In [None]:
import torch
import numpy as np

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class WeightedRandomSampler(Sampler):
    
    def __init__(self, labels, stratify = None, weights = None):
        self.labels = labels
        self.label_counts = labels.value_counts().to_dict()
        self.num_labels = len(self.label_counts.keys())
        
        if weights is None:
            self.weights = {key: 1 for key, val in self.label_counts.items()}
        else:
            self.weights = weights
        
        if stratify == 'increase':
            self.samples_per_label = {key: round(max(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        elif stratify == 'decrease':
            self.samples_per_label = {key: round(min(self.label_counts.values()) * weight) for key, weight in self.weights.items()}
        else:
            self.samples_per_label = {key: round(self.label_counts[key] * self.weights[key]) for key in self.weights} 
        
        self.num_samples = sum(self.samples_per_label.values())
        
    def __iter__(self):
        indices = []
        for lbl, amount in self.samples_per_label.items():
            label_data = self.labels[self.labels == lbl]
            label_counts = len(label_data)
            for i in range(amount // label_counts):
                indices += label_data.index[torch.randperm(label_counts)].tolist()
            indices += label_data.index[torch.randperm(amount % label_counts)].tolist()
        return iter(np.array(indices)[torch.randperm(self.num_samples)].tolist())
        
    def __len__(self):
        return self.num_samples

In [None]:
import torch
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import numpy as np

class DatasetForMultiTaskLearning(torch.utils.data.Dataset):
    
    def __init__(self, data_paths, split = [0,1]):
        
        self.encodings = pd.DataFrame(columns = ['cls_labels'])
        self.others = pd.DataFrame()
        for (data_path, data_id) in data_paths:
            dataset = pd.read_csv(f'{data_path}', header = 0, index_col = 0)
            dataset = dataset[round(len(dataset)*split[0]): round(len(dataset)*split[1])]
            #dataset = dataset.dropna().groupby([sen1, sen2]).first().reset_index()
            dataset['data_id'] = data_id
            lit_evals = ['input_ids', 'attention_mask']
            if 'gold_ids' in dataset:
                lit_evals += ['gold_ids', 'gold_mask']
    
            tqdm.pandas()        
            dataset[lit_evals] = dataset[lit_evals].progress_applymap(literal_eval)
            
            
            num_labels = len(pd.unique(self.encodings['cls_labels']))
            
            lbls = np.sort(pd.unique(dataset['labels']))
            
            dataset['cls_labels'] = dataset['labels'].apply(lambda lbl: np.where(lbls == lbl)[0][0] + num_labels)
            tmp = dataset[dataset['cls_labels'] == 1].groupby(['groups']).first().reset_index()
            dataset = dataset[dataset['cls_labels'] == 0].append(tmp).reset_index()
            #dataset['cls_labels_ids'] = [lbls + num_labels] * len(dataset)
            
            dataset['cls_labels'] += addit
        
            self.encodings = self.encodings.append(dataset[lit_evals + ['cls_labels']])
            self.others = self.others.append(dataset.drop(lit_evals + ['cls_labels'], axis=1))
        if 'gold_ids' in self.encodings:
            lm_labels = lambda x: [-100]*len(x['input_ids']) + x['gold_ids'][(len(x['input_ids'])):]
            self.encodings['lm_labels'] = self.encodings.apply(lm_labels, axis=1)
        self.encodings = self.encodings.reset_index()
        self.others = self.others.reset_index()
        
        
    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings.loc[idx].to_dict('list'), self.others.loc[idx]

In [None]:
from transformers import GPT2Model
from transformers import GPT2PreTrainedModel
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.modeling_outputs import ModelOutput
from transformers import GPT2Config
from dataclasses import dataclass
from typing import Optional
import torch

class GPT2ForMultiTaskConfig(GPT2Config):
    def __init__(
        self,
        n_labels = 2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.n_labels = n_labels

@dataclass
class GPT2ForMultiTaskOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    cls_loss: Optional[torch.FloatTensor] = None
    cls_logits: torch.FloatTensor = None

class GPT2ForMultiTaskLearning(GPT2PreTrainedModel):
    
    config_class = GPT2ForMultiTaskConfig
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.n_labels
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        #self.dropout = nn.Dropout(0.1)
        #self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2', config=config)
        self.cls_head = nn.Linear(config.n_embd, config.n_labels, bias=False)
        
        # Initialize weights and apply final processing
        #self.post_init()
        
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        lm_labels=None,
        cls_labels=None,
    ):
        
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        #print(hidden_states.size())
        #print(hidden_states)
        
        lm_logits = self.lm_head(hidden_states)        
        cls_logits = self.cls_head(hidden_states)
        #rint(cls_logits)
        batch_size, sequence_length = input_ids.shape[:2]
        sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
        cls_logits = cls_logits[torch.arange(batch_size, device=self.device), sequence_lengths]
        
        #rint(cls_logits)
        
        lm_loss = None
        if lm_labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            #lm_loss /= torch.ne(shift_labels, -100).sum(-1).float().mean()

        cls_loss = None
        if cls_labels is not None:
            loss_fct = CrossEntropyLoss()
            cls_loss = loss_fct(cls_logits.view(-1, self.num_labels), cls_labels.view(-1))

        #if not return_dict:
        #    output = (lm_logits,) + transformer_outputs[1:]
        #    return ((loss,) + output) if loss is not None else output

        return GPT2ForMultiTaskOutput(
            loss=lm_loss,
            cls_loss=cls_loss,
            logits=lm_logits,
            cls_logits=cls_logits,
            #past_key_values=transformer_outputs.past_key_values,
            #hidden_states=transformer_outputs.hidden_states,
            #attentions=transformer_outputs.attentions,
            #cross_attentions=transformer_outputs.cross_attentions,
        )

In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel
from transformers import GPT2TokenizerFast
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

class Trainer():
    
    def __init__(self,
                 data_paths,
                 save_path,
                 eval_paths = None,
                 split = [0,1],
                 model_path = 'gpt2',
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 learning_rate = 5e-5,
                 num_epochs = 3,
                 warmup_percent = 1,
                ):
        
        self.dataset = DatasetForMultiTaskLearning(data_paths, split)
        self.save_path = save_path
        
        wrs = WeightedRandomSampler(self.dataset.encodings['cls_labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', use_fast = True)
        
        self.train_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = self.train_collator, 
                                       sampler = sampler)
        self.eval_loader = None
        self.eval_dataset = None
        if eval_paths is not None:
            self.eval_dataset = DatasetForMultiTaskLearning(eval_paths, split)
            wrs = WeightedRandomSampler(self.eval_dataset.encodings['cls_labels'], stratify = stratify, weights = sample_weights)
            sampler = torch.utils.data.sampler.BatchSampler(wrs, batch_size=batch_size,drop_last=False)
            self.eval_loader = DataLoader(dataset = self.eval_dataset, 
                                       batch_size = None, 
                                       collate_fn = self.train_collator, 
                                       sampler = sampler)
        num_labels = len(pd.unique(self.dataset.encodings['cls_labels']))
        config = GPT2ForMultiTaskConfig(n_labels = num_labels, pad_token_id = self.tokenizer.eos_token_id)
        self.model = GPT2ForMultiTaskLearning.from_pretrained(model_path, config = config)
        
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        self.num_epochs = num_epochs
        self.num_warmup_steps = 0#round(num_epochs * len(self.train_loader) * warmup_percent / (batch_size * 100)) * batch_size
        self.num_training_steps = num_epochs * len(self.train_loader)
        self.lr_scheduler = get_scheduler("linear", 
                                     optimizer = self.optimizer, 
                                     num_warmup_steps = self.num_warmup_steps, 
                                     num_training_steps = self.num_training_steps)
    
    def train_collator(self, batch):
        max_len = max([len(l) for l in batch[0]['gold_ids']])
        
        pads = {'input_ids': self.tokenizer.eos_token_id, 'lm_labels': -100, 'attention_mask': 0}
        keys = {'gold_ids': 'input_ids', 'gold_mask': 'attention_mask', 'lm_labels': 'lm_labels'}
        out = {}
        for key, val in keys.items():
            out[val] = torch.tensor([sample + [pads[val]] * (max_len - len(sample)) for sample in batch[0][key]])
        out['cls_labels'] = torch.tensor(batch[0]['cls_labels'])
        return out
    
    def train(self, save = True):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(device)
        
        progress_bar_train = tqdm(range(self.num_training_steps))
        self.losses = []
        for epoch in range(self.num_epochs):
            if self.eval_loader is not None:
                self.evaluate()
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                enc = {k: v.to(device) for k, v in batch.items()}
                
                outputs = self.model(**enc)
                loss = outputs.loss + outputs.cls_loss
                #loss = outputs.cls_loss
                self.losses += [loss.detach().cpu()]
                loss.backward()

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                progress_bar_train.update(1)
                
                if i % round(len(self.train_loader)/3) == 1:
                    print(f'Average training loss: {sum(self.losses)/len(self.losses)}')
                    
            if save:
                self.model.save_pretrained(f'{self.save_path}_{epoch}')
                
        if self.eval_loader is not None:
            self.evaluate()
    
    def evaluate(self):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.eval()
        lm_losses = []
        cls_losses = []
        cls_preds = []
        cls_true = []
        progress_bar_eval = tqdm(range(len(self.eval_loader)))
        for i, batch in enumerate(self.eval_loader):
            enc = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = self.model(**enc)
            lm_losses += [outputs.loss.detach().cpu()]
            cls_losses += [outputs.cls_loss.detach().cpu()]
            vals, predictions = torch.max(outputs.cls_logits, dim=-1)
            cls_preds += predictions.detach().cpu().tolist()#outputs.cls_logits.detach().cpu().tolist()
            cls_true += enc['cls_labels'].cpu().tolist()
            progress_bar_eval.update(1)

        print(f'Average language modeling loss: {sum(lm_losses)/len(lm_losses)}')
        print(f'Average classification loss: {sum(cls_losses)/len(cls_losses)}')
        print(f'Average classification accuracy: {sum([x == y for x, y in zip(cls_preds, cls_true)])/len(cls_true)}')
        losses = []
        alpha = (sum(lm_losses)/len(lm_losses)) / (sum(cls_losses)/len(cls_losses))
        print(f'Alpha = {alpha}')
        return alpha

In [None]:
trainer = Trainer(data_paths = [('../../data/generated/multitask/separate/ecqa_train_gpt.csv', 'ecqa'),
#trainer = Trainer(data_paths = [('../../data/tokenized/gpt2/ecqa/train.csv', 'ecqa'),
                               #('../../data/tokenized/gpt2/esnli/train.csv', 'esnli'),
                               #('../../data/tokenized/gpt2/comve/train.csv', 'comve')
                               ],
                  #eval_paths = [('../../data/tokenized/gpt2/ecqa/train.csv', 'ecqa'),
                  #eval_paths = [#('../../data/tokenized/gpt2/ecqa/dev.csv', 'ecqa'),
                               #('../../data/tokenized/gpt2/esnli/dev.csv', 'esnli'),
                               #('../../data/tokenized/gpt2/comve/dev.csv', 'comve')
                  #            ], 
                  save_path = 'eSNLI/Gen2_notft', 
                  #model_path = 'eSNLI/Gold', 
                  #split = [0,0.5],
                  num_epochs = 3, 
                  batch_size = 4,
                  stratify = 'decrease')
trainer.train(True)

In [None]:
trainer.model.save_pretrained(trainer.save_path)

In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel
from transformers import GPT2TokenizerFast
from transformers import DataCollatorForLanguageModeling
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm
from datasets import load_metric
import json

class Generator():
    
    def __init__(self,
                 data_paths,
                 save_paths,
                 split = [0,1],
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                 gold_explanations = True
                ):
        
        self.dataset = DatasetForMultiTaskLearning(data_paths, split)
        self.save_paths = save_paths
        
        self.wrs = WeightedRandomSampler(self.dataset.encodings['cls_labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(self.wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', use_fast = True)
        self.test_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = self.test_collator, 
                                       sampler = sampler)
        
        self.bertscore = load_metric('bertscore')
        self.gold_explanations = gold_explanations
        
    def test_collator(self, batch):
        max_len = max([len(l) for l in batch[0]['input_ids']])
        
        pads = {'input_ids': self.tokenizer.eos_token_id, 'attention_mask': 0}
        keys = {'input_ids': 'input_ids', 'attention_mask': 'attention_mask'}
        out = {}
        for key, val in keys.items():
            out[val] = torch.tensor([[pads[val]] * (max_len - len(sample)) + sample for sample in batch[0][key]])
            
        return out, batch[1]
    
    def generate(self, model, save = False):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)

        progress_bar = tqdm(range(len(self.test_loader)))
        
        self.df = pd.DataFrame()
        model.eval()
        for i, batch in enumerate(self.test_loader):
            enc = {k: v.to(device) for k, v in batch[0].items()}
            with torch.no_grad():
                generation = model.generate(
                    input_ids = enc['input_ids'],
                    attention_mask = enc['attention_mask'],
                    do_sample = True,
                    max_length = 100,
                    temperature = 0.7,
                    top_k = 50,
                    top_p = 0.7
                )
                batch_input = self.tokenizer.batch_decode(enc['input_ids'].detach(), skip_special_tokens=True)
                preds = []
                for i, e in zip(batch_input, self.tokenizer.batch_decode(generation.detach(), skip_special_tokens=True)):
                    preds += [e[len(i):].strip()]
                
                if self.gold_explanations: 
                    try:
                        golds = batch[1]['explanations'].tolist()
                        self.bertscore.add_batch(predictions=preds, references=golds)
                    except:
                        continue
                batch[1]['generated'] = preds
                self.df = self.df.append(batch[1])

            progress_bar.update(1)
            if save:
                for (save_path, data_id) in self.save_paths:
                    with open(save_path, 'w') as f:
                        json.dump(self.df[self.df['data_id'] == data_id].to_dict('list'), f)
                              
            if self.gold_explanations:
                res = self.bertscore.compute(lang = 'en')
                print(f'{round(sum(res["precision"])/len(res["precision"]), 2)}\t\t|\t'+
                      f'{round(sum(res["recall"])/len(res["recall"]), 2)}\t\t|\t'+
                      f'{round(sum(res["f1"])/len(res["f1"]), 2)}')

                with open('../../data/generated/bertscores.json'.format(save_path), 'w') as f:
                    json.dump({'Precision': round(sum(res["precision"])/len(res["precision"]), 2),
                               'Recall': round(sum(res["recall"])/len(res["recall"]), 2),
                               'F1-score': round(sum(res["f1"])/len(res["f1"]), 2)}, 
                              f)

generator = Generator(data_paths = [#('../../data/tokenized/gpt2/ecqa/train.csv', 'ecqa'),
                                   ('../../data/tokenized/gpt2/esnli/train.csv', 'esnli'),
                                   #('../../data/tokenized/gpt2/comve/test.csv', 'comve')
                                    ],
                      save_paths =  [#('../../data/generated/multitask/shared/ecqa_train.json', 'ecqa'),
                                   ('../../data/generated/multitask/separate/esnli_train.json', 'esnli'),
                                   #('../../data/generated/multitask/separate/comve_test.json', 'comve')
                                    ],
                      split = [0,0.5], 
                      batch_size = 8,
                     gold_explanations = False)

model = GPT2ForMultiTaskLearning.from_pretrained('eSNLI/GPT2MultiTask_2')#GPT2LMHeadModel.from_pretrained('gpt2')

generator.tokenizer(f"Statement: Where is the best place to keep ice crean?\nStatement: party\nExplanation:", return_tensors='pt')

model.to('cpu')
out = model.generate(**generator.tokenizer(f"Statement: Where is the best place to keep ice crean?\nStatement: party\nExplanation:", return_tensors='pt'), max_length = 200)
generator.tokenizer.batch_decode(out)

generator.generate(model, save = True)

stop

import pandas as pd
import json

with open('../../data/generated/multitask/separate/ecqa_train.json', 'r') as f:
    ecqa_generated = pd.DataFrame(json.load(f))
for sample in ecqa_generated.groupby('labels').sample(n=2).iterrows():
    sample = sample[1]
    print('Questions: {}\nOptions: {}\nLabel: {}\nExplanation: {}\nGenerated: {}\n\n'.format(sample['questions'],sample['options'],sample['labels'],sample['explanations'],sample['generated']))

for sample in esnli_generated.groupby('labels').sample(n=2).iterrows():
    sample = sample[1]
    print('Premise: {}\nHypothesis: {}\nLabel: {}\nExplanation: {}\nGenerated: {}\n\n'.format(sample['premise'],sample['hypothesis'],sample['labels'],sample['explanations'],sample['generated']))

import pandas as pd
import json

with open('../../data/generated/multitask/separate/comve_test.json', 'r') as f:
    comve_generated = pd.DataFrame(json.load(f))
for sample in comve_generated.groupby('labels').sample(n=2).iterrows():
    sample = sample[1]
    print('Correct: {}\nIncorrect: {}\nLabel: {}\nExplanation: {}\nGenerated: {}\n\n'.format(sample['correct'],sample['incorrect'],sample['labels'],sample['explanations'],sample['generated']))

In [None]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from transformers import GPT2TokenizerFast
from transformers import DataCollatorWithPadding
import torch
from tqdm.auto import tqdm
from sklearn import metrics

class Evaluator():
    
    def __init__(self,
                 data_paths,
                 split = [0,1],
                 stratify = None,
                 sample_weights = None,
                 batch_size = 16,
                ):
        
        self.dataset = DatasetForMultiTaskLearning(data_paths, split)
        
        self.wrs = WeightedRandomSampler(self.dataset.encodings['cls_labels'], stratify = stratify, weights = sample_weights)
        sampler = torch.utils.data.sampler.BatchSampler(self.wrs, batch_size=batch_size,drop_last=False)
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', use_fast = True)
        self.eval_loader = DataLoader(dataset = self.dataset, 
                                       batch_size = None, 
                                       collate_fn = self.eval_collator, 
                                       sampler = sampler)
    
    def eval_collator(self, batch):
        max_len = max([len(l) for l in batch[0]['attention_mask']])
        
        pads = {'input_ids': self.tokenizer.eos_token_id, 'attention_mask': 0}
        keys = {'input_ids': 'input_ids', 'attention_mask': 'attention_mask'}
        out = {}
        for key, val in keys.items():
            out[val] = torch.tensor([sample + [pads[val]] * (max_len - len(sample)) for sample in batch[0][key]])
        out['cls_labels'] = torch.tensor(batch[0]['cls_labels'])
        return out, batch[1]
    
    def evaluate(self, model):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)
        
        groups = []
        lgits = []
        lgits_bin1 = []
        lgits_bin6 = []
        preds = []
        refs = []
        opts = []
        data_ids = []

        model.eval()
        self.progress_bar_eval = tqdm(range(len(self.eval_loader)))
        self.progress_bar_eval.reset()
        for i, (enc, others) in enumerate(self.eval_loader):
            enc = {k: v.to(device) for k, v in enc.items()}
            with torch.no_grad():
                outputs = model(**enc)
            logits = outputs.cls_logits
            vals, predictions = torch.max(logits, dim=-1)
            groups += others['groups'].tolist()
            lgits_bin1 += logits[:,1].detach().cpu()
            #lgits_bin6 += logits[:,6].detach().cpu()
            lgits += vals.detach().cpu()
            preds += predictions.detach().cpu()
            refs += enc['cls_labels'].tolist()
            data_ids += others['data_id'].tolist()

            self.progress_bar_eval.update(1)
            
        df = pd.DataFrame({'data_ids': data_ids, 'groups': groups, 'logits': lgits, 'logits_bin1': lgits_bin1, 'preds': preds, 'refs': refs})
        #df = pd.DataFrame({'data_ids': data_ids, 'groups': groups, 'logits': lgits, 'logits_bin1': lgits_bin1,'logits_bin6': lgits_bin6, 'preds': preds, 'refs': refs})
        
        print("-"*20 + "----------" + "-"*20)
        prec, rec, f1, dist = metrics.precision_recall_fscore_support(refs, preds, average=None)
        print("-"*20 + "EVALUATION" + "-"*20)
        print('Precision: \t\t{}'.format(prec))
        print('Recall: \t\t{}'.format(rec))
        print('F1: \t\t\t{}'.format(f1))
        print('Distribution: \t\t{}'.format(dist))
        
        ecqa_task = df[(df.groupby('groups'))['logits_bin1'].transform(max) == df['logits_bin1']]
        ecqa_task = ecqa_task[ecqa_task['data_ids'] == 'ecqa']
        print("Original task ecqa: \t\t{}".format(ecqa_task['refs'].mean()))
        #comve_task = df[(df.groupby('groups'))['logits_bin1'].transform(max) == df['logits_bin1']]
        #comve_task = comve_task[comve_task['data_ids'] == 'comve']
        #print("Original task comve: \t\t{}".format((comve_task['refs'] - 0).mean()))
        
        tmp = df[(df.groupby('groups'))['logits'].transform(max) == df['logits']]
        tmp = tmp[tmp['data_ids'] == 'esnli']
        print("Original task v2: \t{}".format(len(tmp[tmp['preds'] == tmp['refs']].groupby('groups').first())/len(df.groupby('groups').first())))
        
        print("-"*20 + "----------" + "-"*20)

In [None]:
model = GPT2ForMultiTaskLearning.from_pretrained('ECQA/GPT2MultiTask')

In [None]:
sen1 = 'questions'
sen2 = 'options'
addit = 0
evaluator = Evaluator(data_paths = [('../../data/generated/multitask/separate/ecqa_test_gpt.csv', 'ecqa'),
                                   #('../../data/tokenized/gpt2/esnli/test.csv', 'esnli'),
                                   #('../../data/tokenized/gpt2/comve/test.csv', 'comve')
                                    ],
                      #split = [0,0.5], 
                      batch_size = 8,
                     )
evaluator.evaluate(model)

stop

In [None]:
addit = 0
for sepsh in ['separate', 'shared']:
    for stmt in ['singletask', 'multitask']:
        try:
            evaluator = Evaluator(data_paths = [('../../data/generated/{}/{}/comve_test_gpt.csv'.format(stmt, sepsh), 'comve')], 
                          #split = [0.0,0.001], 
                          batch_size = 8)
            evaluator.evaluate(model)
            print('SUCCESS: ../../data/generated/{}/{}/comve_test_gpt.csv'.format(stmt, sepsh))
        except Exception as e:
            print('FAILED: ../../data/generated/{}/{}/comve_test_gpt.csv'.format(stmt, sepsh))
            print(e)

evaluator.dataset.encodings['cls_labels'].value_counts()

evaluator.evaluate(model)