In [1]:
datapoints = [["What is the capital of Russia?", "The capital of Russia is Moscow.", 1],
           ["What is the capital of India?", "The capital of Russia is Delhi.", 1],
           ["What is the capital of United States?", "The capital of Russia is Washington.", 1], 
           ["What is the capital of Germany?", "The capital of Russia is Berlin.", 1],
           ["What is the capital of France?", "The capital of Russia is Paris.", 1],
           ["What is the capital of Russia?", "Goku loves chi chi.", 0],
           ["What is the capital of India?", "Gohan is better than Goku for sure.", 0],
           ["What is the capital of United States?", "Freeza has to freeze.", 0], 
           ["What is the capital of Germany?", "Einstien should have nuked Hitler.", 0],
           ["What is the capital of France?", "Newton lost it when the apple fell on his head.", 0]]

In [2]:
import numpy as np

import torch
import torch.nn as nn
from torch import nn

from transformers import BertTokenizer, BertModel, BertConfig, BertPreTrainedModel, AdamW
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

In [31]:
class Ensemble_Bert(BertPreTrainedModel):
    def __init__(self,config,*args,**kwargs):
        super().__init__(config)
        
        #first model
        self.bert_model_1 = BertModel(config)
        #second model
        self.bert_model_2 = BertModel(config)
        
        self.cls = nn.Linear(2*self.config.hidden_size,2)
        self.init_weights()
        
    def forward(self,
               input_ids=None,
               attention_mask=None,
               token_type_ids=None,
               next_sentence_label=None):
        
        outputs = []
        input_ids_1 = input_ids[0]
        attention_mask_1 = attention_mask[0]
        outputs.append(self.bert_model_1(input_ids_1,
                                         attention_mask=attention_mask_1))
        input_ids_2 = input_ids[1]
        attention_mask_2 = attention_mask[0]
        outputs.append(self.bert_model_2(input_ids_2,
                                         attention_mask=attention_mask_2))
        
        last_hidden_state =torch.cat([output[1] for output in outputs],dim=1)
        logits = self.cls(last_hidden_state)
        
        if next_sentence_label is not None:
            loss = nn.CrossEntropyLoss(ignore_index=-1)
            l = loss(logits.view(-1,2),next_sentence_label.view(-1))
            return l,logits
        else:
            return logits

In [32]:
config = BertConfig()
model = Ensemble_Bert(config)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
lr = 1e-5
no_decay =["bias","LayerNorm.weight"]
param = [p for n,p in model.named_parameters()if not  any(nd in n for nd in no_decay)]
optimizer = AdamW(param,lr=lr)



In [33]:
def qa_dataset(dataset,qa=True):
    input_ids,attention_mask = [],[]
    labels = []
    for point in dataset:
        if qa is True:
            q,a,_ = point
        else:
            a,q,_ = point
        inputs = tokenizer.encode_plus(q,a,
                             return_tensors="pt",
                             max_length=128,
                             truncation=True,
                             padding="max_length")
        input_ids.append(inputs["input_ids"])
        attention_mask.append(inputs["attention_mask"])
        labels.append(point[-1])
    input_ids = torch.cat(input_ids,dim=0)
    attention_mask = torch.cat(attention_mask,dim=0)
    return input_ids,attention_mask,labels
        

In [34]:
class dataset_pytorch(Dataset):
    def __init__(self,input_ids,attention_mask,labels=None):
        self.input_ids = np.array(input_ids)
        self.attention_mask = np.array(attention_mask)
        self.labels = torch.tensor(labels,dtype=torch.long)
    def __getitem__(self,idx):
        return self.input_ids[idx],self.attention_mask[idx],self.labels[idx]
    def __len__(self):
        return self.input_ids.shape[0]

In [35]:
input_ids_qa,attention_mask_qa,label_qa = qa_dataset(datapoints)
data_set_qa = dataset_pytorch(input_ids_qa,attention_mask_qa,label_qa)
dataloader_qa = DataLoader(dataset=data_set_qa,
                           batch_size=5,sampler=SequentialSampler(data_set_qa))

In [36]:
input_ids_aq,attention_mask_aq,labels_aq = qa_dataset(datapoints,qa=False)
dataset_aq = dataset_pytorch(input_ids_aq,attention_mask_aq,labels_aq)
dataloader_aq = DataLoader(dataset=dataset_aq,
                           batch_size=5,
                          sampler=SequentialSampler(dataset_aq))

In [51]:
epochs = 5
for i in range(epochs):
    for step, combined_data in enumerate(zip(dataloader_qa,dataloader_aq)):
        batch_1,batch_2 = combined_data
        model.train()
        inputs = {
            "input_ids" : [batch_1[0],batch_2[0]],
            "attention_mask" : [batch_1[1],batch_1[1]],
            "next_sentence_label" : batch_2[2]
        }
        outputs = model(**inputs)
        loss = outputs[0]
        print(f"epoch{i} :loss{loss}")
        loss.backward()
        optimizer.step()
        model.zero_grad()
    print("\n")

epoch0 :loss1.0444568395614624
epoch0 :loss0.3622767925262451


epoch1 :loss0.8111569285392761
epoch1 :loss0.55194491147995


epoch2 :loss0.43230146169662476
epoch2 :loss0.8224819898605347


epoch3 :loss0.34562668204307556
epoch3 :loss0.8164757490158081


epoch4 :loss0.38298362493515015
epoch4 :loss0.7245930433273315


