In [1]:
import torch
import numpy as np

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class WeightedRandomSampler(Sampler):
    
    def __init__(self, labels, stratify = None, weights = None):
        self.labels = labels
        self.label_counts = labels.value_counts().to_dict()
        self.num_labels = len(self.label_counts.keys())
        
        if weights is None:
            self.weights = {key: 1 for key, val in self.label_counts.items()}
        else:
            self.weights = weights
        
        if stratify == 'increase':
            self.samples_per_label = {key: round(max(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        elif stratify == 'decrease':
            self.samples_per_label = {key: round(min(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        else:
            self.samples_per_label = {key: round(self.label_counts[key] * self.weights[key]) for key in self.weights} 
        
        self.num_samples = sum(self.samples_per_label.values())
        
    def __iter__(self):
        indices = []
        for lbl, amount in self.samples_per_label.items():
            label_data = self.labels[self.labels == lbl]
            label_counts = len(label_data)
            for i in range(amount // label_counts):
                indices += label_data.index[torch.randperm(label_counts)].tolist()
            indices += label_data.index[torch.randperm(amount % label_counts)].tolist()
        return iter(np.array(indices)[torch.randperm(self.num_samples)].tolist())
        
    def __len__(self):
        return self.num_samples

In [2]:
import torch
import pandas as pd
from ast import literal_eval
from tqdm import tqdm

class CLS_Dataset(torch.utils.data.Dataset):
    
    def __init__(self, data_path, split = [0,1]):
        
        encs = ['attention_mask', 'input_ids', 'token_type_ids']
        dataset = pd.read_csv(f'{data_path}', header = 0, index_col = 0).reset_index()
        print(pd.unique(dataset['labels']))
        dataset = dataset[round(len(dataset)*split[0]): round(len(dataset)*split[1])]
        dataset['labels'] = dataset['labels'].apply(lambda lbl: int(lbl))
        tmp = dataset[dataset['labels'] == 1].groupby(['groups']).first().reset_index()
        dataset = dataset[dataset['labels'] == 0].append(tmp).reset_index()
        
        dataset['labels'] += addit
        
        tqdm.pandas()
        
        self.encodings = dataset[encs].progress_applymap(literal_eval)
        self.encodings['labels'] = dataset['labels']
        self.others = dataset.drop(encs + ['labels'], axis=1)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings.loc[idx].to_dict('list'), self.others.loc[idx]

In [3]:
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
from transformers import DataCollatorWithPadding
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

class Trainer():
    
    def __init__(self,
                 data_path,
                 save_path,
                 split = [0,1],
                 model_path = None,
                 num_labels = 2,
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 learning_rate = 5e-5,
                 num_epochs = 3,
                 warmup_percent = 1,
                ):
        
        self.dataset = CLS_Dataset(data_path, split)
        self.save_path = save_path
        
        if model_path is None:
            self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = num_labels)
        else:
            self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.wrs = WeightedRandomSampler(self.dataset.encodings['labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(self.wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast = True)
        
        dataCollator = DataCollatorWithPadding(tokenizer = self.tokenizer)
        self.train_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = lambda x: dataCollator(x[0]), 
                                       sampler = sampler)
        
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        self.num_epochs = num_epochs
        self.num_warmup_steps = round(num_epochs * len(self.train_loader) * warmup_percent / (batch_size * 100)) * batch_size
        self.num_training_steps = num_epochs * len(self.train_loader)
        self.lr_scheduler = get_scheduler("linear", 
                                     optimizer = self.optimizer, 
                                     num_warmup_steps = self.num_warmup_steps, 
                                     num_training_steps = self.num_training_steps)
    
    def train(self, save = True, evaluator = None):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(device)
        
        progress_bar_train = tqdm(range(self.num_training_steps))
        
        losses = []
        for epoch in range(self.num_epochs):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                #if i > 10: break
                enc = {k: v.to(device) for k, v in batch.items()}
                outputs = self.model(**enc)
                loss = outputs.loss
                losses += [loss.detach().cpu()]
                loss.backward()

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                progress_bar_train.update(1)
                
                if i % round(len(self.train_loader)/3) == 1:
                    print(f'Average loss: {sum(losses)/len(losses)}')
                    losses = []
            if evaluator is not None:
                self.model.eval()
                evaluator.evaluate(self.model)
        if save:
            self.model.save_pretrained(self.save_path)

In [4]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from transformers import DataCollatorWithPadding
import torch
from tqdm.auto import tqdm
from sklearn import metrics

class Evaluator():
    
    def __init__(self,
                 data_path,
                 split = [0,1],
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16
                ):
        
        self.dataset = CLS_Dataset(data_path, split)
        wrs = WeightedRandomSampler(self.dataset.encodings['labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast = True)
        dataCollator = DataCollatorWithPadding(tokenizer = self.tokenizer)
        self.eval_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = lambda x: (dataCollator(x[0]), x[1]), 
                                       sampler = sampler)
        
    
    def evaluate(self, model):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)

        groups = []
        lgits = []
        lgits_bin = []
        preds = []
        refs = []
        opts = []

        model.eval()
        self.progress_bar_eval = tqdm(range(len(self.eval_loader)))
        self.progress_bar_eval.reset()
        for i, (enc, others) in enumerate(self.eval_loader):
            #if i > 10: break
            enc = {k: v.to(device) for k, v in enc.items()}
            with torch.no_grad():
                outputs = model(**enc)
            logits = outputs.logits
            vals, predictions = torch.max(logits, dim=-1)
            groups += others['groups'].tolist()
            lgits_bin += logits[:,1+addit].detach().cpu()
            lgits += vals.detach().cpu()
            preds += predictions.detach().cpu() - addit
            refs += enc['labels'].detach().cpu() - addit

            self.progress_bar_eval.update(1)
        
        df = pd.DataFrame({'groups': groups, 'logits': lgits, 'logits_bin': lgits_bin, 'preds': preds, 'refs': refs})
        print("-"*20 + "----------" + "-"*20)
        print(df.head())
        prec, rec, f1, dist = metrics.precision_recall_fscore_support(refs, preds, average=None)
        print("-"*20 + "EVALUATION" + "-"*20)
        print('Precision: \t\t{}'.format(prec))
        print('Recall: \t\t{}'.format(rec))
        print('F1: \t\t\t{}'.format(f1))
        print('Distribution: \t\t{}'.format(dist))
        
        print("Original task: \t\t{}".format(df[(df.groupby('groups'))['logits_bin'].transform(max) == df['logits_bin']]['refs'].mean()))
        
        tmp = df[(df.groupby('groups'))['logits'].transform(max) == df['logits']]
        print("Original task v2: \t{}".format(len(tmp[tmp['preds'] == tmp['refs']].groupby('groups').first())/len(df.groupby('groups').first())))
        
        print("-"*20 + "----------" + "-"*20)

In [5]:
sen1 = 'incorrect'
sen2 = 'explanation'


In [6]:
addit = 0
trainer = Trainer(data_path = '../../data/generated/multitask/separate/ecqa_train_bert.csv', 
                  save_path = 'ECQA/Gen3_notft',
                  num_labels = 3,
                  #model_path = 'eSNLI/Gen3', 
                  #split = [0.25,0.5],
                  num_epochs = 3, 
                  batch_size = 16,
                  stratify = 'decrease')
trainer.train(save = True)#, evaluator=evaluator)

[0. 1.]


  0%|          | 0/97647 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

  0%|          | 0/2445 [00:00<?, ?it/s]

Average loss: 0.8940809965133667
Average loss: 0.7197940945625305
Average loss: 0.6927849054336548
Average loss: 0.6727956533432007
Average loss: 0.5723435282707214
Average loss: 0.5591082572937012
Average loss: 0.5627303123474121
Average loss: 0.28720608353614807
Average loss: 0.28685641288757324


stop

In [13]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('ECQA/Gen3')

In [8]:
addit = 0

In [14]:
evaluator = Evaluator(data_path = '../../data/tokenized/bert/ecqa/test.csv', 
#evaluator = Evaluator(data_path = '../../data/generated/multitask/separate/ecqa_test_bert.csv', 
                          #split = [0.4,0.6], 
                          batch_size = 32)


[1 0]


  0%|          | 0/27996 [00:00<?, ?it/s]

In [15]:
evaluator.evaluate(model)

  0%|          | 0/292 [00:00<?, ?it/s]

--------------------------------------------------
                             groups          logits       logits_bin  \
0  8d87db96a87432ee5f01d24326b121b5  tensor(3.4676)  tensor(-3.7930)   
1  ddea491827544bc6d1a9e3aa715817cf  tensor(0.2067)   tensor(0.2067)   
2  37743518d3b95c18cc0e820c2eb6175f  tensor(1.5626)  tensor(-1.4801)   
3  19629ab397ea8f1bdc46629668c62af2  tensor(3.2118)  tensor(-3.3375)   
4  2156156d215e0fe1ded434bd40a7f3d8  tensor(1.2851)   tensor(1.2851)   

       preds       refs  
0  tensor(0)  tensor(0)  
1  tensor(1)  tensor(0)  
2  tensor(0)  tensor(0)  
3  tensor(0)  tensor(1)  
4  tensor(1)  tensor(1)  
--------------------EVALUATION--------------------
Precision: 		[0.88026706 0.40933642]
Recall: 		[0.7948821  0.56798715]
F1: 			[0.83539848 0.47578475]
Distribution: 		[7464 1868]
Original task: 		0.5048179871520343
Original task v2: 	0.8832976445396146
--------------------------------------------------


model.config

addit = 2
for sepsh in ['separate']:
    for stmt in ['singletask']:
        try:
            evaluator = Evaluator(data_path = '../../data/generated/{}/{}/esnli_test_bert.csv'.format(stmt, sepsh), 
                          #split = [0.4,0.6], 
                          batch_size = 32)
            evaluator.evaluate(model)
        except:
            print('FAILED: ../../data/generated/{}/{}/esnli_test_bert.csv'.format(stmt, sepsh))

pd.unique(evaluator.dataset.encodings['labels'])