In [None]:
import torch
import numpy as np

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class WeightedRandomSampler(Sampler):
    
    def __init__(self, labels, stratify = None, weights = None):
        self.labels = labels
        self.label_counts = labels.value_counts().to_dict()
        self.num_labels = len(self.label_counts.keys())
        
        if weights is None:
            self.weights = {key: 1 for key, val in self.label_counts.items()}
        else:
            self.weights = weights
        
        if stratify == 'increase':
            self.samples_per_label = {key: round(max(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        elif stratify == 'decrease':
            self.samples_per_label = {key: round(min(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        else:
            self.samples_per_label = {key: round(self.label_counts[key] * self.weights[key]) for key in self.weights} 
        
        self.num_samples = sum(self.samples_per_label.values())
        
    def __iter__(self):
        indices = []
        for lbl, amount in self.samples_per_label.items():
            label_data = self.labels[self.labels == lbl]
            label_counts = len(label_data)
            for i in range(amount // label_counts):
                indices += label_data.index[torch.randperm(label_counts)].tolist()
            indices += label_data.index[torch.randperm(amount % label_counts)].tolist()
        return iter(np.array(indices)[torch.randperm(self.num_samples)].tolist())
        
    def __len__(self):
        return self.num_samples

In [None]:
import torch
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import numpy as np

class CLS_Datasets(torch.utils.data.Dataset):
    
    def __init__(self, data_paths, split = [0,1]):
        
        self.encodings = pd.DataFrame(columns = ['attention_mask', 'input_ids', 'token_type_ids', 'encoded_labels'])
        self.others = pd.DataFrame(columns = ['groups', 'data_id' 'labels'])
        for (data_path, data_id) in data_paths:
            dataset = pd.read_csv(f'{data_path}', header = 0, index_col = 0).reset_index()
            dataset = dataset[round(len(dataset)*split[0]): round(len(dataset)*split[1])]
            dataset['data_id'] = data_id
            tqdm.pandas()
            dataset[['attention_mask', 'input_ids', 'token_type_ids']] = \
            dataset[['attention_mask', 'input_ids', 'token_type_ids']].progress_applymap(literal_eval)
            num_labels = len(pd.unique(self.encodings['encoded_labels']))
            lbls = np.sort(pd.unique(dataset['labels']))
            dataset['encoded_labels'] = dataset['labels'].apply(lambda lbl: np.where(lbls == lbl)[0][0] + num_labels + addit)
            self.encodings = self.encodings.append(\
            dataset[['attention_mask', 'input_ids', 'token_type_ids', 'encoded_labels']])
            
            self.others = self.others.append(dataset.drop(['attention_mask', 'input_ids', 'token_type_ids', 'encoded_labels'], axis=1))
        self.encodings = self.encodings.rename(columns = {'encoded_labels': 'labels'})
    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings.loc[idx].to_dict('list'), self.others.loc[idx]

In [None]:
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
from transformers import DataCollatorWithPadding
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

class Trainer():
    
    def __init__(self,
                 data_paths,
                 save_path,
                 split = [0,1],
                 model_path = None,
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 learning_rate = 5e-5,
                 num_epochs = 3,
                 warmup_percent = 1,
                ):
        
        self.dataset = CLS_Datasets(data_paths, split)
        self.save_path = save_path
        
        data_classes = {}
        encs = self.dataset.encodings
        othr = self.dataset.others
        for data_id in pd.unique(othr['data_id']):
            data_classes[data_id] = [label for label in \
                pd.unique(encs['labels'][othr['data_id'] == data_id])]
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = 7)
        self.wrs = WeightedRandomSampler(encs['labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(self.wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast = True)
        dataCollator = DataCollatorWithPadding(tokenizer = self.tokenizer)
        self.train_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = lambda x: dataCollator(x[0]), 
                                       sampler = sampler)
        
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        self.num_epochs = num_epochs
        self.num_warmup_steps = round(num_epochs * len(self.train_loader) * warmup_percent / (batch_size * 100)) * batch_size
        self.num_training_steps = num_epochs * len(self.train_loader)
        self.lr_scheduler = get_scheduler("linear", 
                                     optimizer = self.optimizer, 
                                     num_warmup_steps = self.num_warmup_steps, 
                                     num_training_steps = self.num_training_steps)
    
    def train(self, save = True, evaluator = None):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(device)
        progress_bar_train = tqdm(range(self.num_training_steps))
        losses = []
        current_loss = 0
        for epoch in range(self.num_epochs):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                #if i > 10: break
                enc = {k: v.to(device) for k, v in batch.items()}
                outputs = self.model(**enc)
                loss = outputs.loss
                current_loss = loss.detach().cpu()
                losses += [loss.detach().cpu()]
                loss.backward()

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                progress_bar_train.update(1)
                
                if i % round(len(self.train_loader)/3) == 1:
                    print(f'Average loss: {sum(losses)/len(losses)}')
                    losses = []
            if evaluator is not None:
                self.model.eval()
                evaluator.evaluate(self.model)
        print(f'Current loss: {current_loss}')
        if save:
            self.model.save_pretrained(self.save_path)

In [None]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from transformers import DataCollatorWithPadding
import torch
from tqdm.auto import tqdm
from sklearn import metrics

class Evaluator():
    
    def __init__(self,
                 data_paths,
                 split = [0,1],
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16
                ):
        
        self.dataset = CLS_Datasets(data_paths, split)
        wrs = WeightedRandomSampler(self.dataset.encodings['labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast = True)
        dataCollator = DataCollatorWithPadding(tokenizer = self.tokenizer)
        self.eval_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = lambda x: (dataCollator(x[0]), x[1]), 
                                       sampler = sampler)
        
    
    def evaluate(self, model):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)

        data_ids = []
        groups = []
        lgits = []
        preds = []
        refs = []
        opts = []

        model.eval()
        self.progress_bar_eval = tqdm(range(len(self.eval_loader)))
        self.progress_bar_eval.reset()
        for i, (enc, others) in enumerate(self.eval_loader):
            #if i > 10: break
            enc = {k: v.to(device) for k, v in enc.items()}
            with torch.no_grad():
                outputs = model(**enc)
            logits = outputs.logits
                
            vals, predictions = torch.max(logits, dim=-1)
            data_ids += others['data_id'].tolist()
            groups += others['groups'].tolist()
            lgits += vals.detach().cpu()
            preds += predictions.detach().cpu()
            refs += enc['labels'].detach().cpu()

            self.progress_bar_eval.update(1)
        
        df = pd.DataFrame({'data_ids': data_ids, 'groups': groups, 'logits': lgits, 'preds': preds, 'refs': refs})

        prec, rec, f1, dist = metrics.precision_recall_fscore_support(refs, preds, average=None)
        print("-"*20 + "EVALUATION" + "-"*20)
        print('Precision: \t\t{}'.format(prec))
        print('Recall: \t\t{}'.format(rec))
        print('F1: \t\t\t{}'.format(f1))
        print('Distribution: \t\t{}'.format(dist))
        
        for data_id in pd.unique(self.dataset.others['data_id']):
            d = df[df['data_ids'] == data_id]
            #int("Original task ({}): \t\t{}".format(data_id, d[(d.groupby('groups'))['logits_bin'].transform(max) == d['logits_bin']]['refs'].mean()))
        
            tmp = d[(d.groupby('groups'))['logits'].transform(max) == d['logits']]
            print("Original task ({}) v2: \t{}".format(data_id, len(tmp[tmp['preds'] == tmp['refs']].groupby('groups').first())/len(d.groupby('groups').first())))

        print("-"*20 + "----------" + "-"*20)

evaluator = Evaluator(data_paths = [#('../../data/tokenized/bert/ecqa/test.csv', 'ecqa'),
                                    #('../../data/tokenized/bert/esnli/test.csv', 'esnli'),
                                    #('../../data/generated/multitask/separate/comve_test_bert.csv', 'ecqa')
                                   ],
                      #split = [0,0.01], 
                      batch_size = 32)

trainer = Trainer(data_paths = [('../../data/tokenized/bert/ecqa/train.csv', 'ecqa'),
                               ('../../data/tokenized/bert/ecqa/train.csv', 'esnli'),
                               ('../../data/tokenized/bert/comve/train.csv', 'comve')
                                ], 
                  save_path = 'Shared/Gold', 
                  model_path = 'eSNLI/Gold', 
                  split = [0,0.5],
                  num_epochs = 3, 
                  batch_size = 4,
                  stratify = 'decrease')
trainer.train(save = True)#, evaluator = evaluator)

In [None]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('Shared/Gold')

In [None]:
addit = 2
for sepsh in ['separate', 'shared']:
    for stmt in ['singletask', 'multitask']:
        try:
            evaluator = Evaluator(data_paths = [('../../data/generated/{}/{}/esnli_test_bert.csv'.format(stmt, sepsh), 'esnli')], 
                          #split = [0.0,0.001], 
                          batch_size = 32)
            evaluator.evaluate(model)
            print('SUCCESS: ../../data/generated/{}/{}/esnli_test_bert.csv'.format(stmt, sepsh))
        except Exception as e:
            print('FAILED: ../../data/generated/{}/{}/esnli_test_bert.csv'.format(stmt, sepsh))
            print(e)


evaluator.evaluate(model)

In [None]:
3814 / (3814 + 7464)