In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder

In [2]:
bag_of_CLS = torch.rand(16, 10, 768) # [seq_len, batch_size, d_model]
type_mask = torch.tensor([[1],
                          [2],
                          [2],
                          [2],
                          [2],
                          [2],
                          [2],
                          [2],
                          [2],
                          [0],
                          [0],
                          [0],
                          [0],
                          [0],
                          [0],
                          [0]]) # [seq_len, batch_size]
# type_mask = None

In [8]:
class MUTANT(nn.Module):

    def __init__(self, d_model=768, seq_len=16, dropout=0.1):
        super(MUTANT,self).__init__()
        self.seq_len = seq_len
        self.dropout = nn.Dropout(p=dropout)
        self.token_type_embeddings = nn.Embedding(3, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.linear_heads = []
        for i in range(seq_len):
            self.linear_heads.append(nn.Linear(d_model, 1))
                
    def forward_embed(self, input_CLSs, type_mask=None):
        # input_CLSs -> [seq_len, batch_size, d_model]
        # type_mask -> [seq_len, batch_size] 0 or 1 for different types
        input_CLSs_shape = input_CLSs.shape
        
        if isinstance(type_mask, torch.Tensor):
            token_type_embeddings = self.token_type_embeddings(type_mask)
            input_CLSs = input_CLSs + token_type_embeddings
            
        output_CLSs = self.transformer_encoder(input_CLSs)
        return output_CLSs
    
    def forward(self, input_CLSs, type_mask=None):
        # input_CLSs -> [seq_len, batch_size, d_model]
        # type_mask -> [seq_len, batch_size] 0 or 1 for different types
        seq_length = input_CLSs.shape[0]
        
        if isinstance(type_mask, torch.Tensor):
            token_type_embeddings = self.token_type_embeddings(type_mask)
            input_CLSs = input_CLSs + token_type_embeddings
            
        output_CLSs = self.transformer_encoder(input_CLSs)
        
        output_linear = []
        for i in range(self.seq_len):
            print(i)
            output_CLS = output_CLSs[i,:]
            print(output_CLS)
            print(self.linear_heads[i](output_CLS))
            output_linear.append(self.linear_heads[i](output_CLS))
        return output_linear
    
    def get_device(self):
        return next(self.parameters()).device

In [9]:
model = MUTANT(d_model=768)

In [12]:
output_linear = model.forward(bag_of_CLS, type_mask)

0
tensor([[-0.1964,  2.2750,  0.4229,  ..., -0.1228, -0.4160,  1.0900],
        [ 0.1467,  1.8875,  0.1617,  ...,  0.2106, -0.7010,  0.7633],
        [ 0.0795,  2.2577,  0.5681,  ...,  0.4870, -0.3145,  0.3927],
        ...,
        [-0.2176,  2.2198,  0.4763,  ..., -0.5688, -0.5419,  0.5890],
        [-0.3192,  2.3840,  0.3361,  ...,  0.4277, -1.2896,  0.8952],
        [-0.8053,  1.7312,  0.9045,  ...,  0.0305, -0.0812,  0.9816]],
       grad_fn=<SliceBackward>)
tensor([[ 0.2479],
        [ 0.0230],
        [ 0.0696],
        [ 0.0067],
        [ 0.2736],
        [-0.0230],
        [ 0.2690],
        [ 0.1101],
        [ 0.0873],
        [ 0.0277]], grad_fn=<AddmmBackward>)
1
tensor([[-0.6617,  0.8897, -0.0866,  ..., -0.1019,  0.2273,  0.5821],
        [-0.5505,  0.9639, -0.2937,  ..., -0.1924, -0.4754,  0.6994],
        [-0.2339,  0.5938, -0.2251,  ..., -0.0508, -0.1750,  0.5315],
        ...,
        [-1.0123,  0.8945,  0.4292,  ...,  0.5607,  0.6318,  0.6011],
        [-0.4662,  1.

In [14]:
len(output_linear)

16