In [12]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler, RandomSampler


In [13]:
class MUTANT(nn.Module):

    def __init__(self, d_model=768, seq_len=16, dropout=0.1):
        super(MUTANT,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.token_type_embeddings = nn.Embedding(3, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # Scoreing heads
        self.head_passage = nn.Linear(d_model, 1)
        self.head_entity = nn.Linear(d_model, 1)

        
    def forward(self, input_CLSs, type_mask=None):
        # input_CLSs -> [seq_len, batch_size, d_model]
        # type_mask -> [seq_len, batch_size] 0 or 1 for different types
        
        if isinstance(type_mask, torch.Tensor):
            token_type_embeddings = self.token_type_embeddings(type_mask)
            input_CLSs = input_CLSs + token_type_embeddings 
        
        # Build padding masks i.e. type_mask == 0.
        #src_key_padding_mask = (type_mask > 0).type(torch.int).T
        src_key_padding_mask = (type_mask > 0).T
        
        # Forward pass of Transformer encoder.
        output_CLSs = self.transformer_encoder(input_CLSs, src_key_padding_mask=src_key_padding_mask)
        
        # Ensure Passage and Entity heads score correct mask type i.e. passage == 1 & entity = 2. 
        passage_mask = (type_mask == 1).type(torch.int).unsqueeze(-1)
        entity_mask = (type_mask == 2).type(torch.int).unsqueeze(-1)
        entity_output = self.head_entity(output_CLSs) * entity_mask
        passage_output = self.head_passage(output_CLSs) * passage_mask
                
        return passage_output+entity_output
    
    
    def get_device(self):
        return next(self.parameters()).device
    

In [6]:
model = MUTANT(d_model=10, seq_len=6, dropout=0.1)
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss
train_loss_total = 0.0

model.train()
bag_of_CLS = torch.rand(6, 3, 10) # [seq_len, batch_size, d_model]
type_mask = torch.tensor([[1,1,1],
                          [2,2,2],
                          [2,2,2],
                          [2,2,0],
                          [2,2,0],
                          [0,0,0]]) # [seq_len, batch_size]

labels = torch.tensor([[[1.0],[0.0],[1.0]],
                        [[0.0],[0.0],[0.0]],
                        [[1.0],[0.0],[1.0]],
                        [[0.0],[1.0],[0.0]],
                        [[0.0],[0.0],[0.0]],
                        [[0.0],[0.0],[0.0]]]) # [seq_len, batch_size]
for i in range(1000):
    # ========================================
    #               Training
    # ========================================
    model.zero_grad()
    outputs = model.forward(bag_of_CLS, type_mask=type_mask)

    # Calculate Loss: softmax --> cross entropy loss
    loss = loss_func(outputs, labels)
    # Getting gradients w.r.t. parameters
    loss.sum().backward()
    optimizer.step()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


    train_loss_total += loss.sum().item()
    
    if i % 1000 == 0:
#         print('--------')
        print(train_loss_total/(1+i))
#         print(labels)
#         print(outputs)


--------
0.6013029217720032
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[-0.0505],
         [-0.3043],
         [-0.2274]],

        [[ 1.0343],
         [ 1.1432],
         [ 1.4726]],

        [[ 1.1681],
         [ 0.8412],
         [ 1.2295]],

        [[ 1.1092],
         [ 0.5907],
         [-0.0000]],

        [[ 0.6793],
         [ 0.9632],
         [-0.0000]],

        [[-0.0000],
         [-0.0000],
         [-0.0000]]], grad_fn=<AddBackward0>)
--------
0.15750577594790902
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],


--------
0.14419755443487353
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5067],
         [0.4522],
         [0.9831]],

        [[0.2584],
         [0.2793],
         [0.4883]],

        [[0.2590],
         [0.2569],
         [0.4991]],

        [[0.2992],
         [0.2392],
         [0.0000]],

        [[0.2284],
         [0.2446],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.14399028665736935
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])


--------
0.1427907815633712
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5083],
         [0.5214],
         [0.7658]],

        [[0.2543],
         [0.2510],
         [0.2807]],

        [[0.2509],
         [0.2514],
         [0.4738]],

        [[0.2583],
         [0.2752],
         [0.0000]],

        [[0.2472],
         [0.2518],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.1427281411654267
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
te

--------
0.14215703297903437
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5267],
         [0.5371],
         [0.9138]],

        [[0.2556],
         [0.2743],
         [0.2550]],

        [[0.2634],
         [0.2527],
         [0.4867]],

        [[0.2510],
         [0.2495],
         [0.0000]],

        [[0.2742],
         [0.2531],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.14211680607199234
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])


