### Hello:

Absolutely all credits to the paper Just Train Twice: Improving Group Robustness without Training Group Information 

and the provided code https://github.com/anniesch/jtt/tree/master

the code below was slightly modified and well put together onto one jupyter notebook

this code should work with the provided configuration, possibly you have to play around if about to change something. use heavily the original github repo / work directly by cloning jtt

### Preliminary:

- pip install all dependencies listed 2 cells below
- download data (0.5GB) https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz
- unpack and place it in jtt folder
- suggested dir structure: 

{jtt: main.ipynb, requirements.txt, waterbird_complete95_forest2water2, {results: log: {...}}}

if you clone Automating_Science repo then add dataset to gitignore

This file itself, the jupyter notebook is quite lenghty, but some people prefer it that way

PS: some features (eg. chi-square geometry or label shift support) is not provided/fails to work in the original repo.

**Enjoy** and communicate on Piazza!

# CHAPTER 0

In [None]:
# for debugging
import ipykernel
ipykernel.__version__

In [None]:
import os, sys
import csv
import subprocess
from tqdm import tqdm
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.special import softmax

import bisect
import warnings

from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
import torchvision
from torchvision import transforms

In [None]:
device = torch.device("cpu")

In [None]:
model = "resnet50"
model_type = model
dataset = "cub"

join = os.path.join
data_path = "waterbird_complete95_forest2water2"
metadata_path = f"{data_path}/metadata.csv"
results_dir = f"results/model_outputs/"
log_dir = f"results/log/"
root_dir = "./cub"

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

fraction = 1.0
val_fraction = 0.1

output_csv_name = "aa"
job_script_name = "bb"
job_script_name = "cc"

target = "waterbird_complete95"
confounder_name = ["forest2water2"]
augment_data = False

In [None]:
exp_name = f"{dataset}_sample_exp"
train_from_scratch = True
resume = False
mode = "w"
if os.path.exists(log_dir) and resume:
    resume = True
    mode = "a"
else:
    resume = False
    mode = "w"

log_every = 2
show_progress = False
save_step = 1
save_last = False
save_best = False

n_epochs = 3
final_epoch = n_epochs
lr = 1e-5
max_grad_norm = 1.0 # only bert!
batch_size = 64
wd = 1.0
gamma = 0.1
minimum_variational_weight = 0
use_bert_params = 1
scheduler_flag = False
warmup_steps = 0
adam_epsilon = 1e-8

method = "ERM"
loss_type = ["erm", "group_dro", "joint_dro"][0]
hinge = False
btl = False

initialization_text = """"""
final_text = "***"
memory = 30
seed = 42

shift_type = "confounder"
up_weight = 0
fold = None
num_folds_per_sweep = 5
num_sweeps = 4
aug_col =  None
reweight_groups = False

alpha = 0.2
generalization_adjustment = "0.0"
automatic_adjustment = False
robust_step_size = 0.01
joint_dro_alpha = 1
use_normalized_loss = False

conf_threshold = 0.5
deploy = False

# CHAPTER 1.1

In [None]:
model_attributes = {
    "resnet50": {
        "feature_type": "image",
        "target_resolution": (224, 224),
        "flatten": False,
    },
}

In [None]:
def get_model(model, pretrained, resume, n_classes, dataset, log_dir, train_data=None):
    if train_data is not None:
        if resume:
            model = torch.load(os.path.join(log_dir, "last_model.pth"))
            d = train_data.input_size()[0]
        elif model_attributes[model]["feature_type"] in (
                "precomputed",
                "raw_flattened",
        ):
            assert pretrained
            # Load precomputed features
            d = train_data.input_size()[0]
            model = nn.Linear(d, n_classes)
            model.has_aux_logits = False
    elif model == "resnet50":
        model = torchvision.models.resnet50(pretrained=pretrained)
        d = model.fc.in_features
        model.fc = nn.Linear(d, n_classes)
    elif model == "resnet34":
        model = torchvision.models.resnet34(pretrained=pretrained)
        d = model.fc.in_features
        model.fc = nn.Linear(d, n_classes)
    elif model == "wideresnet50":
        model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
        d = model.fc.in_features
        model.fc = nn.Linear(d, n_classes)
    else:
        raise ValueError(f"{model} Model not recognized.")

    return model

In [None]:
class Logger(object):
    def __init__(self, fpath=None, mode="w"):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            self.file = open(fpath, mode)

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()

In [None]:
class CSVBatchLogger:
    def __init__(self, csv_path, n_groups, mode="w"):
        columns = ["epoch", "batch"]
        for idx in range(n_groups):
            columns.append(f"avg_loss_group:{idx}")
            columns.append(f"exp_avg_loss_group:{idx}")
            columns.append(f"avg_acc_group:{idx}")
            columns.append(f"processed_data_count_group:{idx}")
            columns.append(f"update_data_count_group:{idx}")
            columns.append(f"update_batch_count_group:{idx}")
        columns.append("avg_actual_loss")
        columns.append("avg_per_sample_loss")
        columns.append("avg_acc")
        columns.append("model_norm_sq")
        columns.append("reg_loss")

        self.path = csv_path
        self.file = open(csv_path, mode)
        self.columns = columns
        self.writer = csv.DictWriter(self.file, fieldnames=columns)
        if mode == "w":
            self.writer.writeheader()

    def log(self, epoch, batch, stats_dict):
        stats_dict["epoch"] = epoch
        stats_dict["batch"] = batch
        self.writer.writerow(stats_dict)

    def flush(self):
        self.file.flush()

    def close(self):
        self.file.close()

In [None]:
class DRODataset(Dataset):
    def __init__(self, dataset, process_item_fn, n_groups, n_classes,
                 group_str_fn):
        self.dataset = dataset
        self.process_item = process_item_fn
        self.n_groups = n_groups
        self.n_classes = n_classes
        self.group_str = group_str_fn
        group_array = []
        y_array = []

        group_array = self.get_group_array()
        y_array = self.get_label_array()

        self._group_array = torch.LongTensor(group_array)
        self._y_array = torch.LongTensor(y_array)
        self._group_counts = ((torch.arange(
            self.n_groups).unsqueeze(1) == self._group_array).sum(1).float())

        self._y_counts = (torch.arange(
            self.n_classes).unsqueeze(1) == self._y_array).sum(1).float()

    def __getitem__(self, idx):
        if self.process_item is None:
            return self.dataset[idx]
        else:
            return self.process_item(self.dataset[idx])

    def get_group_array(self):
        if self.process_item is None:
            return self.dataset.get_group_array()
        else:
            raise NotImplementedError

    def get_label_array(self):
        if self.process_item is None:
            return self.dataset.get_label_array()
        else:
            raise NotImplementedError

    def __len__(self):
        return len(self.dataset)

    def group_counts(self):
        return self._group_counts

    def class_counts(self):
        return self._y_counts

    def input_size(self):
        for x, y, g, _ in self:
            return x.size()

In [None]:
def log_data(data, logger):
    logger.write("Training Data...\n")
    for group_idx in range(data["train_data"].n_groups):
        logger.write(
            f'    {data["train_data"].group_str(group_idx)}: n = {data["train_data"].group_counts()[group_idx]:.0f}\n'
        )
    logger.write("Validation Data...\n")
    for group_idx in range(data["val_data"].n_groups):
        logger.write(
            f'    {data["val_data"].group_str(group_idx)}: n = {data["val_data"].group_counts()[group_idx]:.0f}\n'
        )
    if data["test_data"] is not None:
        logger.write("Test Data...\n")
        for group_idx in range(data["test_data"].n_groups):
            logger.write(
                f'    {data["test_data"].group_str(group_idx)}: n = {data["test_data"].group_counts()[group_idx]:.0f}\n'
            )

