# A transformer-based model based on "What does self-attention learn from Masked Language Modelling?" paper

### Core parts of the transformer
- Separated position and spin
- Single attention layer


### Outline
1. Vanilla attention implementation
2. Factored attention implementation


In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import tqdm
import random
import torch.nn.functional as F

In [43]:
X = torch.randn(100,1000,300)
idx = torch.randint(0,1000,(100,))
# create masked_X, Y=model(masked_X)
# X[:,idx,:]-Y[:,idx,:]



torch.Size([100, 100, 300])

In [4]:
data_train = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
data_test = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(100)])

def mask_random_spin(sequence, mask_token=2):
    """
    Mask one random spin in a sequence of protein spins.
    
    Parameters:
    - sequence: a list or sequence of spins (integers)
    - mask_token: the token used to mask a spin (default is 2)
    
    Returns:
    - masked_sequence: a sequence similar to the input but with one spin masked
    - masked_position: the position of the spin that was masked
    """
    # Ensure the sequence can be converted to a list for masking
    sequence_list = list(sequence)
    
    # Choose a random position to mask, excluding the first spin
    mask_position = random.randint(1, len(sequence_list) - 1)
    
    # Mask the chosen position
    masked_sequence = sequence_list.copy()
    masked_sequence[mask_position] = mask_token
    
    return masked_sequence, mask_position

In [32]:
class VanillaAttentionTransformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=2, dropout_rate=0.0):
        super(VanillaAttentionTransformer, self).__init__()
        self.word_embeddings = nn.Embedding(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a # parameter controlling how important are positions
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, num_spins) # output layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, s, masked_token):
        position_ids = torch.arange(len(s), dtype=torch.long)
        position_ids = position_ids
        x = self.word_embeddings(s) + self.position_embeddings(position_ids)
        query = self.query_weight(x)
        key = self.key_weight(x)

        values = self.value_weight(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)) # (batch_size,embed_dim)
        exp_scaling = torch.exp(self.word_embeddings(masked_token) + self.position_embeddings(position_ids).T@query.T@key@(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)))
        attn_output = torch.sum(exp_scaling/torch.sum(exp_scaling.sum(0))*values) # not sure it multiplies the way I want it to multiply - check
        output = self.dropout(self.fc(attn_output))
        return output
    
class FactoredAttentionTransformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=2, dropout_rate=0.0):
        super(FactoredAttentionTransformer, self).__init__()
        self.word_embeddings = nn.Embedding(num_spins, embed_dim)
        self.word_embedding.weight.requires_grad = False
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a # parameter controlling how important are positions
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, num_spins) # output layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, s, masked_token):
        # masked token should be equal to 0
        # masked_token = torch.tensor([0])
        position_ids = torch.arange(len(s), dtype=torch.long)
        position_ids = position_ids
        x = self.word_embeddings(s) + self.position_embeddings(position_ids)
        query = self.query_weight(x)
        key = self.key_weight(x)

        values = self.value_weight(self.word_embeddings(s)) # (embed_dim)
        exp_scaling = torch.exp(self.word_embeddings(masked_token) + self.position_embeddings(position_ids).T@query.T@key@(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)))
        attn_output = torch.sum(exp_scaling/torch.sum(exp_scaling.sum(0))*values) # not sure it multiplies the way I want it to multiply - check
        output = self.dropout(self.fc(attn_output))
        return output

In [33]:
def evaluate(model, data_test, criterion, device=0):
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq = data[i]
        # mask a token
        masked_sequence, masked_position = mask_random_spin(input_seq, mask_token=2)
        # Forward pass
        outputs = model.forward(masked_sequence, masked_sequence[masked_position])

        output_token = F.log_softmax(outputs, dim=-1)
        target_token = input_seq[masked_position]

        # Compute loss
        loss = criterion(output_token, target_token)
        epoch_loss += loss.item()
    #model.train()
    return epoch_loss / len(data_test)

def train(model, data_train, data_test, optimizer, criterion, num_epochs=8, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq = data
            print(input_seq)

            # mask a token
            masked_sequence, masked_position = mask_random_spin(input_seq, mask_token=2)
            # Forward pass
            prediction = model.forward(masked_sequence, masked_sequence[masked_position])
            
            predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = input_seq[masked_position]

            # Compute loss
            loss = criterion(predicted_token, target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        eval_loss = evaluate(model, data_test, criterion, device=device)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)



In [34]:
# Define the parameters 
vocab_size = 2
L = 20
embedding_dim = 20
hidden_dim = 20
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [35]:
# Example usage:
model = VanillaAttentionTransformer(embed_dim=embedding_dim, a=1, max_seq_length=L, num_spins=2, dropout_rate=dropout_rate)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, data_train, data_test, optimizer, criterion, device=device)
#torch.save(model.state_dict(), 'models/lstm_scratch.pt')
#evaluate(model, test_dataloader, criterion, device=device)

  0%|          | 0/1000 [00:00<?, ?it/s]
                                     

[-1  1  1  1 -1 -1 -1 -1  1  1 -1 -1  1  1 -1  1 -1  1  1 -1]




TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list

# Questions
- Do we pass through the sequence multiple times, each time masking different tokens or do we pass through the sequence once masking one random token?
- Maybe we evaluate on how well the model predicts the other token in the sequence?