In [None]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder

In [None]:
class MUTANT(nn.Module):

    def __init__(self, d_model=768, seq_len=16, dropout=0.1):
        super(MUTANT,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.token_type_embeddings = nn.Embedding(3, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # Scoreing heads
        self.head_passage = nn.Linear(d_model, 1)
        self.head_entity = nn.Linear(d_model, 1)

        
    def forward(self, input_CLSs, type_mask=None):
        # input_CLSs -> [seq_len, batch_size, d_model]
        # type_mask -> [seq_len, batch_size] 0 or 1 for different types
        
        if isinstance(type_mask, torch.Tensor):
            token_type_embeddings = self.token_type_embeddings(type_mask)
            input_CLSs = input_CLSs + token_type_embeddings 
        
        # Build padding masks i.e. type_mask == 0.
        src_key_padding_mask = (type_mask > 0).type(torch.int).T
        
        # Forward pass of Transformer encoder.
        output_CLSs = self.transformer_encoder(input_CLSs, src_key_padding_mask=src_key_padding_mask)
        
        # Ensure Passage and Entity heads score correct mask type i.e. passage == 1 & entity = 2. 
        passage_mask = (type_mask == 1).type(torch.int).unsqueeze(-1)
        entity_mask = (type_mask == 2).type(torch.int).unsqueeze(-1)
        entity_output = self.head_entity(output_CLSs) * entity_mask
        passage_output = self.head_passage(output_CLSs) * passage_mask
                
        return passage_output+entity_output
    
    
    def get_device(self):
        return next(self.parameters()).device

In [None]:
model = MUTANT(d_model=10, seq_len=6, dropout=0.1)
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss
train_loss_total = 0.0

model.train()
for i in range(100000):
    bag_of_CLS = torch.rand(6, 3, 10) # [seq_len, batch_size, d_model]
    type_mask = torch.tensor([[1,1,1],
                              [2,2,2],
                              [2,2,2],
                              [2,2,0],
                              [2,2,0],
                              [0,0,0]]) # [seq_len, batch_size]

    labels = torch.tensor([[[1.0],[0.0],[1.0]],
                            [[0.0],[0.0],[0.0]],
                            [[1.0],[0.0],[1.0]],
                            [[0.0],[1.0],[0.0]],
                            [[0.0],[0.0],[0.0]],
                            [[0.0],[0.0],[0.0]]]) # [seq_len, batch_size]

    # ========================================
    #               Training
    # ========================================
    model.zero_grad()
    outputs = model.forward(bag_of_CLS, type_mask=type_mask)

    # Calculate Loss: softmax --> cross entropy loss
    loss = loss_func(outputs, labels)
    # Getting gradients w.r.t. parameters
    loss.sum().backward()
    optimizer.step()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


    train_loss_total += loss.sum().item()
    
    if i % 1000 == 0:
        print('--------')
        print(train_loss_total/(1+i))
        print(labels)
        print(outputs)
