# EVALUATE THE ATTENTION MODEL

## 1. Get the model to be evaluated

In [None]:
class MyAttentionModelWithWordPosition(nn.Module):
    """My Attention model, based on the Transformer encoder."""

    def __init__(self, vocab_size, embedding_dim, max_seq_len, num_heads, dim_feedforward, num_layers, dropout=0.5):
        super(MyAttentionModelWithWordPosition, self).__init__()
        try:
            from torch.nn import TransformerEncoder, TransformerEncoderLayer
        except:
            raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.')
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout=0.1, max_len=max_seq_len)
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward, dropout)
        # output shape (batch_size, max_seq_len, embedding_dim)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) 
        self.encoder = nn.Embedding(vocab_size, embedding_dim)
        self.embedding_dim = embedding_dim
        self.pooler = nn.AvgPool1d(max_seq_len, stride=1)
        self.decoder = nn.Linear(embedding_dim, 1)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src, word_position_mask, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        #print(f"Input shape: {src.shape}")
        src = self.encoder(src) * math.sqrt(self.embedding_dim)
        #print(f"Embedded shape: {src.shape}")
        src = self.pos_encoder(src)
        #print(f"Positional encoding shape: {src.shape}")
        ## HERE: change src such that we only take the sequence for the token at WORD POSITION"
        src = torch.masked_select(src, word_position_mask)
        output = self.transformer_encoder(src, self.src_mask)
        #print(f"Transformer encoder output shape: {output.shape}")
        # (batch_size, n_tokens, emb_dim) -> (batch_size, 1, emb_dim)
        output = self.pooler(output.permute(0,2,1))
        #print(f"Pooled output shape: {output.shape}")
        output = self.decoder(output.view(-1,output.shape[1]))
        #print(f"Decoder output shape: {output.shape}")
        #print(h)
        return F.log_softmax(output, dim=-1)

In [None]:
model = MyAttentionModelWithWordPosition(vocab_size=len(voc), embedding_dim=32, max_seq_len=max_sequence_length, num_heads=4, dim_feedforward=16, num_layers=1, dropout=0.1)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

## 2. Define an evaluation method

In [None]:
def model_evaluator_for_sense()

## 3. Evaluate the model