# ReduceMix Training

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
    Trainer, 
    TrainingArguments, 
    TrainerCallback, 
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.trainer_callback import TrainerControl
from datasets import load_dataset
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transforms import TextMix, SentMix, WordMix, SibylCollator

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def tokenize_fn(text):
    return tokenizer(text, padding=True, truncation=True, max_length=250, return_tensors='pt')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=250)

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')
        
def compute_metrics(pred):
    preds, labels = pred
    if len(labels.shape) > 1: 
        acc = acc_at_k(labels, preds, k=2)
        return { 'accuracy': acc }        
    else:
        acc = accuracy_score(labels, preds.argmax(-1))
        return { 'accuracy': acc }        

class TargetedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        if len(labels.shape) > 1: 
            loss = CEwST_loss(logits, labels)
        else:
            loss = torch.nn.functional.cross_entropy(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss

class TargetedMixturesCallback(TrainerCallback):
    """
    A callback that calculates a confusion matrix on the validation
    data and returns the most confused class pairings.
    """
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device
        
    def on_evaluate(self, args, state, control, model, tokenizer, **kwargs):
        cnf_mat = self.get_confusion_matrix(model, tokenizer, self.dataloader)
        new_targets = self.get_most_confused_per_class(cnf_mat)
        print("New targets:", new_targets)
        control = TrainerControl
        control.new_targets = new_targets
        if state.global_step < state.max_steps:
            control.should_training_stop = False
        else:
            control.should_training_stop = True
        return control
        
    def get_confusion_matrix(self, model, tokenizer, dataloader, normalize=True):
        n_classes = max(dataloader.dataset['label']) + 1
        confusion_matrix = torch.zeros(n_classes, n_classes)
        with torch.no_grad():
            for batch in iter(self.dataloader):
                data, targets = batch['text'], batch['label']
                data = tokenizer(data, padding=True, truncation=True, max_length=250, return_tensors='pt')
                input_ids = data['input_ids'].to(self.device)
                attention_mask = data['attention_mask'].to(self.device)
                targets = targets.to(self.device)
                outputs = model(input_ids, attention_mask=attention_mask).logits
                preds = torch.argmax(outputs, dim=1).cpu()
                for t, p in zip(targets.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1    
            if normalize:
                confusion_matrix = confusion_matrix / confusion_matrix.sum(dim=0)
        return confusion_matrix

    def get_most_confused_per_class(self, confusion_matrix):
        idx = torch.arange(len(confusion_matrix))
        cnf = confusion_matrix.fill_diagonal_(0).max(dim=1)[1]
        return torch.stack((idx, cnf)).T.tolist()

class TargetedMixturesCollator:
    def __init__(self, 
                 tokenize_fn, 
                 transform, 
                 transform_prob=1.0, 
                 target_pairs=[], 
                 target_prob=1.0, 
                 num_classes=2):
        
        self.tokenize_fn = tokenize_fn
        self.transform = transform
        self.transform_prob = transform_prob
        self.target_pairs = target_pairs
        self.target_prob = target_prob
        self.num_classes = num_classes
        print("TargetedMixturesCollator initialized with {}".format(transform.__class__.__name__))
        
    def __call__(self, batch):
        text = [x['text'] for x in batch]
        labels = [x['label'] for x in batch]
        batch = (text, labels)
        if torch.rand(1) < self.transform_prob:
            batch = self.transform(
                batch, 
                self.target_pairs,   
                self.target_prob,
                self.num_classes
            )
        text, labels = batch
        labels = torch.tensor(labels)
        if len(labels.shape) == 1:
            labels = torch.nn.functional.one_hot(labels, num_classes=self.num_classes)
        batch = self.tokenize_fn(text)
        batch['labels'] = labels
        batch.pop('idx', None)
        batch.pop('label', None)
        return batch
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [4]:
MODEL_NAMES = ['xlnet-base-cased']
# ts = ['ORIG', 'INV', 'SIB', 'INVSIB', 'TextMix', 'SentMix', 'WordMix']
ts = ['SIB', 'INVSIB', 'TextMix', 'SentMix', 'WordMix']

In [5]:
results = []

for MODEL_NAME in MODEL_NAMES:
        
    for t in ts: 
        
        transform = None
        num_sampled_INV = 0
        num_sampled_SIB = 0
        label_type = "soft"
        
        if t == "INV":
            num_sampled_INV = 2
            label_type = "hard"
        elif t == "SIB":
            num_sampled_SIB = 2
        elif t == 'INVSIB':
            num_sampled_INV = 1
            num_sampled_SIB = 1
            label_type = None
        elif t == "TextMix":
            transform = TextMix()
        elif t == "SentMix":
            transform = SentMix()
        elif t == "WordMix":
            transform = WordMix()
        
        checkpoint = './results/' + MODEL_NAME + '-SST2-SibylCollator-' + t
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3).to(device)
        
        dataset = load_dataset('glue', 'sst2', split='train[:90%]') 
        dataset.rename_column_('sentence', 'text')
        dataset_dict = dataset.train_test_split(
            test_size = 0.05,
            train_size = 0.95,
            shuffle = True
        )
        train_dataset = dataset_dict['train']
        eval_dataset = dataset_dict['test']

        test_dataset = load_dataset('glue', 'sst2', split='train[90%:]')
        test_dataset = test_dataset.rename_column('sentence', 'text') 
        test_dataset = test_dataset.rename_column('label', 'labels')
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        train_batch_size = 6
        eval_batch_size  = 32
        num_epoch = 20
        gradient_accumulation_steps = 1
        max_steps = int((len(train_dataset) * num_epoch / gradient_accumulation_steps) / train_batch_size)

#         tmcb = TargetedMixturesCallback(
#             dataloader=DataLoader(eval_dataset, batch_size=32),
#             device=device
#         )
        escb = EarlyStoppingCallback(
            early_stopping_patience=10
        )
#         tmc = TargetedMixturesCollator(
#             tokenize_fn=tokenize_fn, 
#             transform=t, 
#             transform_prob=0.5,
#             target_pairs=[],
#             target_prob=0.5,
#             num_classes=4
#         )
        sibyl_collator = SibylCollator( 
            tokenize_fn=tokenize_fn, 
            transform=transform, 
            num_sampled_INV=num_sampled_INV, 
            num_sampled_SIB=num_sampled_SIB, 
            task_type="topic", 
            tran_type=None, 
            label_type=label_type,
            one_hot=label_type == "soft",
            transform_prob=0.5,
            target_pairs=[],
            target_prob=0.0,
            reduce_mixed=True,
            num_classes=2
        )

        training_args = TrainingArguments(\
            output_dir=checkpoint,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps, 
            warmup_steps=int(max_steps / 10),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=2000,
            logging_first_step=True,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            evaluation_strategy="steps",
            remove_unused_columns=False
        )

        trainer = TargetedTrainer(
            model=model, 
            tokenizer=tokenizer,
            args=training_args,
            compute_metrics=compute_metrics,                  
            train_dataset=train_dataset,         
            eval_dataset=eval_dataset,
            data_collator=sibyl_collator if t != "ORIG" else DefaultCollator(),
            callbacks=[escb] # [tmcb, escb]
        )

        trainer.train()

        # test with ORIG data
        trainer.eval_dataset = test_dataset
        trainer.data_collator = DefaultCollator()
        # trainer.remove_callback(tmcb)

        out_orig = trainer.evaluate()
        out_orig['run'] = checkpoint
        out_orig['test'] = "ORIG"
        print('ORIG for {}\n{}'.format(checkpoint, out_orig))

        results.append(out_orig)

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


SibylCollator initialized with num_sampled_INV=0 and num_sampled_SIB=2


W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
2000,0.7796,0.470491,0.835038,14.6584,206.775
4000,0.5165,0.397779,0.856813,14.6498,206.896
6000,0.5049,0.34593,0.881557,14.4384,209.927
8000,0.4834,0.474579,0.879248,13.7951,219.716
10000,0.4815,0.531787,0.877598,15.4676,195.958
12000,0.4566,0.513983,0.889475,13.452,225.32
14000,0.4699,0.54727,0.855823,13.832,219.13
16000,0.4858,0.433912,0.894754,13.5119,224.321
18000,0.4964,0.424506,0.87133,13.5418,223.826
20000,0.4876,0.429425,0.864071,13.7143,221.01


early_stopping_patience_counter


ORIG for ./results/xlnet-base-cased-SST2-SibylCollator-SIB
{'eval_loss': 0.30294692516326904, 'eval_accuracy': 0.9257609502598366, 'eval_runtime': 26.3794, 'eval_samples_per_second': 255.313, 'epoch': 3.75, 'run': './results/xlnet-base-cased-SST2-SibylCollator-SIB', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


SibylCollator initialized with num_sampled_INV=1 and num_sampled_SIB=1


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
2000,0.8075,0.589471,0.774002,22.0939,137.187
4000,0.5299,0.503504,0.838997,21.533,140.761
6000,0.51,0.465756,0.848565,22.0925,137.196
8000,0.4747,0.479523,0.838667,20.7435,146.118
10000,0.4698,0.412564,0.846585,22.6638,133.737
12000,0.4653,0.469427,0.861102,18.8576,160.731
14000,0.4741,0.501554,0.856153,21.321,142.16
16000,0.4985,0.491795,0.871,20.6323,146.906
18000,0.505,0.461174,0.854503,19.7367,153.572
20000,0.497,0.507519,0.864401,17.5985,172.23


early_stopping_patience_counter


ORIG for ./results/xlnet-base-cased-SST2-SibylCollator-INVSIB
{'eval_loss': 0.2668299973011017, 'eval_accuracy': 0.9253155159613957, 'eval_runtime': 26.3495, 'eval_samples_per_second': 255.603, 'epoch': 3.75, 'run': './results/xlnet-base-cased-SST2-SibylCollator-INVSIB', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


SibylCollator initialized with TextMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
2000,0.8363,0.517927,0.771363,10.9078,277.874
4000,0.5526,0.520226,0.835698,10.6585,284.374
6000,0.527,0.538134,0.843616,10.452,289.991
8000,0.502,0.450566,0.860442,10.458,289.825
10000,0.4994,0.455773,0.866711,10.7782,281.217
12000,0.4854,0.553503,0.872979,10.1,300.1
14000,0.5165,0.496163,0.859782,10.9766,276.132
16000,0.5231,0.470007,0.872979,10.2885,294.601
18000,0.5134,0.564205,0.845596,10.3922,291.661
20000,0.5126,0.639038,0.844936,10.3623,292.503


early_stopping_patience_counter


ORIG for ./results/xlnet-base-cased-SST2-SibylCollator-TextMix
{'eval_loss': 0.267045259475708, 'eval_accuracy': 0.9328878990348923, 'eval_runtime': 26.3675, 'eval_samples_per_second': 255.428, 'epoch': 5.0, 'run': './results/xlnet-base-cased-SST2-SibylCollator-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


SibylCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
2000,0.8314,0.553951,0.751567,10.5119,288.341
4000,0.5571,0.534504,0.832399,10.7303,282.471
6000,0.519,0.577473,0.82811,10.471,289.467
8000,0.508,0.503339,0.848235,10.2067,296.962
10000,0.4918,0.75956,0.852854,10.9826,275.982
12000,0.495,0.651991,0.843946,10.1219,299.449
14000,0.4949,0.611563,0.851534,10.481,289.191
16000,0.4945,0.531208,0.861762,10.0002,303.093
18000,0.5183,0.650011,0.841636,10.0092,302.821
20000,0.5334,0.710924,0.841636,9.9863,303.516


early_stopping_patience_counter


ORIG for ./results/xlnet-base-cased-SST2-SibylCollator-SentMix
{'eval_loss': 0.29202547669410706, 'eval_accuracy': 0.932293986636971, 'eval_runtime': 26.2897, 'eval_samples_per_second': 256.184, 'epoch': 5.42, 'run': './results/xlnet-base-cased-SST2-SibylCollator-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


SibylCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
2000,0.8792,0.59109,0.735071,10.5288,287.876
4000,0.6145,0.569824,0.794457,10.5258,287.958
6000,0.6049,0.540938,0.792148,10.4281,290.657
8000,0.5703,0.538365,0.808644,10.4141,291.047
10000,0.5477,0.585113,0.779611,11.3168,267.833
12000,0.5444,0.498109,0.78291,10.1498,298.625
14000,0.5507,0.57749,0.816232,10.4849,289.081
16000,0.5864,0.63051,0.800396,10.2566,295.518
18000,0.5686,0.595756,0.819861,10.2306,296.267
20000,0.5704,0.535857,0.830419,9.9025,306.084


early_stopping_patience_counter


ORIG for ./results/xlnet-base-cased-SST2-SibylCollator-WordMix
{'eval_loss': 0.33149608969688416, 'eval_accuracy': 0.9058648849294729, 'eval_runtime': 26.3794, 'eval_samples_per_second': 255.313, 'epoch': 4.17, 'run': './results/xlnet-base-cased-SST2-SibylCollator-WordMix', 'test': 'ORIG'}


In [6]:
df = pd.DataFrame(results)

In [7]:
df.to_csv('train_SST2_SibylCollator_XLNET.csv')

In [8]:
df.to_clipboard(excel=True)

In [9]:
df

Unnamed: 0,eval_loss,eval_accuracy,eval_runtime,eval_samples_per_second,epoch,run,test
0,0.302947,0.925761,26.3794,255.313,3.75,./results/xlnet-base-cased-SST2-SibylCollator-SIB,ORIG
1,0.26683,0.925316,26.3495,255.603,3.75,./results/xlnet-base-cased-SST2-SibylCollator-...,ORIG
2,0.267045,0.932888,26.3675,255.428,5.0,./results/xlnet-base-cased-SST2-SibylCollator-...,ORIG
3,0.292025,0.932294,26.2897,256.184,5.42,./results/xlnet-base-cased-SST2-SibylCollator-...,ORIG
4,0.331496,0.905865,26.3794,255.313,4.17,./results/xlnet-base-cased-SST2-SibylCollator-...,ORIG
