In [12]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler, RandomSampler


In [13]:
class MUTANT(nn.Module):

    def __init__(self, d_model=768, seq_len=16, dropout=0.1):
        super(MUTANT,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.token_type_embeddings = nn.Embedding(3, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # Scoreing heads
        self.head_passage = nn.Linear(d_model, 1)
        self.head_entity = nn.Linear(d_model, 1)

        
    def forward(self, input_CLSs, type_mask=None):
        # input_CLSs -> [seq_len, batch_size, d_model]
        # type_mask -> [seq_len, batch_size] 0 or 1 for different types
        
        if isinstance(type_mask, torch.Tensor):
            token_type_embeddings = self.token_type_embeddings(type_mask)
            print('----- token_type_embeddings -----')

            print(token_type_embeddings)
            input_CLSs = input_CLSs + token_type_embeddings 
            print('----- input_CLSs -----')

            print(input_CLSs)
        
        # Build padding masks i.e. type_mask == 0.
        #src_key_padding_mask = (type_mask > 0).type(torch.int).T
        src_key_padding_mask = (type_mask > 0).T
        print('----- src_key_padding_mask -----')
        
        # Forward pass of Transformer encoder.
        output_CLSs = self.transformer_encoder(input_CLSs, src_key_padding_mask=src_key_padding_mask)
        print('----- output_CLSs -----')

        print(output_CLSs)
        
        # Ensure Passage and Entity heads score correct mask type i.e. passage == 1 & entity = 2. 
        passage_mask = (type_mask == 1).type(torch.int).unsqueeze(-1)
        entity_mask = (type_mask == 2).type(torch.int).unsqueeze(-1)
        entity_output = self.head_entity(output_CLSs) * entity_mask
        passage_output = self.head_passage(output_CLSs) * passage_mask
        
        output = passage_output+entity_output
        print('----- output -----')
        print(output)
        return output
    
    
    def get_device(self):
        return next(self.parameters()).device
    

In [14]:
# model = MUTANT(d_model=10, seq_len=6, dropout=0.1)
# lr = 0.001

# optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss
# train_loss_total = 0.0

# model.train()
# bag_of_CLS = torch.rand(6, 3, 10) # [seq_len, batch_size, d_model]
# type_mask = torch.tensor([[1,1,1],
#                           [2,2,2],
#                           [2,2,2],
#                           [2,2,0],
#                           [2,2,0],
#                           [0,0,0]]) # [seq_len, batch_size]

# labels = torch.tensor([[[1.0],[0.0],[1.0]],
#                         [[0.0],[0.0],[0.0]],
#                         [[1.0],[0.0],[1.0]],
#                         [[0.0],[1.0],[0.0]],
#                         [[0.0],[0.0],[0.0]],
#                         [[0.0],[0.0],[0.0]]]) # [seq_len, batch_size]
# for i in range(100):
#     # ========================================
#     #               Training
#     # ========================================
#     model.zero_grad()
#     outputs = model.forward(bag_of_CLS, type_mask=type_mask)

#     # Calculate Loss: softmax --> cross entropy loss
#     loss = loss_func(outputs, labels)
#     # Getting gradients w.r.t. parameters
#     loss.sum().backward()
#     optimizer.step()

#     torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


#     train_loss_total += loss.sum().item()
    
#     if i % 10 == 0:
#         print('--------')
#         print(train_loss_total/(1+i))
#         print(labels)
#         print(outputs)


0.7456469535827637
0.2570289806886153
0.20149808660859153
0.1798853624251581
0.16487641160081073
0.15623908533769496
0.14912645262284357
0.1404143590217745
0.13302875659715982
0.12317588700206725


In [None]:
import json

path = '/nfs/trec_news_track/data/5_fold/scaled_5fold_0_data/mutant_data/valid/0_mutant_max.json'

with open(path, 'r') as f:
    d = json.load(f)

In [None]:

bag_of_CLS = []
labels = []
type_mask = []
max_seq_len = 16
for passage_id in d['query']['passage'].keys():
    seq_cls = []
    seq_labels = []
    seq_mask = []
    
    print(passage_id)
    
    passage_cls = d['query']['passage'][passage_id]['cls_token']
    passage_relevant = d['query']['passage'][passage_id]['relevant']
    seq_cls.append(passage_cls)
    seq_labels.append([passage_relevant])
    seq_mask.append(1)

    for entity_id in d['query']['passage'][passage_id]['entity']:
        if len(seq_mask) < max_seq_len:
            entity_cls = d['query']['passage'][passage_id]['entity'][entity_id]['cls_token']
            entity_relevant = d['query']['passage'][passage_id]['entity'][entity_id]['relevant']
            seq_cls.append(entity_cls)
            seq_labels.append([entity_relevant])
            seq_mask.append(2)
        else:
            pass
            #print('not enough max_seq_leng for: {} - entity: {}, total ents: {}'.format(passage_id, entity_id, len(d['query']['passage'][passage_id]['entity'])))
            
    if len(seq_mask) < max_seq_len:
        padding_len = max_seq_len - len(seq_mask)
        for i in range(padding_len):
            seq_cls.append([0]*768)
            seq_labels.append([0])
            seq_mask.append(0)
        
    bag_of_CLS.append(seq_cls) 
    labels.append(seq_labels)
    type_mask.append(seq_mask)
    
        
bag_of_CLS_tensor = torch.tensor(bag_of_CLS)
type_mask_tensor = torch.tensor(type_mask)
labels_tensor = torch.tensor(labels)
print(bag_of_CLS_tensor.shape, type_mask_tensor.shape, labels_tensor.shape)

train_dataset = TensorDataset(bag_of_CLS_tensor, type_mask_tensor, labels_tensor)

train_data_loader = DataLoader(train_dataset, sampler=SequentialSampler(train_dataset), batch_size=8)


In [None]:

model = MUTANT(d_model=768, seq_len=16, dropout=0.1)
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss
train_loss_total = 0.0

model.train()

for i in range(1):
    train_loss_total = 0
    for i_train, train_batch in enumerate(train_data_loader):
        print('-------------------------------------')
        print('-------------------------------------')
        print('-------------------------------------')
       
        bag_of_CLS, type_mask, labels = train_batch
        bag_of_CLS = bag_of_CLS.view(16,1,768)
        type_mask = type_mask.view(16,1)
        labels = labels.view(16,1,1)
        
        print('----- batch -----')
        print(bag_of_CLS.shape)
        print(bag_of_CLS)
        print(type_mask.shape)
        print(type_mask)  
        print(labels.shape)
        print(labels)

        model.zero_grad()

        outputs = model.forward(bag_of_CLS, type_mask=type_mask)
#         print(outputs)

        # Calculate Loss: softmax --> cross entropy loss
        loss = loss_func(outputs, labels)
        # Getting gradients w.r.t. parameters
        loss.sum().backward()
        optimizer.step()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        train_loss_total += loss.sum().item()
        break

    print(train_loss_total / len(train_data_loader))