In [None]:
import torch
import numpy as np

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class WeightedRandomSampler(Sampler):
    
    def __init__(self, labels, stratify = None, weights = None):
        self.labels = labels
        self.label_counts = labels.value_counts().to_dict()
        self.num_labels = len(self.label_counts.keys())
        
        if weights is None:
            self.weights = {key: 1 for key, val in self.label_counts.items()}
        else:
            self.weights = weights
        
        if stratify == 'increase':
            self.samples_per_label = {key: round(max(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        elif stratify == 'decrease':
            self.samples_per_label = {key: round(min(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        else:
            self.samples_per_label = {key: round(self.label_counts[key] * self.weights[key]) for key in self.weights} 
        
        self.num_samples = sum(self.samples_per_label.values())
        
    def __iter__(self):
        indices = []
        for lbl, amount in self.samples_per_label.items():
            label_data = self.labels[self.labels == lbl]
            label_counts = len(label_data)
            for i in range(amount // label_counts):
                indices += label_data.index[torch.randperm(label_counts)].tolist()
            indices += label_data.index[torch.randperm(amount % label_counts)].tolist()
        return iter(np.array(indices)[torch.randperm(self.num_samples)].tolist())
        
    def __len__(self):
        return self.num_samples

In [None]:
import torch
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import numpy as np

class LM_Datasets(torch.utils.data.Dataset):
    
    def __init__(self, data_paths, split = [0,1]):
        
        self.encodings = pd.DataFrame()
        self.others = pd.DataFrame(columns = ['cls_labels'])
        for (data_path, data_id) in data_paths:
            dataset = pd.read_csv(f'{data_path}', header = 0, index_col = 0)
            dataset = dataset[round(len(dataset)*split[0]): round(len(dataset)*split[1])]
            dataset['data_id'] = data_id
            lit_evals = ['input_ids', 'attention_mask', 'gold_ids', 'gold_mask']
    
            tqdm.pandas()        
            dataset[lit_evals] = dataset[lit_evals].progress_applymap(literal_eval)
            
            num_labels = len(pd.unique(self.others['cls_labels']))
            
            lbls = np.sort(pd.unique(dataset['labels']))
            
            dataset['cls_labels'] = dataset['labels'].apply(lambda lbl: np.where(lbls == lbl)[0][0] + num_labels)
        
            self.encodings = self.encodings.append(dataset[lit_evals])
            self.others = self.others.append(dataset.drop(lit_evals, axis=1))
        
            lm_labels = lambda x: [-100]*len(x['input_ids']) + x['gold_ids'][(len(x['input_ids'])):]
            self.encodings['lm_labels'] = self.encodings.apply(lm_labels, axis=1)
        self.encodings = self.encodings.reset_index()
        self.others = self.others.reset_index()
            
    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings.loc[idx].to_dict('list'), self.others.loc[idx]

In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel
from transformers import GPT2TokenizerFast
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

class Trainer():
    
    def __init__(self,
                 data_paths,
                 save_path,
                 split = [0,1],
                 model_path = 'gpt2',
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 learning_rate = 5e-5,
                 num_epochs = 3,
                 warmup_percent = 1,
                ):
        
        self.dataset = LM_Datasets(data_paths, split)
        self.save_path = save_path
        
        wrs = WeightedRandomSampler(self.dataset.others['cls_labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', use_fast = True)
        
        self.train_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = self.train_collator, 
                                       sampler = sampler)
        
        
        self.model = GPT2LMHeadModel.from_pretrained(model_path)
        self.model.pad_token_id = self.tokenizer.eos_token_id
        
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        self.num_epochs = num_epochs
        self.num_warmup_steps = round(num_epochs * len(self.train_loader) * warmup_percent / (batch_size * 100)) * batch_size
        self.num_training_steps = num_epochs * len(self.train_loader)
        self.lr_scheduler = get_scheduler("linear", 
                                     optimizer = self.optimizer, 
                                     num_warmup_steps = self.num_warmup_steps, 
                                     num_training_steps = self.num_training_steps)
    
    def train_collator(self, batch):
        max_len = max([len(l) for l in batch[0]['gold_ids']])

        pads = {'input_ids': self.tokenizer.eos_token_id, 'labels': -100, 'attention_mask': 0}
        keys = {'gold_ids': 'input_ids', 'gold_mask': 'attention_mask', 'lm_labels': 'labels'}
        out = {}
        for key, val in keys.items():
            out[val] = torch.tensor([sample + [pads[val]] * (max_len - len(sample)) for sample in batch[0][key]])
        return out
    
    def train(self, save = True, evaluator = None):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(device)
        
        progress_bar_train = tqdm(range(self.num_training_steps))
        losses = []
        for epoch in range(self.num_epochs):
            self.model.train()
            #print('epoch: {}'.format(epoch))
            for i, batch in enumerate(self.train_loader):
                #print('batch: {}'.format(i))
                enc = {k: v.to(device) for k, v in batch.items()}
                outputs = self.model(**enc)
                loss = outputs.loss
                losses += [loss.detach().cpu()]
                loss.backward()

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                progress_bar_train.update(1)
                
                if i % round(len(self.train_loader)/3) == 1:
                    print(f'Average loss: {sum(losses)/len(losses)}')
                    losses = []
            if evaluator is not None:
                self.model.eval()
                evaluator.evaluate(self.model, save)
        if save:
            self.model.save_pretrained(self.save_path)

trainer = Trainer(data_paths = [#('../../data/tokenized/gpt2/ecqa/train.csv', 'ecqa'),
                               ('../../data/tokenized/gpt2/esnli/train.csv', 'esnli'),
                               #('../../data/tokenized/gpt2/comve/train.csv', 'comve')
                               ], 
                  save_path = 'eSNLI/GPT2SingleTask', 
                  #model_path = 'eSNLI/Gold', 
                  split = [0,0.5],
                  num_epochs = 3, 
                  batch_size = 4,
                  stratify = 'decrease')
trainer.train(save = True)

In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel
from transformers import GPT2TokenizerFast
from transformers import DataCollatorForLanguageModeling
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm
from datasets import load_metric
import json

class Evaluator():
    
    def __init__(self,
                 data_paths,
                 save_paths,
                 split = [0,1],
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 gold_explanations = True
                ):
        
        self.dataset = LM_Datasets(data_paths, split)
        self.save_paths = save_paths
        
        self.wrs = WeightedRandomSampler(self.dataset.others['cls_labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(self.wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', use_fast = True)
        self.test_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = self.test_collator, 
                                       sampler = sampler)
        
        self.bertscore = load_metric('bertscore')
        self.gold_explanations = gold_explanations
        
    def test_collator(self, batch):
        max_len = max([len(l) for l in batch[0]['input_ids']])
        
        pads = {'input_ids': self.tokenizer.eos_token_id, 'attention_mask': 0}
        keys = {'input_ids': 'input_ids', 'attention_mask': 'attention_mask'}
        out = {}
        for key, val in keys.items():
            out[val] = torch.tensor([[pads[val]] * (max_len - len(sample)) + sample for sample in batch[0][key]])
            
        return out, batch[1]
    
    def evaluate(self, model, save = False):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)
        model.config.pad_token_id = model.config.eos_token_id
        model.config.max_length = 150

        progress_bar = tqdm(range(len(self.test_loader)))
        
        self.df = pd.DataFrame()
        model.eval()
        for i, batch in enumerate(self.test_loader):
            enc = {k: v.to(device) for k, v in batch[0].items()}
            with torch.no_grad():
                generation = model.generate(
                    input_ids = enc['input_ids'],
                    attention_mask = enc['attention_mask'],
                    do_sample = True,
                    max_length = 100,
                    temperature = 0.7,
                    top_k = 50,
                    top_p = 0.7
                )
                batch_input = self.tokenizer.batch_decode(enc['input_ids'].detach(), skip_special_tokens=True)
                preds = []
                for i, e in zip(batch_input, self.tokenizer.batch_decode(generation.detach(), skip_special_tokens=True)):
                    preds += [e[len(i):].strip()]
                
                if self.gold_explanations: 
                    try:
                        golds = batch[1]['explanations'].tolist()
                        self.bertscore.add_batch(predictions=preds, references=golds)
                    except:
                        continue
                batch[1]['generated'] = preds
                self.df = self.df.append(batch[1])

            progress_bar.update(1)
        if save:
            for (save_path, data_id) in self.save_paths:
                with open(save_path, 'w') as f:
                    json.dump(self.df[self.df['data_id'] == data_id].to_dict(), f)
                              
            if self.gold_explanations:
                res = self.bertscore.compute(lang = 'en')
                print(f'{round(sum(res["precision"])/len(res["precision"]), 2)}\t\t|\t'+
                      f'{round(sum(res["recall"])/len(res["recall"]), 2)}\t\t|\t'+
                      f'{round(sum(res["f1"])/len(res["f1"]), 2)}')

                with open('../../data/generated/bertscores.json'.format(save_path), 'w') as f:
                    json.dump({'Precision': round(sum(res["precision"])/len(res["precision"]), 2),
                               'Recall': round(sum(res["recall"])/len(res["recall"]), 2),
                               'F1-score': round(sum(res["f1"])/len(res["f1"]), 2)}, 
                              f)

In [None]:
evaluator = Evaluator(data_paths = [#('../../data/tokenized/gpt2/ecqa/train.csv', 'ecqa'),
                                   ('../../data/tokenized/gpt2/esnli/test.csv', 'esnli'),
                                   #('../../data/tokenized/gpt2/comve/train.csv', 'comve')
                                   ],
                      save_paths =  [#('../../data/generated/singletask/shared/ecqa_train.json', 'ecqa'),
                                   ('../../data/generated/multitask/separate/esnli_test.json', 'esnli'),
                                   #('../../data/generated/singletask/shared/comve_train.json', 'comve')
                                    ],
                      #split = [0,0.5], 
                      batch_size = 4,
                     gold_explanations = False)

In [None]:
model = GPT2LMHeadModel.from_pretrained('eSNLI/GPT2SingleTask')

In [None]:
evaluator.evaluate(model, save = True)