# Setup

In [10]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

Mounted at /content/drive
Drive mounted.


In [11]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM, GroupDRO, Focal, Algorithm
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

# JTT Implementation

In [12]:
## Based on https://github.com/YyzHarry/SubpopBench

### Algorithm 6: JTT ###
class AbstractTwoStage(Algorithm):
    def __init__(self, num_classes, num_attributes, hparams):
        super().__init__(num_classes, num_attributes, hparams)
        self.stage1_model = ERM(num_classes, num_attributes, hparams)
        self.first_stage_step_frac = hparams['first_stage_step_frac']
        self.switch_step = int(self.first_stage_step_frac * hparams['num_training_steps'])
        self.cur_model = self.stage1_model
        self.stage2_model = None    # implement in child classes

    def update(self, minibatch, step):
        all_i, all_x, all_y, all_a = minibatch
        if step < self.switch_step:
            self.cur_model = self.stage1_model
            self.cur_model.train()
            loss = self.stage1_model._compute_loss(all_i, all_x, all_y, all_a, step)
        else:
            self.cur_model = self.stage2_model
            self.cur_model.train()
            self.stage1_model.eval()
            loss = self.stage2_model._compute_loss(all_i, all_x, all_y, all_a, step, self.stage1_model)

        self.cur_model.optimizer.zero_grad()
        loss.backward()
        if self.cur_model.clip_grad:
            torch.nn.utils.clip_grad_norm_(self.cur_model.network.parameters(), 1.0)
        self.cur_model.optimizer.step()

        if self.cur_model.lr_scheduler is not None:
            self.cur_model.lr_scheduler.step()

        self.cur_model.network.zero_grad()
        return {'loss': loss.item()}

    def return_feats(self, x):
        return self.cur_model.featurizer(x)

    def predict(self, x):
        return self.cur_model.network(x)

class JTT_Stage2(ERM):
    def __init__(self, num_classes, num_attributes, hparams):
        super().__init__(num_classes, num_attributes, hparams)

    def _compute_loss(self, i, x, y, a, step, stage1_model):
        with torch.no_grad():
            predictions = stage1_model.predict(x)

        if predictions.squeeze().ndim == 1:
            wrong_predictions = (predictions > 0).detach().ne(y).float()
        else:
            wrong_predictions = predictions.argmax(1).detach().ne(y).float()

        weights = torch.ones(wrong_predictions.shape).to(x.device).float()
        weights[wrong_predictions == 1] = self.hparams["jtt_lambda"]

        return (self.loss(self.predict(x), y) * weights).mean()

class JTT(AbstractTwoStage):
    """
    Just-train-twice (JTT) [https://arxiv.org/pdf/2107.09044.pdf]
    """
    def __init__(self, num_classes, num_attributes, hparams):
        super().__init__(num_classes, num_attributes, hparams)
        self.stage2_model = JTT_Stage2(num_classes, num_attributes, hparams)

# Data

In [13]:
hparams = hparams_f('JTT')
hparams

{'batch_size': 32,
 'last_layer_dropout': 0.5,
 'optimizer': 'adamw',
 'weight_decay': 0.0001,
 'lr': 1e-05,
 'group_balanced': False,
 'num_training_steps': 30001,
 'num_warmup_steps': 0,
 'first_stage_step_frac': 0.5,
 'jtt_lambda': 10}

In [14]:
device = "cuda"
train_weights = None
batch_size = hparams['batch_size']

In [18]:
DATASET = 'MultiNLI'  # 'CivilComments' , 'MultiNLI'

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    train_dataset = MultiNLI(data_path, 'tr', hparams)
    models_path = path_to_root + '/models/models_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    train_dataset = CivilComments(data_path, 'tr', hparams, granularity="fine")
    models_path = path_to_root + '/models/models_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

MultiNLI


In [20]:
train_loader  = InfiniteDataLoader(  dataset=train_dataset,
                                    weights=train_weights,
                                    batch_size=batch_size,
                                    num_workers=1)
steps_per_epoch = len(train_dataset) / batch_size

# Model

In [21]:
algorithm_name = 'JTT'
random_seeds = [1,2] #[0,1,2]
init_state_dict_path = lambda random_seed : models_path + f'/00_randominit/seed{random_seed}/sd_epoch0.pth'
state_dict_PATH = models_path + '/06_jtt/'

# Training

In [22]:
start_step = 1
n_steps = 30001 #hparams['num_training_steps']
checkpoint_freq = 1000
train_losses = {}

for seed in random_seeds:
    print('Training Seed:' , seed)
    algorithm = JTT(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
    algorithm.to(device)
    ### Matching the Keys ###
    sd_init = torch.load(init_state_dict_path(seed))
    sd_init_matched = {f'stage1_model.{key}' :sd_init[key]  for key in sd_init.keys()}
    sd_algorithm = algorithm.state_dict()
    for key in sd_init_matched:
        sd_algorithm[key] = sd_init_matched[key]
    algorithm.load_state_dict(sd_algorithm)
    #########################
    train_losses[seed] = []

    train_loader = InfiniteDataLoader(  dataset=train_dataset,
                                        weights=train_weights,
                                        batch_size=batch_size,
                                        num_workers=1)
    train_minibatches_iterator = iter(train_loader)

    for step in tqdm.tqdm(range(start_step, n_steps)):
        ### Training Step ###
        i, x, y, a = next(train_minibatches_iterator)
        minibatch_device = (i, x.to(device), y.to(device), a.to(device))
        algorithm.train()
        step_vals = algorithm.update(minibatch_device, step)
        train_losses[seed].append(step_vals['loss'])

        ### Evaluation ###
        if (step % checkpoint_freq == 0) or (step == n_steps - 1):
            epoch = int(step / checkpoint_freq)
            algorithm_state_dict = algorithm.state_dict()
            algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
            torch.save(algorithm_state_dict, algorithm_state_dict_PATH)

    loss_PATH = state_dict_PATH + f'losses/Loss_{algorithm_name}_{seed}.pth'
    torch.save(train_losses, loss_PATH)

Training Seed: 1


100%|██████████| 30000/30000 [1:36:49<00:00,  5.16it/s]


Training Seed: 2


100%|██████████| 30000/30000 [1:36:47<00:00,  5.17it/s]
