In [1]:
import json

In [2]:
TRAIN_1_FILE = '../data/train/Task_1_train.jsonl'

In [4]:
with open(TRAIN_1_FILE,'r') as f:
    lines = [json.loads(line) for line in f.read().splitlines()]

In [6]:
example = lines[0]

In [8]:
from transformers import BertTokenizer
bt = BertTokenizer.from_pretrained('bert-base-uncased')

In [57]:
article_token = bt.encode(example['article'])[:512]

In [58]:
question_token = bt.encode(example['question'].replace('@placeholder','[MASK]'))

In [59]:
options_tokens = [bt.encode(example[f'option_{i}']) for i in range(5)]

In [60]:
import torch

In [61]:
article_token = torch.LongTensor([article_token])

In [62]:
question_token = torch.LongTensor([question_token])

In [63]:
question_token.shape

torch.Size([1, 26])

In [64]:
options_tokens = torch.LongTensor(options_tokens)

In [65]:
options_tokens.shape

torch.Size([5, 3])

In [66]:
from transformers import BertModel
m = BertModel.from_pretrained('bert-base-uncased')

In [67]:
m.embeddings(article_token).shape

torch.Size([1, 512, 768])

In [68]:
m.embeddings(question_token).shape

torch.Size([1, 26, 768])

In [69]:
m.embeddings(options_tokens).shape

torch.Size([5, 3, 768])

## GABert

In [70]:
m1 = BertModel.from_pretrained('bert-base-uncased')
m2 = BertModel.from_pretrained('bert-base-uncased')

In [71]:
e1 = m1.embeddings

In [72]:
layer_1_1 = m1.encoder.layer[0]
layer_2_1 = m2.encoder.layer[0]

In [76]:
article_embeds = e1(article_token)
print(article_embeds.shape)

torch.Size([1, 512, 768])


In [78]:
question_embeds = e1(question_token)
print(question_embeds.shape)

torch.Size([1, 26, 768])


In [83]:
options_embeds = e1(options_tokens)
print(options_tokensembeds.shape)

torch.Size([5, 3, 768])


In [109]:
class GatedAttention(torch.nn.Module):      
    def forward(self, question_states, article_states):
        question_att = question_states.permute(0,2,1)
        att_matrix = torch.bmm(article_states,question_att)
        
        att_weights = torch.nn.functional.softmax(att_matrix.view(-1,att_matrix.size(-1)),dim=1).view_as(att_matrix)
        question_rep = torch.bmm(att_weights, question_states)
    
        question_to_article = torch.mul(article_states, question_rep)
        
        return question_to_article ##Attention applied on articles    

In [110]:
ga = GatedAttention()

In [111]:
layer_1_1_out = layer_1_1(question_embeds)
layer_2_1_out = layer_2_1(article_embeds)

In [112]:
layer_1_1_out[0].shape

torch.Size([1, 26, 768])

In [113]:
layer_2_1_out[0].shape

torch.Size([1, 512, 768])

In [116]:
ga(layer_1_1_out[0],layer_2_1_out[0]).shape

torch.Size([1, 512, 768])

In [121]:
inp_1 = question_embeds
inp_2 = article_embeds

## Need to check attention_mask and layer_head_mask
for i in range(len(m1.encoder.layer)):
    current_layer_1 = m1.encoder.layer[i]
    current_layer_2 = m2.encoder.layer[i]
    inp_1 = current_layer_1(inp_1)[0]
    int_2 = current_layer_2(inp_2)[0]
    inp_2 = ga(inp_1,int_2)

In [122]:
inp_1.shape

torch.Size([1, 26, 768])

In [123]:
inp_2.shape

torch.Size([1, 512, 768])

In [124]:
inp_2

