In [1]:
import torch
import pandas as pd
import os

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from torch import nn
from datasets import Dataset , DatasetDict
from transformers import AutoTokenizer, AutoModel , BertForQuestionAnswering

### Preparing data

In [2]:
binary_sentiment_classification = pd.read_csv(os.getcwd() + '/data/IMDB Dataset.csv')
summarization = pd.read_csv(os.getcwd() + '/data/new_summarization_data.csv')

In [3]:
binary_sentiment_classification.head(1)

Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive


In [4]:
binary_sentiment_classification['type'] = ['binary_classification'] * binary_sentiment_classification.shape[0]

In [5]:
binary_sentiment_classification.rename(columns={'review' : 'text' , 'sentiment':'labels'} , inplace=True)

In [6]:
binary_sentiment_classification.head(1)

Unnamed: 0,text,labels,type
0,One of the other reviewers has mentioned that ...,positive,binary_classification


In [7]:
summarization.head(1)

Unnamed: 0.1,Unnamed: 0,ID,Content,Summary,Dataset
0,0,f49ee725a0360aa6881ed1f7999cc531885dd06a,New York police are concerned drones could bec...,Police have investigated criminals who have ri...,CNN/Daily Mail


In [8]:
summarization = summarization.filter({'Content' , 'Summary' , 'summarization'})

In [42]:
summarization['type'] = ['summ'] * summarization.shape[0]

In [43]:
summarization.head(1)

Unnamed: 0,labels,text,type
0,New York police are concerned drones could bec...,Police have investigated criminals who have ri...,summ


In [44]:
summarization.rename(columns={'Summary': 'text' , 'Content':'labels'} , inplace=True)

In [45]:
summarization.head(1)

Unnamed: 0,labels,text,type
0,New York police are concerned drones could bec...,Police have investigated criminals who have ri...,summ


In [46]:
df = pd.concat([summarization , binary_sentiment_classification])

In [47]:
df = df.sample(frac=1).reset_index(drop=True)

In [48]:
df.head(20)

Unnamed: 0,labels,text,type
0,She's been accused of shamelessly attaching he...,Will Germany's triumph at the World Cup help A...,summ
1,11 July 2017 Last updated at 07:55 BST\nScrapp...,Tackling a blaze isn't just a job for firefigh...,summ
2,Cecilio Lopez Sanchez was hoisted out of the c...,A Spanish caver has been freed after being tra...,summ
3,Controversy: the comments made by Mr Davies ha...,MP criticised for 'insulting' and 'embarrassin...,summ
4,"(CNN) -- A year ago, President Obama delivered...","A year ago, President Obama gave a key speech ...",summ
5,By . Darren Boyle for MailOnline . A former pu...,"Archie Reed, 20, is accused of attempting to r...",summ
6,(CNN) -- German football is riding the crest o...,Owen Hargreaves says English football need to ...,summ
7,The Democratic-run U.S. Senate defeated measur...,– A measure to defund Planned Parenthood was s...,summ
8,"By . Ap Reporter . PUBLISHED: . 17:23 EST, 23...",Mormons have long taught disaster preparedness...,summ
9,"By . Ruth Styles . Sheila Thomas, now 64, was ...","Sheila Thomas, now 64, was 17 when she had to ...",summ


In [49]:
train, test = train_test_split(df, test_size=0.2)

In [50]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [51]:
datasets_train_test = DatasetDict({
    "train": Dataset.from_pandas(train),
    "test": Dataset.from_pandas(test)
    })

def tokenize_function(df):
    return tokenizer(df["text"], padding="max_length", truncation=True)
tokenized_datasets = datasets_train_test.map(tokenize_function, batched=True)

  0%|          | 0/737 [00:00<?, ?ba/s]

  0%|          | 0/185 [00:00<?, ?ba/s]

In [52]:
import pickle

# preprocessing checkpoint

with open('tokenized_datasets.t', 'wb') as f:
    pickle.dump(tokenized_datasets, f)

In [53]:
tokenized_datasets = tokenized_datasets.remove_columns(['text' , '__index_level_0__'])
tokenized_datasets = tokenized_datasets.rename_column('type' , 'tasks')
tokenized_datasets.set_format("torch")

In [54]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['labels', 'tasks', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 736416
    })
    test: Dataset({
        features: ['labels', 'tasks', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 184105
    })
})

### Defining model

In [55]:
class BinaryClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, model, dropout_p=0.1):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)
        
        self._init_weights()
        
        self.forward = model.forward
    
    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()
            
class QAClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, model, dropout_p=0.1):
        super().__init__()
        self.qa_outputs = nn.Linear(hidden_size, num_labels)
        
        self._init_weights()
    
        self.forward = model.forward
        
    def _init_weights(self):
        self.qa_outputs.weight.data.normal_(mean=0.0, std=0.02)
        if self.qa_outputs.bias is not None:
            self.qa_outputs.bias.data.zero_()

In [81]:
# tasks is list type
# bert-base-uncased
class MLT(nn.Module):
    def __init__(self, model_name, tasks):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.task_heads = nn.ModuleDict()
        for task in tasks:
            output_head = self._create_output_head(self.model.config.hidden_size , task , self.model)
            self.task_heads[task] = output_head
    @staticmethod
    def _create_output_head(model_hidden_size , task, model):
        if task == "summ":
            return QAClassificationHead(model_hidden_size, 2 ,model)
        elif task == "binary_class":
            return BinaryClassificationHead(model_hidden_size, 2 ,model)
        else:
            raise NotImplementedError()
            
    # 'labels', 'tasks', 'input_ids', 'token_type_ids', 'attention_mask
    # forward method of MLT class
    
    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        tasks=None,
        **kwargs,):
        
        print("IN MAIN FORWARD")
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        
        # get only last sequence and pooled output
        last_sequence, pooled_output = outputs[:2]
        print(type(tasks))
        unique_tasks = [tasks] # Need to make this into a tensor of strings, or better, just keep as list and see where it goes
        
        loss = []
        logits = None
        
        for unique_task in unique_tasks:
            task_id_filter = tasks == unique_task
            logits, task_loss = self.task_heads[str(unique_task[0])].forward(
                last_sequence[task_id_filter],
                pooled_output[task_id_filter],
                attention_mask[task_id_filter],
            )

            if labels is not None:
                loss_list.append(task_loss)
        
    
        outputs = (logits, outputs[2:])
        
        if loss:
            loss = torch.stack(loss)
            outputs = (loss.mean(),) + outputs
            
        return outputs

In [57]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [58]:
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

training_args = TrainingArguments(output_dir="test_trainer")

In [59]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [82]:
model = MLT("bert-base-uncased" , ['binary_class' , 'summ'])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [83]:
for e in small_train_dataset:
    print((e['input_ids'].size()))
    model(input_ids = e['input_ids'].unsqueeze(0) , tasks = [e['tasks']] , 
          attention_mask = e['attention_mask'].unsqueeze(0) , token_type_ids = e['token_type_ids'])
    break

torch.Size([512])
IN MAIN FORWARD
<class 'list'>


ValueError: too many values to unpack (expected 2)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()