In [None]:
def get_loader(dataset, train, reweight_groups, **kwargs):
    if not train:  # Validation or testing
        assert reweight_groups is None
        shuffle = False
        sampler = None
    elif not reweight_groups:  # Training but not reweighting
        shuffle = True
        sampler = None
    else:  # Training and reweighting
        # When the --robust flag is not set, reweighting changes the loss function
        # from the normal ERM (average loss over each training example)
        # to a reweighted ERM (weighted average where each (y,c) group has equal weight) .
        # When the --robust flag is set, reweighting does not change the loss function
        # since the minibatch is only used for mean gradient estimation for each group separately
        group_weights = len(dataset) / dataset._group_counts
        weights = group_weights[dataset._group_array]
        # Replacement needs to be set to True, otherwise we'll run out of minority samples
        sampler = WeightedRandomSampler(weights,
                                        len(dataset),
                                        replacement=True)
        shuffle = False

    # assert shuffle == False
    loader = DataLoader(dataset, shuffle=shuffle, sampler=sampler, **kwargs)
    return loader

In [None]:
class Subset(torch.utils.data.Dataset):
    """
    Subsets a dataset while preserving original indexing.

    NOTE: torch.utils.dataset.Subset loses original indexing.
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

        self.group_array = self.get_group_array(re_evaluate=True)
        self.label_array = self.get_label_array(re_evaluate=True)
        

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

    def get_group_array(self, re_evaluate=True):
        """Return an array [g_x1, g_x2, ...]"""
        # setting re_evaluate=False helps us over-write the group array if necessary (2-group DRO)
        if re_evaluate:
            group_array = self.dataset.get_group_array()[self.indices]        
            assert len(group_array) == len(self)
            return group_array
        else:
            return self.group_array

    def get_label_array(self, re_evaluate=True):
        if re_evaluate:
            label_array = self.dataset.get_label_array()[self.indices]
            assert len(label_array) == len(self)
            return label_array
        else:
            return self.label_array


class ConcatDataset(torch.utils.data.ConcatDataset):
    """
    Concate datasets

    Extends the default torch class to support group and label arrays.
    """
    def __init__(self, datasets):
        super(ConcatDataset, self).__init__(datasets)

    def get_group_array(self):
        group_array = []
        for dataset in self.datasets:
            group_array += list(np.squeeze(dataset.get_group_array()))
        return group_array

    def get_label_array(self):
        label_array = []
        for dataset in self.datasets:
            label_array += list(np.squeeze(dataset.get_label_array()))
        return label_array


def get_fold(
    dataset,
    fold_arg=None,
    cross_validation_ratio=0.2,
    num_valid_per_point=4,
    seed=0,
    shuffle=True,
):
    """Returns (train, valid) splits of the dataset.

    Args:
      dataset (DRODataset): the dataset to split into (train, valid) splits.
      cross_validation_ratio (float): valid set size is this times the size of
          the dataset.
      num_valid_per_point (int): number of times each point appears in a
          validation set.
      seed (int): under the same seed, the output of this is guaranteed to be
          the same.
      shuffle (bool): whether to shuffle the training-set for cross validation
          or not (used for debugging can be removed later.)

    Returns:
      folds (list[list[[(DRODataset, DRODataset)]]): the (train, valid) splits.
          In each outer list, the inner list valid sets span the entire train
          set.  Each inner list is length: num_valid_per_point * 1 /
          cross_validation_ratio.
    """
    if fold_arg is not None:
        indices = fold_arg.split("_")[1:]
        sweep_ind = int(indices[0])
        fold_ind = int(indices[1])
        assert sweep_ind is None or sweep_ind < num_valid_per_point
        assert fold_ind is None or fold_ind < int(1 / cross_validation_ratio)

    valid_size = int(np.ceil(len(dataset) * cross_validation_ratio))
    num_valid_sets = int(np.ceil(len(dataset) / valid_size))

    random = np.random.RandomState(seed)

    all_folds = []
    for sweep_counter in range(num_valid_per_point):
        folds = []
        indices = list(range(len(dataset)))
        if shuffle:
            random.shuffle(indices)
        else:
            print("\n" * 10, "WARNING, NOT SHUFFLING", "\n" * 10)
        for i in range(num_valid_sets):
            train_indices = indices[:i * valid_size] + indices[(i + 1) *
                                                               valid_size:]
            print("len(train_indices)", len(train_indices))
            train_split = Subset(dataset, train_indices)

            valid_indices = indices[i * valid_size:(i + 1) * valid_size]
            print("len(valid_indices)", len(valid_indices))
            valid_split = Subset(dataset, valid_indices)
            if sweep_counter == 0 and i == 0:
                print("train_split", train_split, "valid_split", valid_split)
            folds.append((train_split, valid_split))
        all_folds.append(folds)

    if fold_arg is not None:
        train_data_subset, val_data_subset = all_folds[sweep_ind][fold_ind]
        # Wrap in DRODataset Objects
        train_data = DRODataset(
            train_data_subset,
            process_item_fn=None,
            n_groups=dataset.n_groups,
            n_classes=dataset.n_classes,
            group_str_fn=dataset.group_str,
        )

        val_data = DRODataset(
            val_data_subset,
            process_item_fn=None,
            n_groups=dataset.n_groups,
            n_classes=dataset.n_classes,
            group_str_fn=dataset.group_str,
        )

        return train_data, val_data
    else:
        return all_folds

In [None]:
def apply_label_shift(dataset, n_classes, shift_type, minority_frac,
                      imbalance_ratio):
    assert shift_type.startswith("label_shift")
    if shift_type == "label_shift_step":
        return step_shift(dataset, n_classes, minority_frac, imbalance_ratio)


def step_shift(dataset, n_classes, minority_frac, imbalance_ratio):
    # get y info
    y_array = []
    for x, y in dataset:
        y_array.append(y)
    y_array = torch.LongTensor(y_array)
    y_counts = ((
        torch.arange(n_classes).unsqueeze(1) == y_array).sum(1)).float()
    # figure out sample size for each class
    is_major = (torch.arange(n_classes) <
                (1 - minority_frac) * n_classes).float()
    major_count = int(
        torch.min(is_major * y_counts +
                  (1 - is_major) * y_counts * imbalance_ratio).item())
    minor_count = int(np.floor(major_count / imbalance_ratio))
    print(y_counts, major_count, minor_count)
    # subsample
    sampled_indices = []
    for y in np.arange(n_classes):
        (indices, ) = np.where(y_array == y)
        np.random.shuffle(indices)
        if is_major[y]:
            sample_size = major_count
        else:
            sample_size = minor_count
        sampled_indices.append(indices[:sample_size])
    sampled_indices = torch.from_numpy(np.concatenate(sampled_indices))
    return Subset(dataset, sampled_indices)


In [None]:
class ConfounderDataset(Dataset):
    def __init__(
        self,
        root_dir,
        target_name,
        confounder_names,
        model_type=None,
        augment_data=None,
    ):
        raise NotImplementedError

    def get_group_array(self):
        return self.group_array

    def get_label_array(self):
        return self.y_array

    def __len__(self):
        return len(self.filename_array)

    def __getitem__(self, idx):
        y = self.y_array[idx]
        g = self.group_array[idx]

        if model_attributes[self.model_type]["feature_type"] == "precomputed":
            x = self.features_mat[idx, :]
        else:
            img_filename = os.path.join(self.data_dir,
                                        self.filename_array[idx])
            img = Image.open(img_filename).convert("RGB")
            # Figure out split and transform accordingly
            if self.split_array[idx] == self.split_dict[
                    "train"] and self.train_transform:
                img = self.train_transform(img)
            elif (self.split_array[idx]
                  in [self.split_dict["val"], self.split_dict["test"]]
                  and self.eval_transform):
                img = self.eval_transform(img)
            # Flatten if needed
            if model_attributes[self.model_type]["flatten"]:
                assert img.dim() == 3
                img = img.view(-1)
            x = img

        return x, y, g, idx

    def get_splits(self, splits, train_frac=1.0):
        subsets = {}
        for split in splits:
            assert split in ("train", "val",
                             "test"), f"{split} is not a valid split"
            mask = self.split_array == self.split_dict[split]

            num_split = np.sum(mask)
            indices = np.where(mask)[0]
            if train_frac < 1 and split == "train":
                num_to_retain = int(np.round(float(len(indices)) * train_frac))
                indices = np.sort(
                    np.random.permutation(indices)[:num_to_retain])
            subsets[split] = Subset(self, indices)
        return subsets

    def group_str(self, group_idx):
        y = group_idx // (self.n_groups / self.n_classes)
        c = group_idx % (self.n_groups // self.n_classes)

        group_name = f"{self.target_name} = {int(y)}"
        bin_str = format(int(c), f"0{self.n_confounders}b")[::-1]
        for attr_idx, attr_name in enumerate(self.confounder_names):
            group_name += f", {attr_name} = {bin_str[attr_idx]}"
        return group_name

In [None]:
class CUBDataset(ConfounderDataset):
    """
    CUB dataset (already cropped and centered).
    NOTE: metadata_df is one-indexed.
    """
    def __init__(
        self,
        root_dir,
        target_name,
        confounder_names,
        augment_data=False,
        model_type=None,
        data_dir=None,
        metadata_csv_name="metadata.csv"
    ):
        self.root_dir = root_dir
        self.target_name = target_name
        self.confounder_names = confounder_names
        self.model_type = model_type
        self.augment_data = augment_data

        if data_dir is None:
            self.data_dir = os.path.join(
                self.root_dir, "data",
                "_".join([self.target_name] + self.confounder_names))
        else:
            self.data_dir = data_dir

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f"{self.data_dir} does not exist yet. Please generate the dataset first."
            )
        
        print(self.data_dir)

        # Read in metadata
        print(f"Reading '{os.path.join(self.data_dir, metadata_csv_name)}'")
        self.metadata_df = pd.read_csv(metadata_csv_name)

        # Get the y values
        self.y_array = self.metadata_df["y"].values
        self.n_classes = 2

        # We only support one confounder for CUB for now
        self.confounder_array = self.metadata_df["place"].values
        self.n_confounders = 1
        # Map to groups
        self.n_groups = pow(2, 2)
        assert self.n_groups == 4, "check the code if you are running otherwise"
        self.group_array = (self.y_array * (self.n_groups / 2) +
                            self.confounder_array).astype("int")

        # Extract filenames and splits
        self.filename_array = self.metadata_df["img_filename"].values
        self.split_array = self.metadata_df["split"].values
        self.split_dict = {
            "train": 0,
            "val": 1,
            "test": 2,
        }

        # Set transform
        if model_attributes[self.model_type]["feature_type"] == "precomputed":
            self.features_mat = torch.from_numpy(
                np.load(
                    os.path.join(
                        root_dir,
                        "features",
                        model_attributes[self.model_type]["feature_filename"],
                    ))).float()
            self.train_transform = None
            self.eval_transform = None
        else:
            self.features_mat = None
            self.train_transform = get_transform_cub(self.model_type,
                                                     train=True,
                                                     augment_data=augment_data)
            self.eval_transform = get_transform_cub(self.model_type,
                                                    train=False,
                                                    augment_data=augment_data)


def get_transform_cub(model_type, train, augment_data):
    scale = 256.0 / 224.0
    target_resolution = model_attributes[model_type]["target_resolution"]
    assert target_resolution is not None

    if (not train) or (not augment_data):
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize((
                int(target_resolution[0] * scale),
                int(target_resolution[1] * scale),
            )),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    return transform

In [None]:
confounder_settings = {
    "cub": {
        "constructor": CUBDataset
    },
    "CUB": {
        "constructor": CUBDataset
    },
}

def prepare_confounder_data(train, return_full_dataset=False):
    full_dataset = confounder_settings[dataset]["constructor"](
        root_dir=root_dir,
        target_name=target,
        confounder_names=confounder_name,
        model_type=model,
        augment_data=augment_data,
        data_dir=data_path if (data_path is not None) else None,
        metadata_csv_name=metadata_path if (metadata_path is not None) else "metadata.csv",
    )
    if return_full_dataset:
        return DRODataset(
            full_dataset,
            process_item_fn=None,
            n_groups=full_dataset.n_groups,
            n_classes=full_dataset.n_classes,
            group_str_fn=full_dataset.group_str,
        )
    if train:
        splits = ["train", "val", "test"]
    else:
        splits = ["test"]
    subsets = full_dataset.get_splits(splits, train_frac=fraction)
    dro_subsets = [
        DRODataset(
            subsets[split],
            process_item_fn=None,
            n_groups=full_dataset.n_groups,
            n_classes=full_dataset.n_classes,
            group_str_fn=full_dataset.group_str,
        ) for split in splits
    ]
    return dro_subsets

In [None]:
dataset_attributes = {
    "cub": {
        "root_dir": "cub"
    },
    "CUB": {
        "root_dir": "cub"
    },
}
def prepare_data(train, return_full_dataset=False):
    global root_dir
    # Set root_dir to defaults if necessary
    if root_dir is None:
        root_dir = dataset_attributes[dataset]["root_dir"]
    if shift_type == "confounder":
        return prepare_confounder_data(
            train,
            return_full_dataset,
        )

In [None]:
confounder_settings = {
    "cub": {
        "constructor": CUBDataset
    },
    "CUB": {
        "constructor": CUBDataset
    },
}

In [None]:
def get_spurious_col_csv():
    metadata_dir = join(join(results_dir, dataset), exp_name)
    # output_dir: results/dataset/exp_name/
    output_dir = metadata_dir
    # output_path: results/dataset/exp_name/metadata_aug.csv
    output_path = join(output_dir, output_csv_name)
    new_metadata = pd.read_csv(metadata_path)
    split_name = "split"

    train_data = new_metadata[new_metadata[split_name] == 0]

    index_col = "Unnamed: 0"
    train_data["spurious"] = train_data["y"] != train_data["place"]
    index_col = "img_id"
        
    spur_col = train_data[["spurious", index_col]]
    new_metadata = pd.merge(
        new_metadata, spur_col, how="outer", on=index_col
    )
    new_metadata = new_metadata.fillna(False)

    # Save metadata
    new_metadata.to_csv(output_path)

In [None]:
def set_seed(seed):
    """Sets seed"""
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
def hinge_loss(yhat, y):
    # The torch loss takes in three arguments so we need to split yhat
    # It also expects classes in {+1.0, -1.0} whereas by default we give them in {0, 1}
    # Furthermore, if y = 1 it expects the first input to be higher instead of the second,
    # so we need to swap yhat[:, 0] and yhat[:, 1]...
    torch_loss = torch.nn.MarginRankingLoss(margin=1.0, reduction="none")
    y = (y.float() * 2.0) - 1.0
    return torch_loss(yhat[:, 1], yhat[:, 0], y)

In [None]:
GEOMETRIES = ('cvar', 'chi-square')


def cvar_value(p, v, reg):
    """Returns <p, v> - reg * KL(p, uniform) for Torch tensors"""
    m = p.shape[0]

    with torch.no_grad():
        idx = torch.nonzero(p)  # where is annoyingly backwards incompatible
        kl = np.log(m) + (p[idx] * torch.log(p[idx])).sum()

    return torch.dot(p, v) - reg * kl

def chi_square_value(p,v, reg): # TODO
    pass

In [None]:
class RobustLoss(torch.nn.Module):
    """PyTorch module for the batch robust loss estimator"""
    def __init__(self, size, reg, geometry, tol=1e-4,
                 max_iter=1000, debugging=False):
        """
        Parameters
        ----------
        size : float
            Size of the uncertainty set (\rho for \chi^2 and \alpha for CVaR)
            Set float('inf') for unconstrained
        reg : float
            Strength of the regularizer, entropy if geometry == 'cvar'
            $\chi^2$ divergence if geometry == 'chi-square'
        geometry : string
            Element of GEOMETRIES
        tol : float, optional
            Tolerance parameter for the bisection
        max_iter : int, optional
            Number of iterations after which to break the bisection
        """
        super().__init__()
        self.size = size
        self.reg = reg
        self.geometry = geometry
        self.tol = tol
        self.max_iter = max_iter
        self.debugging = debugging

        self.is_erm = size == 0

        if geometry not in GEOMETRIES:
            raise ValueError('Geometry %s not supported' % geometry)

        if geometry == 'cvar' and self.size > 1:
            raise ValueError(f'alpha should be < 1 for cvar, is {self.size}')

    def best_response(self, v):
        size = self.size
        reg = self.reg
        m = v.shape[0]

        if self.geometry == 'cvar':
            if self.reg > 0:
                if size == 1.0:
                    return torch.ones_like(v) / m

                def p(eta):
                    x = (v - eta) / reg
                    return torch.min(torch.exp(x),
                                     torch.Tensor([1 / size]).type(x.dtype)) / m

                def bisection_target(eta):
                    return 1.0 - p(eta).sum()

                eta_min = reg * torch.logsumexp(v / reg - np.log(m), 0)
                eta_max = v.max()

                if torch.abs(bisection_target(eta_min)) <= self.tol:
                    return p(eta_min)
            else:
                cutoff = int(size * m)
                surplus = 1.0 - cutoff / (size * m)

                p = torch.zeros_like(v)
                idx = torch.argsort(v, descending=True)
                p[idx[:cutoff]] = 1.0 / (size * m)
                if cutoff < m:
                    p[idx[cutoff]] = surplus
                return p

        if self.geometry == 'chi-square':
            if (v.max() - v.min()) / v.max() <= MIN_REL_DIFFERENCE:
                return torch.ones_like(v) / m

            if size == float('inf'):
                assert reg > 0

                def p(eta):
                    return torch.relu(v - eta) / (reg * m)

                def bisection_target(eta):
                    return 1.0 - p(eta).sum()

                eta_min = min(v.sum() - reg * m, v.min())
                eta_max = v.max()

            else:
                assert size < float('inf')

                # failsafe for batch sizes small compared to
                # uncertainty set size
                if m <= 1 + 2 * size:
                    out = (v == v.max()).float()
                    out /= out.sum()
                    return out

                if reg == 0:
                    def p(eta):
                        pp = torch.relu(v - eta)
                        return pp / pp.sum()

                    def bisection_target(eta):
                        pp = p(eta)
                        w = m * pp - torch.ones_like(pp)
                        return 0.5 * torch.mean(w ** 2) - size

                    eta_min = -(1.0 / (np.sqrt(2 * size + 1) - 1)) * v.max()
                    eta_max = v.max()
                else:
                    def p(eta):
                        pp = torch.relu(v - eta)

                        opt_lam = max(
                            reg, torch.norm(pp) / np.sqrt(m * (1 + 2 * size))
                        )

                        return pp / (m * opt_lam)

                    def bisection_target(eta):
                        return 1 - p(eta).sum()

                    eta_min = v.min() - 1
                    eta_max = v.max()

        eta_star = bisection(
            eta_min, eta_max, bisection_target,
            tol=self.tol, max_iter=self.max_iter)

        if self.debugging:
            return p(eta_star), eta_star
        return p(eta_star)

    def forward(self, v):
        """Value of the robust loss
        Note that the best response is computed without gradients
        Parameters
        ----------
        v : torch.Tensor
            Tensor containing the individual losses on the batch of examples
        Returns
        -------
        loss : torch.float
            Value of the robust loss on the batch of examples
        """
        if self.is_erm:
            return v.mean()
        else:
            with torch.no_grad():
                p = self.best_response(v)

            if self.geometry == 'cvar':
                return cvar_value(p, v, self.reg)
            elif self.geometry == 'chi-square':
                return chi_square_value(p, v, self.reg)



In [None]:
class LossComputer:
    def __init__(
        self,
        criterion,
        loss_type,
        dataset,
        alpha=None,
        gamma=0.1,
        adj=None,
        min_var_weight=0,
        step_size=0.01,
        normalize_loss=False,
        btl=False,
        joint_dro_alpha=None,
    ):
        assert loss_type in ["group_dro", "erm", "joint_dro"]

        self.criterion = criterion
        self.loss_type = loss_type
        self.gamma = gamma
        self.alpha = alpha
        self.min_var_weight = min_var_weight
        self.step_size = step_size
        self.normalize_loss = normalize_loss
        self.btl = btl

        self.n_groups = dataset.n_groups

        self.group_counts = dataset.group_counts().to(device)
        self.group_frac = self.group_counts / self.group_counts.sum()
        self.group_str = dataset.group_str

        if self.loss_type == "joint_dro":
            # Joint DRO reg should be 0.
            assert joint_dro_alpha is not None
            self._joint_dro_loss_computer = RobustLoss(
                    joint_dro_alpha, 0, "cvar")

        if adj is not None:
            self.adj = torch.from_numpy(adj).float().to(device)
        else:
            self.adj = torch.zeros(self.n_groups).float().to(device)

        if loss_type == "group_dro":
            assert alpha, "alpha must be specified"

        # quantities maintained throughout training
        self.adv_probs = torch.ones(self.n_groups).to(device) / self.n_groups
        self.exp_avg_loss = torch.zeros(self.n_groups).to(device)
        self.exp_avg_initialized = torch.zeros(self.n_groups).byte().to(device)

        self.reset_stats()

    def loss(self, yhat, y, group_idx=None, is_training=False):
        # compute per-sample and per-group losses
        per_sample_losses = self.criterion(yhat, y)
        group_loss, group_count = self.compute_group_avg(
            per_sample_losses, group_idx)
        group_acc, group_count = self.compute_group_avg(
            (torch.argmax(yhat, 1) == y).float(), group_idx)

        # update historical losses
        self.update_exp_avg_loss(group_loss, group_count)

        # compute overall loss
        if self.loss_type == "group_dro":
            if not self.btl:
                actual_loss, weights = self.compute_robust_loss(
                    group_loss, group_count)
            else:
                actual_loss, weights = self.compute_robust_loss_btl(
                    group_loss, group_count)
        elif self.loss_type == "joint_dro":
            actual_loss = self._joint_dro_loss_computer(per_sample_losses)
            weights = None
        else:
            assert self.loss_type == "erm"

            actual_loss = per_sample_losses.mean()
            weights = None

        # update stats
        self.update_stats(actual_loss, group_loss, group_acc, group_count,
                          weights)

        return actual_loss

    def compute_robust_loss(self, group_loss, group_count):
        adjusted_loss = group_loss
        if torch.all(self.adj > 0):
            adjusted_loss += self.adj / torch.sqrt(self.group_counts)
        if self.normalize_loss:
            adjusted_loss = adjusted_loss / (adjusted_loss.sum())
        self.adv_probs = self.adv_probs * torch.exp(
            self.step_size * adjusted_loss.data)
        self.adv_probs = self.adv_probs / (self.adv_probs.sum())

        robust_loss = group_loss @ self.adv_probs
        return robust_loss, self.adv_probs

    def compute_robust_loss_btl(self, group_loss, group_count):
        adjusted_loss = self.exp_avg_loss + self.adj / torch.sqrt(
            self.group_counts)
        return self.compute_robust_loss_greedy(group_loss, adjusted_loss)

    def compute_robust_loss_greedy(self, group_loss, ref_loss):
        sorted_idx = ref_loss.sort(descending=True)[1]
        sorted_loss = group_loss[sorted_idx]
        sorted_frac = self.group_frac[sorted_idx]

        mask = torch.cumsum(sorted_frac, dim=0) <= self.alpha
        weights = mask.float() * sorted_frac / self.alpha
        last_idx = mask.sum()
        weights[last_idx] = 1 - weights.sum()
        weights = sorted_frac * self.min_var_weight + weights * (
            1 - self.min_var_weight)

        robust_loss = sorted_loss @ weights

        # sort the weights back
        _, unsort_idx = sorted_idx.sort()
        unsorted_weights = weights[unsort_idx]
        return robust_loss, unsorted_weights

    def compute_group_avg(self, losses, group_idx):
        # compute observed counts and mean loss for each group
        group_map = (group_idx == torch.arange(
            self.n_groups).unsqueeze(1).long().to(device)).float()

        group_count = group_map.sum(1)
        group_denom = group_count + (group_count == 0).float()  # avoid nans
        group_loss = (group_map @ losses.view(-1)) / group_denom
        return group_loss, group_count

    def update_exp_avg_loss(self, group_loss, group_count):
        prev_weights = (1 - self.gamma * (group_count > 0).float()) * (
            self.exp_avg_initialized > 0).float()
        curr_weights = 1 - prev_weights
        self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights
        self.exp_avg_initialized = (self.exp_avg_initialized >
                                    0) + (group_count > 0)

    def reset_stats(self):
        self.processed_data_counts = torch.zeros(self.n_groups).to(device)
        self.update_data_counts = torch.zeros(self.n_groups).to(device)
        self.update_batch_counts = torch.zeros(self.n_groups).to(device)
        self.avg_group_loss = torch.zeros(self.n_groups).to(device)
        self.avg_group_acc = torch.zeros(self.n_groups).to(device)
        self.avg_per_sample_loss = 0.0
        self.avg_actual_loss = 0.0
        self.avg_acc = 0.0
        self.batch_count = 0.0

    def update_stats(self,
                     actual_loss,
                     group_loss,
                     group_acc,
                     group_count,
                     weights=None):
        # avg group loss
        denom = self.processed_data_counts + group_count
        denom += (denom == 0).float()
        prev_weight = self.processed_data_counts / denom
        curr_weight = group_count / denom
        self.avg_group_loss = prev_weight * self.avg_group_loss + curr_weight * group_loss

        # avg group acc
        self.avg_group_acc = prev_weight * self.avg_group_acc + curr_weight * group_acc

        # batch-wise average actual loss
        denom = self.batch_count + 1
        self.avg_actual_loss = (self.batch_count /
                                denom) * self.avg_actual_loss + (
                                    1 / denom) * actual_loss

        # counts
        self.processed_data_counts += group_count
        if self.loss_type == "group_dro":
            self.update_data_counts += group_count * ((weights > 0).float())
            self.update_batch_counts += ((group_count * weights) > 0).float()
        else:
            self.update_data_counts += group_count
            self.update_batch_counts += (group_count > 0).float()
        self.batch_count += 1

        # avg per-sample quantities
        group_frac = self.processed_data_counts / (
            self.processed_data_counts.sum())
        self.avg_per_sample_loss = group_frac @ self.avg_group_loss
        self.avg_acc = group_frac @ self.avg_group_acc

    def get_model_stats(self, model, stats_dict):
        model_norm_sq = 0.0
        for param in model.parameters():
            model_norm_sq += torch.norm(param)**2
        stats_dict["model_norm_sq"] = model_norm_sq.item()
        stats_dict["reg_loss"] = wd / 2 * model_norm_sq.item()
        return stats_dict

    def get_stats(self, model=None):
        stats_dict = {}
        for idx in range(self.n_groups):
            stats_dict[f"avg_loss_group:{idx}"] = self.avg_group_loss[
                idx].item()
            stats_dict[f"exp_avg_loss_group:{idx}"] = self.exp_avg_loss[
                idx].item()
            stats_dict[f"avg_acc_group:{idx}"] = self.avg_group_acc[idx].item()
            stats_dict[
                f"processed_data_count_group:{idx}"] = self.processed_data_counts[
                    idx].item()
            stats_dict[
                f"update_data_count_group:{idx}"] = self.update_data_counts[
                    idx].item()
            stats_dict[
                f"update_batch_count_group:{idx}"] = self.update_batch_counts[
                    idx].item()

        stats_dict["avg_actual_loss"] = self.avg_actual_loss.item()
        stats_dict["avg_per_sample_loss"] = self.avg_per_sample_loss.item()
        stats_dict["avg_acc"] = self.avg_acc.item()

        # Model stats
        if model is not None:
            stats_dict = self.get_model_stats(model, stats_dict)

        return stats_dict

    def log_stats(self, logger, is_training):
        if logger is None:
            return

        logger.write(
            f"Average incurred loss: {self.avg_per_sample_loss.item():.3f}  \n"
        )
        logger.write(
            f"Average sample loss: {self.avg_actual_loss.item():.3f}  \n")
        logger.write(f"Average acc: {self.avg_acc.item():.3f}  \n")
        for group_idx in range(self.n_groups):
            logger.write(
                f"  {self.group_str(group_idx)}  "
                f"[n = {int(self.processed_data_counts[group_idx])}]:\t"
                f"loss = {self.avg_group_loss[group_idx]:.3f}  "
                f"exp loss = {self.exp_avg_loss[group_idx]:.3f}  "
                f"adjusted loss = {self.exp_avg_loss[group_idx] + self.adj[group_idx]/torch.sqrt(self.group_counts)[group_idx]:.3f}  "
                f"adv prob = {self.adv_probs[group_idx]:3f}   "
                f"acc = {self.avg_group_acc[group_idx]:.3f}\n")
        logger.flush()

In [None]:
def run_epoch(
    epoch,
    model,
    optimizer,
    loader,
    loss_computer,
    logger,
    csv_logger,
    is_training,
    show_progress=False,
    log_every=50,
    scheduler=None,
    csv_name=None,
    group=None,
):
    """
    scheduler is only used inside this function if model is bert.
    """

    if is_training:
        model.train()
        if (model_type.startswith("bert") and use_bert_params): # or (args.model == "bert"):
            model.zero_grad()
        
    else:
        model.eval()

    if show_progress:
        prog_bar_loader = tqdm(loader)
    else:
        prog_bar_loader = loader

    with torch.set_grad_enabled(is_training):

        for batch_idx, batch in enumerate(prog_bar_loader):
            batch = tuple(t.to(device) for t in batch)
            x = batch[0]
            y = batch[1]
            g = batch[2]
            data_idx = batch[3]
            
            if model_type.startswith("bert"):
                input_ids = x[:, :, 0]
                input_masks = x[:, :, 1]
                segment_ids = x[:, :, 2]
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=input_masks,
                    token_type_ids=segment_ids,
                    labels=y,
                )[1]  # [1] returns logits
            else:
                # outputs.shape: (batch_size, num_classes)
                outputs = model(x)

                
            output_df = pd.DataFrame()

            # Calculate stats
            if batch_idx == 0:
                acc_y_pred = np.argmax(outputs.detach().cpu().numpy(), axis=1)
                acc_y_true = y.cpu().numpy()
                indices = data_idx.cpu().numpy()
                
                probs = outputs.detach().cpu().numpy()
            else:
                acc_y_pred = np.concatenate([
                    acc_y_pred,
                    np.argmax(outputs.detach().cpu().numpy(), axis=1)
                ])
                acc_y_true = np.concatenate([acc_y_true, y.cpu().numpy()])
                indices = np.concatenate([indices, data_idx.cpu().numpy()])
                probs = np.concatenate([probs, outputs.detach().cpu().numpy()], axis = 0)
                
            assert probs.shape[0] == indices.shape[0]
            # TODO: make this cleaner.
            run_name = f"{csv_name}_epoch_{epoch}_val"
            output_df[f"y_pred_{run_name}"] = acc_y_pred
            output_df[f"y_true_{run_name}"] = acc_y_true
            output_df[f"indices_{run_name}"] = indices
            
            for class_ind in range(probs.shape[1]):
                output_df[f"pred_prob_{run_name}_{class_ind}"] = probs[:, class_ind]

            loss_main = loss_computer.loss(outputs, y, g, is_training)

            if is_training:
                if (model_type.startswith("bert") and use_bert_params): 
                    loss_main.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_grad_norm)
                    scheduler.step()
                    optimizer.step()
                    model.zero_grad()
                else:
                    optimizer.zero_grad()
                    loss_main.backward()
                    optimizer.step()

            if is_training and (batch_idx + 1) % log_every == 0:
                run_stats = loss_computer.get_stats(model) 
                csv_logger.log(epoch, batch_idx, run_stats)

                csv_logger.flush()
                loss_computer.log_stats(logger, is_training)
                loss_computer.reset_stats()

        if run_name is not None:
            save_dir = "/".join(csv_logger.path.split("/")[:-1])
            output_df.to_csv(
                os.path.join(save_dir, 
                                f"output_{group}_epoch_{epoch}.csv"))
            print("Saved", os.path.join(save_dir, 
                                f"output_{group}_epoch_{epoch}.csv"))


        if (not is_training) or loss_computer.batch_count > 0:
            run_stats = loss_computer.get_stats(model)

            csv_logger.log(epoch, batch_idx, run_stats)
            csv_logger.flush()
            loss_computer.log_stats(logger, is_training)
            if is_training:
                loss_computer.reset_stats()


def train(
    model,
    criterion,
    dataset,
    logger,
    train_csv_logger,
    val_csv_logger,
    test_csv_logger,
    epoch_offset,
    csv_name=None,
):
    model = model.to(device)

    # process generalization adjustment stuff
    adjustments = [float(c) for c in generalization_adjustment.split(",")]
    assert len(adjustments) in (1, dataset["train_data"].n_groups)
    if len(adjustments) == 1:
        adjustments = np.array(adjustments * dataset["train_data"].n_groups)
    else:
        adjustments = np.array(adjustments)

    train_loss_computer = LossComputer(
        criterion,
        loss_type=loss_type,
        dataset=dataset["train_data"],
        alpha=alpha,
        gamma=gamma,
        adj=adjustments,
        step_size=robust_step_size,
        normalize_loss=use_normalized_loss,
        btl=btl,
        min_var_weight=minimum_variational_weight,
        joint_dro_alpha=joint_dro_alpha,
    )

    # BERT uses its own scheduler and optimizer
    if (model_type.startswith("bert") and use_bert_params): 
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                wd,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=lr,
                          eps=adam_epsilon)
        t_total = len(dataset["train_loader"]) * n_epochs
        print(f"\nt_total is {t_total}\n")
        scheduler = LinearLR(optimizer,
                            warmup_steps=warmup_steps,
                            t_total=t_total)
    else:
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=0.9,
            weight_decay=wd,
        )
        if scheduler_flag:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                "min",
                factor=0.1,
                patience=5,
                threshold=0.0001,
                min_lr=0,
                eps=1e-08,
            )
        else:
            scheduler = None

    best_val_acc = 0
    for epoch in range(epoch_offset, epoch_offset + n_epochs):
        logger.write("\nEpoch [%d]:\n" % epoch)
        logger.write(f"Training:\n")
        run_epoch(
            epoch,
            model,
            optimizer,
            dataset["train_loader"],
            train_loss_computer,
            logger,
            train_csv_logger,
            is_training=True,
            csv_name=csv_name,
            show_progress=show_progress,
            log_every=log_every,
            scheduler=scheduler,
            group="train",
        )

        logger.write(f"\nValidation:\n")
        val_loss_computer =  LossComputer(
            criterion,
            loss_type=loss_type,
            dataset=dataset["val_data"],
            alpha=alpha,
            gamma=gamma,
            adj=adjustments,
            step_size=robust_step_size,
            normalize_loss=use_normalized_loss,
            btl=btl,
            min_var_weight=minimum_variational_weight,
            joint_dro_alpha=joint_dro_alpha,
        )
        run_epoch(
            epoch,
            model,
            optimizer,
            dataset["val_loader"],
            val_loss_computer,
            logger,
            val_csv_logger,
            is_training=False,
            csv_name=csv_name,
            group="val",
        )

        # Test set; don't print to avoid peeking
        if dataset["test_data"] is not None:
            test_loss_computer = LossComputer(
                criterion,
                loss_type=loss_type,
                dataset=dataset["test_data"],
                step_size=robust_step_size,
                alpha=alpha,
                gamma=gamma,
                adj=adjustments,
                normalize_loss=use_normalized_loss,
                btl=btl,
                min_var_weight=minimum_variational_weight,
                joint_dro_alpha=joint_dro_alpha,
            )
            run_epoch(
                epoch,
                model,
                optimizer,
                dataset["test_loader"],
                test_loss_computer,
                None,
                test_csv_logger,
                is_training=False,
                csv_name=csv_name,
                group="test",
            )

        # Inspect learning rates
        if (epoch + 1) % 1 == 0:
            for param_group in optimizer.param_groups:
                curr_lr = param_group["lr"]
                logger.write("Current lr: %f\n" % curr_lr)

        if scheduler_flag and model_type != "bert":
            if loss_type == "group_dro":
                val_loss, _ = val_loss_computer.compute_robust_loss_greedy(
                    val_loss_computer.avg_group_loss,
                    val_loss_computer.avg_group_loss)
            else:
                val_loss = val_loss_computer.avg_actual_loss
            scheduler.step(
                val_loss)  # scheduler step to update lr at the end of epoch

        if epoch % save_step == 0:
            torch.save(model, os.path.join(log_dir,
                                           "%d_model.pth" % epoch))

        if save_last:
            torch.save(model, os.path.join(log_dir, "last_model.pth"))

        if save_best:
            if loss_type == "group_dro" or reweight_groups:
                curr_val_acc = min(val_loss_computer.avg_group_acc)
            else:
                curr_val_acc = val_loss_computer.avg_acc
            logger.write(f"Current validation accuracy: {curr_val_acc}\n")
            if curr_val_acc > best_val_acc:
                best_val_acc = curr_val_acc
                torch.save(model, os.path.join(log_dir, "best_model.pth"))
                logger.write(f"Best model saved at epoch {epoch}\n")

        if automatic_adjustment:
            gen_gap = val_loss_computer.avg_group_loss - train_loss_computer.exp_avg_loss
            adjustments = gen_gap * torch.sqrt(
                train_loss_computer.group_counts)
            train_loss_computer.adj = adjustments
            logger.write("Adjustments updated\n")
            for group_idx in range(train_loss_computer.n_groups):
                logger.write(
                    f"  {train_loss_computer.get_group_name(group_idx)}:\t"
                    f"adj = {train_loss_computer.adj[group_idx]:.3f}\n")
        logger.write("\n")

# CHAPTER 1.2

In [None]:
logger = Logger(os.path.join(log_dir, "log.txt"), mode)

In [None]:
set_seed(seed)

In [None]:
# Data
# Test data for label_shift_step is not implemented yet
test_data = None
test_loader = None
if shift_type == "confounder":
    train_data, val_data, test_data = prepare_data(
        train=True,
    )

In [None]:
#########################################################################
###################### Prepare data for our method ######################
#########################################################################

# Should probably not be upweighting if folds are specified.
assert not fold or not up_weight

# Fold passed. Use it as train and valid.
if fold:
    train_data, val_data = get_fold(
        train_data,
        fold,
        cross_validation_ratio=(1 / num_folds_per_sweep),
        num_valid_per_point=num_sweeps,
        seed=seed,
    )

if up_weight != 0:
    assert aug_col is not None
    # Get points that should be upsampled
    metadata_df = pd.read_csv(metadata_path)
    train_col = metadata_df[metadata_df["split"] == 0]
    aug_indices = np.where(train_col[aug_col] == 1)[0]
    print("len", len(train_col), len(aug_indices))
    if up_weight == -1:
        up_weight_factor = int(
            (len(train_col) - len(aug_indices)) / len(aug_indices)) - 1
    else:
        up_weight_factor = up_weight

    print(f"Up-weight factor: {up_weight_factor}")
    upsampled_points = Subset(train_data,
                                list(aug_indices) * up_weight_factor)
    # Convert to DRODataset
    train_data = DRODataset(
        ConcatDataset([train_data, upsampled_points]),
        process_item_fn=None,
        n_groups=train_data.n_groups,
        n_classes=train_data.n_classes,
        group_str_fn=train_data.group_str,
    )
elif aug_col is not None:
    print("\n"*2 + "WARNING: aug_col is not being used." + "\n"*2)

In [None]:
#########################################################################
#########################################################################
#########################################################################

loader_kwargs = {
    "batch_size": batch_size,
    "num_workers": 4,
    "pin_memory": True,
}
train_loader = get_loader(train_data,
                        train=True,
                        reweight_groups=reweight_groups,
                        **loader_kwargs)

val_loader = get_loader(val_data,
                        train=False,
                        reweight_groups=None,
                        **loader_kwargs)

if test_data is not None:
    test_loader = get_loader(test_data,
                            train=False,
                            reweight_groups=None,
                            **loader_kwargs)

data = {}
data["train_loader"] = train_loader
data["val_loader"] = val_loader
data["test_loader"] = test_loader
data["train_data"] = train_data
data["val_data"] = val_data
data["test_data"] = test_data

n_classes = train_data.n_classes

log_data(data, logger)

In [None]:
## Initialize model
model = get_model(
    model=model,
    pretrained=not train_from_scratch,
    resume=resume,
    n_classes=train_data.n_classes,
    dataset=dataset,
    log_dir=log_dir,
)

logger.flush()

## Define the objective
if hinge:
    assert dataset in ["CUB"]  # Only supports binary
    criterion = hinge_loss
else:
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

if resume:
    raise NotImplementedError  # Check this implementation.
    df = pd.read_csv(os.path.join(args.log_dir, "test.csv"))
    epoch_offset = df.loc[len(df) - 1, "epoch"] + 1
    logger.write(f"starting from epoch {epoch_offset}")
else:
    epoch_offset = 0

In [None]:
train_csv_logger = CSVBatchLogger(os.path.join(log_dir, f"train.csv"),
                                    train_data.n_groups,
                                    mode=mode)
val_csv_logger = CSVBatchLogger(os.path.join(log_dir, f"val.csv"),
                                val_data.n_groups,
                                mode=mode)
test_csv_logger = CSVBatchLogger(os.path.join(log_dir, f"test.csv"),
                                    test_data.n_groups,
                                    mode=mode)

# CHAPTER 1.3

In [None]:
train(
    model,
    criterion,
    data,
    logger,
    train_csv_logger,
    val_csv_logger,
    test_csv_logger,
    epoch_offset=epoch_offset,
    csv_name=fold
)

train_csv_logger.close()
val_csv_logger.close()
test_csv_logger.close()

# CHAPTER 2

In [None]:
final_epoch = n_epochs - 1
folder_name = "ERM_upweight_0_epochs_3_lr_1e-05_weight_decay_1.0"

train_df = pd.read_csv(os.path.join(log_dir, f"output_train_epoch_{final_epoch}.csv"))
train_df = train_df.sort_values(f"indices_None_epoch_{final_epoch}_val")
train_df["wrong_1_times"] = (1.0 * (train_df[f"y_pred_None_epoch_{final_epoch}_val"] != train_df[f"y_true_None_epoch_{final_epoch}_val"])).apply(np.int64)
print("Total wrong", np.sum(train_df['wrong_1_times']), "Total points", len(train_df))

original_df = pd.read_csv(metadata_path)
original_train_df = original_df[original_df["split"] == 0]
merged_csv = original_train_df.join(train_df.set_index(f"indices_None_epoch_{final_epoch}_val"))
merged_csv["spurious"] = merged_csv['y'] != merged_csv["place"]

merged_csv["our_spurious"] = merged_csv["spurious"] & merged_csv["wrong_1_times"]
merged_csv["our_nonspurious"] = (merged_csv["spurious"] == 0) & merged_csv["wrong_1_times"]
print("Number of our spurious: ", np.sum(merged_csv["our_spurious"]))
print("Number of our nonspurious:", np.sum(merged_csv["our_nonspurious"]))

train_probs_df= merged_csv.fillna(0)

spur_precision = np.sum(
        (merged_csv[f"wrong_1_times"] == 1) & (merged_csv["spurious"] == 1)
    ) / np.sum((merged_csv[f"wrong_1_times"] == 1))
print("Spurious precision", spur_precision)
spur_recall = np.sum(
    (merged_csv[f"wrong_1_times"] == 1) & (merged_csv["spurious"] == 1)
    ) / np.sum((merged_csv["spurious"] == 1))
print("Spurious recall", spur_recall)

probs = softmax(np.array(train_probs_df[[f"pred_prob_None_epoch_{final_epoch}_val_0", f"pred_prob_None_epoch_{final_epoch}_val_1"]]), axis = 1)
train_probs_df["probs_0"] = probs[:,0]
train_probs_df["probs_1"] = probs[:,1]
train_probs_df["confidence"] = train_probs_df["y"] * train_probs_df["probs_1"] + (1 - train_probs_df["y"]) * train_probs_df["probs_0"]
train_probs_df[f"confidence_thres{conf_threshold}"] = (train_probs_df["confidence"] < conf_threshold).apply(np.int64)

if not os.path.exists(f"results/{dataset}/{exp_name}/train_downstream_{folder_name}/final_epoch{final_epoch}"):
    os.makedirs(f"results/{dataset}/{exp_name}/train_downstream_{folder_name}/final_epoch{final_epoch}")
root = f"results/{dataset}/{exp_name}/train_downstream_{folder_name}/final_epoch{final_epoch}"

train_probs_df.to_csv(f"{root}/metadata_aug.csv")
root = f"{exp_name}/train_downstream_{folder_name}/final_epoch{final_epoch}"

sbatch_command = (
        f"python generate_downstream.py --exp_name {root} --lr {lr} --weight_decay {wd} --method JTT --dataset {dataset} --aug_col {aug_col}" + (f" --batch_size {batch_size}" if batch_size else "")
    )
print(sbatch_command)
if deploy:
    subprocess.run(sbatch_command, check=True, shell=True)

# CHAPTER 3.1

In [None]:
def sanitize_df(df):
    """
    Fix a results df for problems arising from resuming.
    """
    # Remove stray epoch/batches
    duplicates = df.duplicated(subset=["epoch", "batch"], keep="last")
    df = df.loc[~duplicates, :]
    df.index = np.arange(len(df))

    if np.sum(duplicates) > 0:
        print(
            f"Removed {np.sum(duplicates)} duplicates from epochs {np.unique(df.loc[duplicates, 'epoch'])}"
        )

    # Make sure epoch/batch is increasing monotonically
    prev_epoch = -1
    prev_batch = -1
    last_batch_in_epoch = -1
    for i in range(len(df)):
        try:
            epoch, batch = df.loc[i, ["epoch", "batch"]].astype(int)
        except:
            print(i, epoch, batch, len(df))
        assert ((prev_epoch == epoch) and
                (prev_batch < batch)) or ((prev_epoch == epoch - 1))
        if prev_epoch == epoch - 1:
            assert (last_batch_in_epoch == -1) or (last_batch_in_epoch
                                                   == prev_batch)
            last_batch_in_epoch = prev_batch
        prev_epoch = epoch
        prev_batch = batch

    return df

In [None]:
def process_df(train_df, val_df, test_df, n_groups):
    loss_metrics = []
    acc_metrics = []
    for group_idx in range(n_groups):  # 4 groups
        loss_metrics.append(f"avg_loss_group:{group_idx}")
        acc_metrics.append(f"avg_acc_group:{group_idx}")
    # robust acc
    for df in [train_df, val_df, test_df]:
        try:
            df["robust_loss"] = np.max(df.loc[:, loss_metrics], axis=1)
            df["robust_acc"] = np.min(df.loc[:, acc_metrics], axis=1)
        except:
            pass

In [None]:
def process_df_waterbird9(train_df, val_df, test_df, params):
    process_df(train_df, val_df, test_df, params)
    loss_metrics = []
    acc_metrics = []
    for group_idx in range(params["n_groups"]):
        loss_metrics.append(f"avg_loss_group:{group_idx}")
        acc_metrics.append(f"avg_acc_group:{group_idx}")

    ratio = params["n_train"] / np.sum(params["n_train"])
    val_df["avg_acc"] = val_df.loc[:, acc_metrics] @ ratio
    val_df["avg_loss"] = val_df.loc[:, loss_metrics] @ ratio
    test_df["avg_acc"] = test_df.loc[:, acc_metrics] @ ratio
    test_df["avg_loss"] = test_df.loc[:, loss_metrics] @ ratio

In [None]:
def get_accs_for_epoch_across_batches(df, epoch):
    n_groups = 1 + np.max([
        int(col.split(":")[1])
        for col in df.columns if col.startswith("avg_acc_group")
    ])

    indices = df["epoch"] == epoch

    accs = np.zeros(n_groups)
    total_counts = np.zeros(n_groups)
    correct_counts = np.zeros(n_groups)

    for i in np.where(indices)[0]:
        for group in range(n_groups):
            total_counts[group] += df.loc[
                i, f"processed_data_count_group:{group}"]
            correct_counts[group] += np.round(
                df.loc[i, f"avg_acc_group:{group}"] *
                df.loc[i, f"processed_data_count_group:{group}"])

    accs = correct_counts / total_counts
    robust_acc = np.min(accs)
    avg_acc = accs @ total_counts / np.sum(total_counts)
    return avg_acc, robust_acc

In [None]:
def print_accs(
    dfs,
    output_dir,
    params=None,
    epoch_to_eval=None,
    print_avg=False,
    output=True,
    splits=["train", "val", "test"],
    early_stop=True,
    print_groups = False,
):
    """
    Input: dictionary of dfs with keys 'val', 'test'
    This takes the minority group 'n' for calculating stdev,
    which is conservative.
    Since clean val/test acc for waterbirds is estimated from a val/test set with a different distribution, there's probably a bit more variability,
    but this is minor since the overall n is high.
    """
    for split in splits:
        assert split in dfs

    early_stopping_epoch = np.argmax(dfs["val"]["robust_acc"].values)

    epochs = []
    assert early_stop or (epoch_to_eval is not None)
    if early_stop:
        epochs += [("early stop at epoch", "early_stopping",
                    early_stopping_epoch)]
    if epoch_to_eval is not None:
        epochs += [("epoch", "epoch_to_eval", epoch_to_eval)]

    metrics = [("Val Robust Worst Group", "robust_acc")]
    if print_avg:
        metrics += [("Val Average Acc", "avg_acc")]
    if print_groups: 
        for i in range(group_count): #  group_count = np.max(np.array([col.split(":")[1] for col in val_df.columns if "_group" in col]).astype(int)) + 1
            metrics += [(f"group {i} acc", f"avg_acc_group:{i}")]

    results = {}
    for metric_str, metric in metrics:
        results[metric] = {}

        for split in splits:
            for epoch_print_str, epoch_save_str, epoch in epochs:
                if epoch not in dfs[split]["epoch"].values:
                    if output:
                        print(
                            f"{metric_str} {split:<5} acc ({epoch_print_str} {epoch_to_eval}):               Not yet run"
                        )
                else:
                    if split == "train":
                        avg_acc, robust_acc = get_accs_for_epoch_across_batches(
                            dfs[split], epoch)
                        if metric == "avg_acc":
                            acc = avg_acc
                        elif metric == "robust_acc":
                            acc = robust_acc
                    else:
                        idx = np.where(dfs[split]["epoch"] == epoch)[0][
                            -1]  # Take the last batch in this epoch
                        acc = dfs[split].loc[idx, metric]

                    if split not in results[metric]:
                        results[metric][split] = {}

                    if params is None:
                        if output:
                            print(
                                f"{metric_str} {split:<5} acc ({epoch_print_str} {epoch}): "
                                f"{acc*100:.1f}")
                            with open(output_dir + "/val_accuracies.txt",
                                      "a") as text_file:
                                print(
                                    f"{metric_str} {split:<5} acc ({epoch_print_str} {epoch}): "
                                    f"{acc*100:.1f}",
                                    file=text_file,
                                )
                    else:
                        n_str = f"n_{split}"
                        minority_n = np.min(params[n_str])
                        total_n = np.sum(params[n_str])
                        if metric == "robust_acc":
                            n = minority_n
                        elif metric == "avg_acc":
                            n = total_n

                        stddev = np.sqrt(acc * (1 - acc) / n)
                        results[metric][split][epoch_save_str] = (acc, stddev)

                        if output:
                            print(
                                f"{metric_str} {split:<5} acc ({epoch_print_str} {epoch}): "
                                f"{acc*100:.1f} ({stddev*100:.1f})")
    return results

# CHAPTER 3.2

In [None]:
output_dir = "CUB/CUB_sample_exp"
runs = ["ERM_upweight_0_epochs_3_lr_1e-05_weight_decay_1.0"]

In [None]:
# Print robust val accuracies from downstream runs
for run in runs:
    try:
        sub_exp_name = run
        
        training_output_dir = os.path.join(output_dir,
                                        sub_exp_name, "model_outputs")
        train_path = os.path.join(training_output_dir, "train.csv")
        val_path = os.path.join(training_output_dir, "val.csv")
        test_path = os.path.join(training_output_dir, "test.csv")
        train_df = pd.read_csv(train_path)
        val_df = pd.read_csv(val_path)
        test_df = pd.read_csv(test_path)
        group_count = np.max(np.array([col.split(":")[1] for col in val_df.columns if "_group" in col]).astype(int)) + 1
        
        process_df(train_df, val_df, test_df, n_groups=group_count)

        dfs = {}
        dfs["train"] = train_df
        dfs["val"] = val_df
        dfs["test"] = test_df
        
        print(f"Downstream Accuracies for {sub_exp_name} with {group_count} groups.")
        with open(training_output_dir + "/val_accuracies.txt", "a") as text_file:
            print(f"Downstream Accuracies for {sub_exp_name}", file=text_file)
            
        # Print average and worst group accuracies for val
        print_accs(
            dfs,
            training_output_dir,
            params=None,
            epoch_to_eval=None,
            print_avg=True,
            print_groups=True,
            output=True,
            splits=["val", 'test'],
            early_stop=True,
        )
        print("\n")
        
    except:
        import sys
        if str(sys.exc_info()[0]) != "<class 'FileNotFoundError'>":
            print("\n")
            print(f"problem with {run}")
            print(sys.exc_info())
            pass