tensor([[[ 8.1008e-02,  8.4509e-02,  9.0809e-02,  ..., -2.0987e-02,
           2.0178e-02,  1.3246e-01],
         [ 7.2911e-02,  5.7018e-02,  7.9535e-02,  ..., -1.7414e-02,
           4.2224e-02,  1.2953e-01],
         [ 6.1251e-02,  2.1614e-02,  7.0671e-02,  ..., -8.3078e-03,
           6.3368e-02,  1.2328e-01],
         ...,
         [ 5.1315e-02, -4.3455e-01, -1.0332e-01,  ...,  2.9795e-02,
          -5.7426e-02, -1.7541e-04],
         [ 2.6858e-02, -4.4956e-01, -1.2668e-01,  ...,  5.7979e-02,
          -2.2074e-01, -4.7746e-02],
         [-1.1459e-01,  3.6559e-02, -1.0911e-01,  ..., -6.2565e-02,
          -1.8362e-01,  7.9108e-02]]], grad_fn=<MulBackward0>)

In [125]:
article_embeds

tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.4367,  0.5360, -0.0514,  ..., -0.0397,  0.6783, -0.5318],
         [ 0.7838, -0.3506, -1.1582,  ..., -0.8033,  0.1465,  0.2171],
         ...,
         [ 0.2022,  0.0762,  0.3220,  ...,  0.5130, -0.6300, -0.0597],
         [ 1.3539,  0.4626,  0.3129,  ..., -0.9238, -0.9422, -0.4833],
         [ 0.7480,  0.4874, -0.3261,  ..., -0.5679,  0.9606, -1.7922]]],
       grad_fn=<NativeLayerNormBackward>)

In [136]:
class GABert(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m1 = BertModel.from_pretrained('bert-base-uncased')
        self.m2 = BertModel.from_pretrained('bert-base-uncased')
        self.embeddings = self.m1.embeddings
        self.ga = GatedAttention()
        
        
    def forward(self, article_tokens, question_tokens, options_tokens, article_attention_mask=None, question_attention_mask=None):
        article_embeds = self.embeddings(article_tokens)
        question_embeds = self.embeddings(question_tokens)
        options_embeds = self.embeddings(options_tokens)
        
        question_contexts = question_embeds
        article_contexts = article_embeds
        
        for i in range(len(self.m1.encoder.layer)):
            current_layer_1 = self.m1.encoder.layer[i]
            current_layer_2 = self.m2.encoder.layer[i]
            question_contexts = current_layer_1(question_contexts, question_attention_mask)[0]
            article_intermediates = current_layer_2(article_contexts, article_attention_mask)[0]
            article_contexts = ga(question_contexts,article_intermediates)
        return question_contexts,article_contexts

In [137]:
gabert = GABert()

In [138]:
gabert(article_token,question_token,options_tokens)

(tensor([[[-0.2625,  0.1005,  0.5384,  ..., -0.2146,  0.1149,  0.2406],
          [ 0.4076, -0.0331,  0.6515,  ..., -0.4912, -0.0818, -0.2506],
          [ 1.0710, -0.0842,  0.7467,  ..., -0.4354, -0.1184, -0.0889],
          ...,
          [ 0.3348, -0.5068,  0.6650,  ...,  0.3952,  0.4215, -0.4632],
          [ 0.4878,  0.5904,  0.0101,  ..., -0.0534, -0.7088, -0.1635],
          [ 0.1715,  0.4857,  0.2637,  ...,  0.1968, -0.8782, -0.2744]]],
        grad_fn=<NativeLayerNormBackward>),
 tensor([[[ 8.1008e-02,  8.4509e-02,  9.0809e-02,  ..., -2.0987e-02,
            2.0178e-02,  1.3246e-01],
          [ 7.2911e-02,  5.7018e-02,  7.9535e-02,  ..., -1.7414e-02,
            4.2224e-02,  1.2953e-01],
          [ 6.1251e-02,  2.1614e-02,  7.0671e-02,  ..., -8.3078e-03,
            6.3368e-02,  1.2328e-01],
          ...,
          [ 5.1315e-02, -4.3455e-01, -1.0332e-01,  ...,  2.9795e-02,
           -5.7426e-02, -1.7541e-04],
          [ 2.6858e-02, -4.4956e-01, -1.2668e-01,  ...,  5.7979e