--------
0.14177563286614153
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5122],
         [0.5071],
         [0.9838]],

        [[0.2511],
         [0.2515],
         [0.5253]],

        [[0.2670],
         [0.2786],
         [0.5359]],

        [[0.2615],
         [0.2482],
         [0.0000]],

        [[0.2546],
         [0.2527],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.14174384209567734
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])


--------
0.14153033669252768
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5053],
         [0.5251],
         [1.0078]],

        [[0.2480],
         [0.2602],
         [0.4826]],

        [[0.2467],
         [0.2821],
         [0.4825]],

        [[0.2600],
         [0.2668],
         [0.0000]],

        [[0.2512],
         [0.2455],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.1415117709104916
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
t

--------
0.14134115539783504
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5025],
         [0.4905],
         [0.9186]],

        [[0.2766],
         [0.2657],
         [0.4568]],

        [[0.2596],
         [0.2581],
         [0.4948]],

        [[0.2559],
         [0.2493],
         [0.0000]],

        [[0.2533],
         [0.2757],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.14133063229976
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
ten

--------
0.14119564436681298
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.5322],
         [0.5527],
         [0.9908]],

        [[0.2845],
         [0.2551],
         [0.4867]],

        [[0.2471],
         [0.2554],
         [0.4787]],

        [[0.2513],
         [0.2533],
         [0.0000]],

        [[0.2526],
         [0.2576],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.1411833509654159
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
t

--------
0.14109366323822836
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
tensor([[[0.4964],
         [0.4789],
         [1.0045]],

        [[0.2456],
         [0.2714],
         [0.4963]],

        [[0.2667],
         [0.2585],
         [0.4914]],

        [[0.2526],
         [0.2538],
         [0.0000]],

        [[0.2614],
         [0.2524],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000]]], grad_fn=<AddBackward0>)
--------
0.1410797950770159
tensor([[[1.],
         [0.],
         [1.]],

        [[0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])
t

In [None]:
import json

path = '/nfs/trec_news_track/data/5_fold/scaled_5fold_0_data/mutant_data/valid/0_mutant_max.json'

with open(path, 'r') as f:
    d = json.load(f)

In [None]:

bag_of_CLS = []
labels = []
type_mask = []
max_seq_len = 16
for passage_id in d['query']['passage'].keys():
    seq_cls = []
    seq_labels = []
    seq_mask = []
    
    print(passage_id)
    
    passage_cls = d['query']['passage'][passage_id]['cls_token']
    passage_relevant = d['query']['passage'][passage_id]['relevant']
    seq_cls.append(passage_cls)
    seq_labels.append([passage_relevant])
    seq_mask.append(1)

    for entity_id in d['query']['passage'][passage_id]['entity']:
        if len(seq_mask) < max_seq_len:
            entity_cls = d['query']['passage'][passage_id]['entity'][entity_id]['cls_token']
            entity_relevant = d['query']['passage'][passage_id]['entity'][entity_id]['relevant']
            seq_cls.append(entity_cls)
            seq_labels.append([entity_relevant])
            seq_mask.append(2)
        else:
            pass
            #print('not enough max_seq_leng for: {} - entity: {}, total ents: {}'.format(passage_id, entity_id, len(d['query']['passage'][passage_id]['entity'])))
            
    if len(seq_mask) < max_seq_len:
        padding_len = max_seq_len - len(type_mask)
        for i in range(padding_len):
            seq_cls.append([0]*768)
            seq_labels.append([0])
            seq_mask.append(0)
        
    bag_of_CLS.append(seq_cls) 
    labels.append(seq_labels)
    type_mask.append(seq_mask)
    
    
    break
    
bag_of_CLS_tensor = torch.tensor(bag_of_CLS)
type_mask_tensor = torch.tensor(type_mask)
labels_tensor = torch.tensor(labels)
print(bag_of_CLS_tensor.shape, type_mask_tensor.shape, labels_tensor.shape)

train_dataset = TensorDataset(bag_of_CLS_tensor, type_mask_tensor, labels_tensor)

train_data_loader = DataLoader(train_dataset, sampler=SequentialSampler(train_dataset), batch_size=8)


In [None]:
for i_train, train_batch in enumerate(train_data_loader):
    bag_of_CLS, type_mask, labels = train_batch
    
    model.zero_grad()
    
    outputs = model.forward(bag_of_CLS, type_mask=type_mask)

    # Calculate Loss: softmax --> cross entropy loss
    loss = loss_func(outputs, labels)
    # Getting gradients w.r.t. parameters
    loss.sum().backward()
    optimizer.step()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


    train_loss_total += loss.sum().item()
    
    if i % 1000 == 0:
#         print('--------')
        print(train_loss_total/(1+i))
#         print(labels)
#         print(outputs)


torch.Size([6, 3, 10])