# A transformer-based model based on "What does self-attention learn from Masked Language Modelling?" paper

### Core parts of the transformer
- Separated position and spin
- Single attention layer


### Outline
1. Vanilla attention implementation
2. Factored attention implementation


In [320]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import tqdm
import random
import torch.nn.functional as F

In [321]:
X = torch.randn(100,1000,300)
idx = torch.randint(0,1000,(100,))
# create masked_X, Y=model(masked_X)
# X[:,idx,:]-Y[:,idx,:]



In [322]:
data_train = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
data_test = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(100)])

def one_hot_encoding(seq, vocab):
    if isinstance(seq, int):
        one_hot = torch.zeros(1, len(vocab), dtype=int)
        one_hot[:, vocab[seq]] = 1
    else:
        one_hot = torch.zeros((len(seq), len(vocab)), dtype=int)
        for i, spin in enumerate(seq):
            one_hot[i, vocab[spin]] = 1
    return one_hot

def mask_random_spin(sequence, mask_token=2):
    """
    Mask one random spin in a sequence of protein spins.
    
    Parameters:
    - sequence: a list or sequence of spins (integers)
    - mask_token: the token used to mask a spin (default is 2)
    
    Returns:
    - masked_sequence: a sequence similar to the input but with one spin masked
    - masked_position: the position of the spin that was masked
    """
    # define vocabulary
    vocab = {-1:0,1:1,2:2}
    # Ensure the sequence can be converted to a list for masking
    sequence_list = sequence.numpy().tolist() if isinstance(sequence, torch.Tensor) else list(sequence)
    
    # Choose a random position to mask, excluding the first spin
    mask_position = random.randint(1, len(sequence_list) - 1)
    
    # Mask the chosen position
    masked_sequence = sequence_list.copy()
    masked_sequence[mask_position] = mask_token

    # Create an array of zeros with shape (len(sequence), len(vocab))
    #one_hot = np.zeros((len(sequence_list), len(vocab)), dtype=int)
    #for i, spin in enumerate(masked_sequence):
    #    one_hot[i, vocab[spin]] = 1
    one_hot = one_hot_encoding(masked_sequence, vocab)
    # Display the one-hot encoding of the masked spin
    #mask_encoding = np.zeros((1, len(vocab)), dtype=int)
    #mask_encoding[0, vocab[2]] = 1
        
    return torch.tensor(one_hot), torch.tensor(mask_position)
    

In [327]:
class VanillaAttentionTransformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(VanillaAttentionTransformer, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a  # parameter controlling how important positions are
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, num_spins)  # output layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, s, masked_token):
        #if not isinstance(s, torch.Tensor):
        s = torch.tensor(s, dtype=torch.float)
        #if not isinstance(masked_token, torch.Tensor):
        masked_token = torch.tensor(masked_token, dtype=torch.float)

        position_ids = torch.arange(s.size(0), dtype=torch.long)
        x = self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)
        
        query = self.query_weight(x)
        key = self.key_weight(x)
        values = self.value_weight(x)
        
        # Simple attention score calculation (Dot product): this is equivalent to the interaction matrix
        scores = torch.matmul(query, key.transpose(-2, -1))  # Transpose last two dimensions for matrix multiplication
        scores = torch.softmax(scores, dim=-1)  # Apply softmax to scores to get probabilities

        # Apply attention scores to values
        attn_output = torch.matmul(scores, values)

        # Sum over the sequence length dimensions
        #print("attention before summing:", attn_output.shape)
        attn_output = attn_output.sum(dim=0)
        #print("attention after summing:", attn_output.shape)
        output = self.fc(self.dropout(attn_output))

        return output


In [328]:
def evaluate(model, data_test, vocab, criterion, device=0):
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm.tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq = data
        input_one_hot = one_hot_encoding(input_seq.tolist(), vocab)
        # mask a token
        masked_sequence, position = mask_random_spin(input_seq, mask_token=2)
        # Forward pass
        outputs = model.forward(masked_sequence, masked_sequence[position])

        #output_token = F.log_softmax(outputs, dim=-1)
        target_token = torch.argwhere(input_one_hot[position]==1).squeeze(0) #target_token = input_seq[position]

        # Compute loss
        loss = criterion(outputs.unsqueeze(0), target_token)
        epoch_loss += loss.item()
    #model.train()
    return epoch_loss / len(data_test)

def train(model, data_train, data_test, vocab, optimizer, criterion, num_epochs=8, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq = data
            input_one_hot = one_hot_encoding(input_seq.tolist(), vocab)

            # mask a token
            masked_sequence, position = mask_random_spin(input_seq, mask_token=2)

            # Forward pass
            prediction = model.forward(masked_sequence, masked_sequence[position]) #masked_sequence[masked_position]
            
            #predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = torch.argwhere(input_one_hot[position]==1).squeeze(0) #input_seq[masked_position]
            # Compute loss
            #print("predicted_token:", predicted_token)
            #print("predicted_token:", prediction)
            #print("target token:", target_token)
            loss = criterion(prediction.unsqueeze(0), target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        eval_loss = evaluate(model, data_test, vocab, criterion, device=device)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)



In [329]:
# Define the parameters 
vocab_size = 3
vocab = {-1:0,1:1,2:2} 
L = 20
embedding_dim = 20
hidden_dim = 20
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [330]:
# Example usage:
model = VanillaAttentionTransformer(embed_dim=embedding_dim, a=0, max_seq_length=L, num_spins=3, dropout_rate=dropout_rate)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, data_train, data_test, vocab, optimizer, criterion, device=device)
#torch.save(model.state_dict(), 'models/lstm_scratch.pt')
#evaluate(model, test_dataloader, criterion, device=device)

  return torch.tensor(one_hot), torch.tensor(mask_position)
  s = torch.tensor(s, dtype=torch.float)
  masked_token = torch.tensor(masked_token, dtype=torch.float)


[1,    10] loss: 0.136
[1,    20] loss: 0.089
[1,    30] loss: 0.080
[1,    40] loss: 0.056
[1,    50] loss: 0.092
[1,    60] loss: 0.068
[1,    70] loss: 0.073
[1,    80] loss: 0.071
[1,    90] loss: 0.077
[1,   100] loss: 0.067
[1,   110] loss: 0.080
[1,   120] loss: 0.065
[1,   130] loss: 0.078
[1,   140] loss: 0.066
[1,   150] loss: 0.075
[1,   160] loss: 0.061
[1,   170] loss: 0.082
[1,   180] loss: 0.079
[1,   190] loss: 0.071
[1,   200] loss: 0.073
[1,   210] loss: 0.066
[1,   220] loss: 0.085
[1,   230] loss: 0.090
[1,   240] loss: 0.081
[1,   250] loss: 0.086
[1,   260] loss: 0.065
[1,   270] loss: 0.072
[1,   280] loss: 0.070
[1,   290] loss: 0.070
[1,   300] loss: 0.066
[1,   310] loss: 0.070
[1,   320] loss: 0.074
[1,   330] loss: 0.072
[1,   340] loss: 0.050
[1,   350] loss: 0.101
[1,   360] loss: 0.062
[1,   370] loss: 0.073
[1,   380] loss: 0.067
[1,   390] loss: 0.085




[1,   400] loss: 0.079
[1,   410] loss: 0.072
[1,   420] loss: 0.072
[1,   430] loss: 0.069
[1,   440] loss: 0.069
[1,   450] loss: 0.086
[1,   460] loss: 0.075
[1,   470] loss: 0.069
[1,   480] loss: 0.094
[1,   490] loss: 0.071
[1,   500] loss: 0.089
[1,   510] loss: 0.083
[1,   520] loss: 0.047
[1,   530] loss: 0.081
[1,   540] loss: 0.078
[1,   550] loss: 0.063
[1,   560] loss: 0.073
[1,   570] loss: 0.112
[1,   580] loss: 0.082
[1,   590] loss: 0.079
[1,   600] loss: 0.073
[1,   610] loss: 0.075
[1,   620] loss: 0.101
[1,   630] loss: 0.075
[1,   640] loss: 0.079
[1,   650] loss: 0.063
[1,   660] loss: 0.093
[1,   670] loss: 0.069
[1,   680] loss: 0.059
[1,   690] loss: 0.105
[1,   700] loss: 0.070
[1,   710] loss: 0.056
[1,   720] loss: 0.087
[1,   730] loss: 0.075
[1,   740] loss: 0.070
[1,   750] loss: 0.072
[1,   760] loss: 0.072
[1,   770] loss: 0.069
[1,   780] loss: 0.067
[1,   790] loss: 0.083
[1,   800] loss: 0.067


100%|██████████| 1000/1000 [00:00<00:00, 1911.70it/s]


[1,   810] loss: 0.070
[1,   820] loss: 0.071
[1,   830] loss: 0.072
[1,   840] loss: 0.067
[1,   850] loss: 0.073
[1,   860] loss: 0.073
[1,   870] loss: 0.065
[1,   880] loss: 0.072
[1,   890] loss: 0.069
[1,   900] loss: 0.075
[1,   910] loss: 0.065
[1,   920] loss: 0.080
[1,   930] loss: 0.079
[1,   940] loss: 0.076
[1,   950] loss: 0.069
[1,   960] loss: 0.084
[1,   970] loss: 0.069
[1,   980] loss: 0.072
[1,   990] loss: 0.077
[1,  1000] loss: 0.065
Epoch 1 | Train Loss: 37.5421


100%|██████████| 100/100 [00:00<00:00, 4792.67it/s]
 12%|█▎        | 1/8 [00:00<00:03,  1.82it/s]

Epoch 1 | Eval Loss: 0.7750




[2,    10] loss: 0.061
[2,    20] loss: 0.079
[2,    30] loss: 0.073
[2,    40] loss: 0.072
[2,    50] loss: 0.072
[2,    60] loss: 0.066
[2,    70] loss: 0.084
[2,    80] loss: 0.070
[2,    90] loss: 0.070
[2,   100] loss: 0.075
[2,   110] loss: 0.068
[2,   120] loss: 0.078
[2,   130] loss: 0.067
[2,   140] loss: 0.073




[2,   150] loss: 0.072
[2,   160] loss: 0.068
[2,   170] loss: 0.060
[2,   180] loss: 0.081
[2,   190] loss: 0.086
[2,   200] loss: 0.082
[2,   210] loss: 0.069
[2,   220] loss: 0.071
[2,   230] loss: 0.074
[2,   240] loss: 0.070
[2,   250] loss: 0.066
[2,   260] loss: 0.069
[2,   270] loss: 0.081
[2,   280] loss: 0.062
[2,   290] loss: 0.053
[2,   300] loss: 0.090
[2,   310] loss: 0.043
[2,   320] loss: 0.092
[2,   330] loss: 0.075
[2,   340] loss: 0.064
[2,   350] loss: 0.075
[2,   360] loss: 0.066




[2,   370] loss: 0.061
[2,   380] loss: 0.081
[2,   390] loss: 0.067
[2,   400] loss: 0.072
[2,   410] loss: 0.069
[2,   420] loss: 0.074
[2,   430] loss: 0.064
[2,   440] loss: 0.045
[2,   450] loss: 0.099
[2,   460] loss: 0.074
[2,   470] loss: 0.065
[2,   480] loss: 0.071
[2,   490] loss: 0.058
[2,   500] loss: 0.054
[2,   510] loss: 0.113




[2,   520] loss: 0.072
[2,   530] loss: 0.079
[2,   540] loss: 0.074
[2,   550] loss: 0.073
[2,   560] loss: 0.067
[2,   570] loss: 0.082
[2,   580] loss: 0.074
[2,   590] loss: 0.070
[2,   600] loss: 0.062
[2,   610] loss: 0.082
[2,   620] loss: 0.070
[2,   630] loss: 0.067
[2,   640] loss: 0.078
[2,   650] loss: 0.064
[2,   660] loss: 0.066
[2,   670] loss: 0.081
[2,   680] loss: 0.090
[2,   690] loss: 0.072
[2,   700] loss: 0.069
[2,   710] loss: 0.063
[2,   720] loss: 0.089
[2,   730] loss: 0.072
[2,   740] loss: 0.075




[2,   750] loss: 0.074
[2,   760] loss: 0.064
[2,   770] loss: 0.074
[2,   780] loss: 0.062
[2,   790] loss: 0.079
[2,   800] loss: 0.071
[2,   810] loss: 0.072
[2,   820] loss: 0.070
[2,   830] loss: 0.071
[2,   840] loss: 0.072
[2,   850] loss: 0.070
[2,   860] loss: 0.071
[2,   870] loss: 0.073
[2,   880] loss: 0.071
[2,   890] loss: 0.071


100%|██████████| 1000/1000 [00:00<00:00, 1836.74it/s]


[2,   900] loss: 0.072
[2,   910] loss: 0.075
[2,   920] loss: 0.069
[2,   930] loss: 0.071
[2,   940] loss: 0.071
[2,   950] loss: 0.071
[2,   960] loss: 0.066
[2,   970] loss: 0.079
[2,   980] loss: 0.072
[2,   990] loss: 0.065
[2,  1000] loss: 0.078
Epoch 2 | Train Loss: 35.9532


100%|██████████| 100/100 [00:00<00:00, 4618.06it/s]
 25%|██▌       | 2/8 [00:01<00:03,  1.78it/s]

Epoch 2 | Eval Loss: 0.7441




[3,    10] loss: 0.076
[3,    20] loss: 0.070
[3,    30] loss: 0.068
[3,    40] loss: 0.075
[3,    50] loss: 0.068
[3,    60] loss: 0.072




[3,    70] loss: 0.070
[3,    80] loss: 0.071
[3,    90] loss: 0.070
[3,   100] loss: 0.068
[3,   110] loss: 0.074
[3,   120] loss: 0.073
[3,   130] loss: 0.072
[3,   140] loss: 0.069
[3,   150] loss: 0.074
[3,   160] loss: 0.051
[3,   170] loss: 0.064
[3,   180] loss: 0.075
[3,   190] loss: 0.086
[3,   200] loss: 0.057
[3,   210] loss: 0.068
[3,   220] loss: 0.080
[3,   230] loss: 0.069
[3,   240] loss: 0.069
[3,   250] loss: 0.070
[3,   260] loss: 0.077
[3,   270] loss: 0.077
[3,   280] loss: 0.069
[3,   290] loss: 0.076
[3,   300] loss: 0.069
[3,   310] loss: 0.068
[3,   320] loss: 0.068
[3,   330] loss: 0.069
[3,   340] loss: 0.076




[3,   350] loss: 0.065
[3,   360] loss: 0.074
[3,   370] loss: 0.070
[3,   380] loss: 0.068
[3,   390] loss: 0.072
[3,   400] loss: 0.071
[3,   410] loss: 0.072
[3,   420] loss: 0.066
[3,   430] loss: 0.072
[3,   440] loss: 0.070




[3,   450] loss: 0.076
[3,   460] loss: 0.066
[3,   470] loss: 0.078
[3,   480] loss: 0.071
[3,   490] loss: 0.075
[3,   500] loss: 0.062
[3,   510] loss: 0.072
[3,   520] loss: 0.094
[3,   530] loss: 0.068
[3,   540] loss: 0.070
[3,   550] loss: 0.072
[3,   560] loss: 0.062
[3,   570] loss: 0.080




[3,   580] loss: 0.071
[3,   590] loss: 0.067
[3,   600] loss: 0.060
[3,   610] loss: 0.094
[3,   620] loss: 0.065
[3,   630] loss: 0.070
[3,   640] loss: 0.073
[3,   650] loss: 0.071
[3,   660] loss: 0.065
[3,   670] loss: 0.078
[3,   680] loss: 0.072
[3,   690] loss: 0.072
[3,   700] loss: 0.069
[3,   710] loss: 0.073
[3,   720] loss: 0.071
[3,   730] loss: 0.070
[3,   740] loss: 0.068
[3,   750] loss: 0.062
[3,   760] loss: 0.077
[3,   770] loss: 0.082
[3,   780] loss: 0.067
[3,   790] loss: 0.076
[3,   800] loss: 0.065
[3,   810] loss: 0.072




[3,   820] loss: 0.074
[3,   830] loss: 0.069
[3,   840] loss: 0.072
[3,   850] loss: 0.069
[3,   860] loss: 0.068
[3,   870] loss: 0.076
[3,   880] loss: 0.072
[3,   890] loss: 0.070
[3,   900] loss: 0.071
[3,   910] loss: 0.053
[3,   920] loss: 0.078
[3,   930] loss: 0.076
[3,   940] loss: 0.067
[3,   950] loss: 0.072


100%|██████████| 1000/1000 [00:00<00:00, 1802.64it/s]


[3,   960] loss: 0.068
[3,   970] loss: 0.076
[3,   980] loss: 0.080
[3,   990] loss: 0.072
[3,  1000] loss: 0.069
Epoch 3 | Train Loss: 35.5575


100%|██████████| 100/100 [00:00<00:00, 4724.33it/s]
 38%|███▊      | 3/8 [00:01<00:02,  1.76it/s]

Epoch 3 | Eval Loss: 0.6943




[4,    10] loss: 0.072
[4,    20] loss: 0.071
[4,    30] loss: 0.071
[4,    40] loss: 0.063
[4,    50] loss: 0.082
[4,    60] loss: 0.066
[4,    70] loss: 0.073
[4,    80] loss: 0.069
[4,    90] loss: 0.070
[4,   100] loss: 0.075
[4,   110] loss: 0.068
[4,   120] loss: 0.072
[4,   130] loss: 0.073
[4,   140] loss: 0.065




[4,   150] loss: 0.067
[4,   160] loss: 0.070
[4,   170] loss: 0.068
[4,   180] loss: 0.068
[4,   190] loss: 0.070
[4,   200] loss: 0.065
[4,   210] loss: 0.066
[4,   220] loss: 0.074
[4,   230] loss: 0.071
[4,   240] loss: 0.072
[4,   250] loss: 0.071
[4,   260] loss: 0.070
[4,   270] loss: 0.072
[4,   280] loss: 0.072
[4,   290] loss: 0.070
[4,   300] loss: 0.066
[4,   310] loss: 0.081
[4,   320] loss: 0.071
[4,   330] loss: 0.070




[4,   340] loss: 0.071
[4,   350] loss: 0.069
[4,   360] loss: 0.070
[4,   370] loss: 0.070
[4,   380] loss: 0.066
[4,   390] loss: 0.073
[4,   400] loss: 0.072
[4,   410] loss: 0.070
[4,   420] loss: 0.063
[4,   430] loss: 0.079
[4,   440] loss: 0.061
[4,   450] loss: 0.053
[4,   460] loss: 0.068




[4,   470] loss: 0.053
[4,   480] loss: 0.089
[4,   490] loss: 0.068
[4,   500] loss: 0.082
[4,   510] loss: 0.069
[4,   520] loss: 0.067
[4,   530] loss: 0.075
[4,   540] loss: 0.072
[4,   550] loss: 0.070
[4,   560] loss: 0.073
[4,   570] loss: 0.066
[4,   580] loss: 0.068
[4,   590] loss: 0.072
[4,   600] loss: 0.071
[4,   610] loss: 0.071
[4,   620] loss: 0.068
[4,   630] loss: 0.070
[4,   640] loss: 0.065
[4,   650] loss: 0.074
[4,   660] loss: 0.071
[4,   670] loss: 0.072
[4,   680] loss: 0.072
[4,   690] loss: 0.072
[4,   700] loss: 0.074




[4,   710] loss: 0.072
[4,   720] loss: 0.070
[4,   730] loss: 0.067
[4,   740] loss: 0.070
[4,   750] loss: 0.071
[4,   760] loss: 0.071
[4,   770] loss: 0.072
[4,   780] loss: 0.066
[4,   790] loss: 0.067
[4,   800] loss: 0.060
[4,   810] loss: 0.076
[4,   820] loss: 0.084
[4,   830] loss: 0.068
[4,   840] loss: 0.070




[4,   850] loss: 0.070
[4,   860] loss: 0.068
[4,   870] loss: 0.070
[4,   880] loss: 0.070
[4,   890] loss: 0.066
[4,   900] loss: 0.078
[4,   910] loss: 0.070
[4,   920] loss: 0.069
[4,   930] loss: 0.071
[4,   940] loss: 0.071
[4,   950] loss: 0.069
[4,   960] loss: 0.070
[4,   970] loss: 0.072
[4,   980] loss: 0.068
[4,   990] loss: 0.070


100%|██████████| 1000/1000 [00:00<00:00, 1752.85it/s]


[4,  1000] loss: 0.072
Epoch 4 | Train Loss: 35.0886


100%|██████████| 100/100 [00:00<00:00, 4945.82it/s]
 50%|█████     | 4/8 [00:02<00:02,  1.73it/s]

Epoch 4 | Eval Loss: 0.7072




[5,    10] loss: 0.068
[5,    20] loss: 0.071
[5,    30] loss: 0.080
[5,    40] loss: 0.070
[5,    50] loss: 0.071
[5,    60] loss: 0.070
[5,    70] loss: 0.070
[5,    80] loss: 0.063
[5,    90] loss: 0.064
[5,   100] loss: 0.082
[5,   110] loss: 0.073
[5,   120] loss: 0.070
[5,   130] loss: 0.069
[5,   140] loss: 0.070
[5,   150] loss: 0.068
[5,   160] loss: 0.072
[5,   170] loss: 0.071




[5,   180] loss: 0.072
[5,   190] loss: 0.069
[5,   200] loss: 0.070
[5,   210] loss: 0.068
[5,   220] loss: 0.064
[5,   230] loss: 0.076
[5,   240] loss: 0.060
[5,   250] loss: 0.068
[5,   260] loss: 0.067
[5,   270] loss: 0.073
[5,   280] loss: 0.077
[5,   290] loss: 0.066
[5,   300] loss: 0.068
[5,   310] loss: 0.074
[5,   320] loss: 0.069
[5,   330] loss: 0.070
[5,   340] loss: 0.071
[5,   350] loss: 0.071
[5,   360] loss: 0.072




[5,   370] loss: 0.073
[5,   380] loss: 0.074
[5,   390] loss: 0.073
[5,   400] loss: 0.069
[5,   410] loss: 0.068
[5,   420] loss: 0.068
[5,   430] loss: 0.072
[5,   440] loss: 0.071
[5,   450] loss: 0.069
[5,   460] loss: 0.072
[5,   470] loss: 0.070
[5,   480] loss: 0.070
[5,   490] loss: 0.070
[5,   500] loss: 0.071
[5,   510] loss: 0.070
[5,   520] loss: 0.070
[5,   530] loss: 0.070
[5,   540] loss: 0.067
[5,   550] loss: 0.072




[5,   560] loss: 0.068
[5,   570] loss: 0.069
[5,   580] loss: 0.060
[5,   590] loss: 0.063
[5,   600] loss: 0.062
[5,   610] loss: 0.070
[5,   620] loss: 0.076
[5,   630] loss: 0.080
[5,   640] loss: 0.070
[5,   650] loss: 0.069
[5,   660] loss: 0.070
[5,   670] loss: 0.071
[5,   680] loss: 0.067
[5,   690] loss: 0.060
[5,   700] loss: 0.075




[5,   710] loss: 0.061
[5,   720] loss: 0.061
[5,   730] loss: 0.080
[5,   740] loss: 0.077
[5,   750] loss: 0.077
[5,   760] loss: 0.071
[5,   770] loss: 0.068
[5,   780] loss: 0.074
[5,   790] loss: 0.070
[5,   800] loss: 0.070
[5,   810] loss: 0.070
[5,   820] loss: 0.070
[5,   830] loss: 0.070
[5,   840] loss: 0.067
[5,   850] loss: 0.070
[5,   860] loss: 0.066
[5,   870] loss: 0.076
[5,   880] loss: 0.071
[5,   890] loss: 0.066
[5,   900] loss: 0.071
[5,   910] loss: 0.073




[5,   920] loss: 0.068
[5,   930] loss: 0.071


100%|██████████| 1000/1000 [00:00<00:00, 1841.91it/s]


[5,   940] loss: 0.069
[5,   950] loss: 0.065
[5,   960] loss: 0.075
[5,   970] loss: 0.072
[5,   980] loss: 0.071
[5,   990] loss: 0.070
[5,  1000] loss: 0.068
Epoch 5 | Train Loss: 35.0154


100%|██████████| 100/100 [00:00<00:00, 4914.01it/s]
 62%|██████▎   | 5/8 [00:02<00:01,  1.74it/s]

Epoch 5 | Eval Loss: 0.7024




[6,    10] loss: 0.065
[6,    20] loss: 0.078
[6,    30] loss: 0.070
[6,    40] loss: 0.070
[6,    50] loss: 0.070
[6,    60] loss: 0.072
[6,    70] loss: 0.070
[6,    80] loss: 0.070
[6,    90] loss: 0.065
[6,   100] loss: 0.073
[6,   110] loss: 0.070
[6,   120] loss: 0.074
[6,   130] loss: 0.070
[6,   140] loss: 0.069
[6,   150] loss: 0.070
[6,   160] loss: 0.074




[6,   170] loss: 0.069
[6,   180] loss: 0.070
[6,   190] loss: 0.069
[6,   200] loss: 0.067
[6,   210] loss: 0.070
[6,   220] loss: 0.071
[6,   230] loss: 0.063
[6,   240] loss: 0.059
[6,   250] loss: 0.087
[6,   260] loss: 0.067
[6,   270] loss: 0.073
[6,   280] loss: 0.070
[6,   290] loss: 0.068
[6,   300] loss: 0.071
[6,   310] loss: 0.070
[6,   320] loss: 0.071
[6,   330] loss: 0.070
[6,   340] loss: 0.071
[6,   350] loss: 0.070
[6,   360] loss: 0.070




[6,   370] loss: 0.070
[6,   380] loss: 0.069
[6,   390] loss: 0.070
[6,   400] loss: 0.070
[6,   410] loss: 0.070
[6,   420] loss: 0.069
[6,   430] loss: 0.066
[6,   440] loss: 0.061
[6,   450] loss: 0.073
[6,   460] loss: 0.077
[6,   470] loss: 0.071
[6,   480] loss: 0.071
[6,   490] loss: 0.070
[6,   500] loss: 0.072
[6,   510] loss: 0.070
[6,   520] loss: 0.070
[6,   530] loss: 0.069




[6,   540] loss: 0.069
[6,   550] loss: 0.068
[6,   560] loss: 0.066
[6,   570] loss: 0.073
[6,   580] loss: 0.072
[6,   590] loss: 0.071
[6,   600] loss: 0.070
[6,   610] loss: 0.070
[6,   620] loss: 0.070
[6,   630] loss: 0.069
[6,   640] loss: 0.068
[6,   650] loss: 0.071
[6,   660] loss: 0.068
[6,   670] loss: 0.068
[6,   680] loss: 0.068
[6,   690] loss: 0.072
[6,   700] loss: 0.072




[6,   710] loss: 0.066
[6,   720] loss: 0.068
[6,   730] loss: 0.071
[6,   740] loss: 0.071
[6,   750] loss: 0.073
[6,   760] loss: 0.071
[6,   770] loss: 0.070
[6,   780] loss: 0.070
[6,   790] loss: 0.070
[6,   800] loss: 0.070
[6,   810] loss: 0.069




[6,   820] loss: 0.070
[6,   830] loss: 0.068
[6,   840] loss: 0.068
[6,   850] loss: 0.068
[6,   860] loss: 0.072
[6,   870] loss: 0.071
[6,   880] loss: 0.070
[6,   890] loss: 0.067
[6,   900] loss: 0.072
[6,   910] loss: 0.070


100%|██████████| 1000/1000 [00:00<00:00, 1673.76it/s]


[6,   920] loss: 0.071
[6,   930] loss: 0.069
[6,   940] loss: 0.068
[6,   950] loss: 0.071
[6,   960] loss: 0.068
[6,   970] loss: 0.068
[6,   980] loss: 0.073
[6,   990] loss: 0.070
[6,  1000] loss: 0.067
Epoch 6 | Train Loss: 34.9400


100%|██████████| 100/100 [00:00<00:00, 4820.43it/s]
 75%|███████▌  | 6/8 [00:03<00:01,  1.69it/s]

Epoch 6 | Eval Loss: 0.7030




[7,    10] loss: 0.079
[7,    20] loss: 0.070
[7,    30] loss: 0.069
[7,    40] loss: 0.066
[7,    50] loss: 0.074
[7,    60] loss: 0.072
[7,    70] loss: 0.069
[7,    80] loss: 0.072
[7,    90] loss: 0.070
[7,   100] loss: 0.073
[7,   110] loss: 0.073
[7,   120] loss: 0.069
[7,   130] loss: 0.072




[7,   140] loss: 0.071
[7,   150] loss: 0.070
[7,   160] loss: 0.073
[7,   170] loss: 0.071
[7,   180] loss: 0.066
[7,   190] loss: 0.064
[7,   200] loss: 0.075
[7,   210] loss: 0.073
[7,   220] loss: 0.072
[7,   230] loss: 0.070
[7,   240] loss: 0.069
[7,   250] loss: 0.070
[7,   260] loss: 0.070
[7,   270] loss: 0.069
[7,   280] loss: 0.071
[7,   290] loss: 0.070
[7,   300] loss: 0.070
[7,   310] loss: 0.070
[7,   320] loss: 0.069
[7,   330] loss: 0.071
[7,   340] loss: 0.069
[7,   350] loss: 0.069
[7,   360] loss: 0.069
[7,   370] loss: 0.070




[7,   380] loss: 0.070
[7,   390] loss: 0.069
[7,   400] loss: 0.069
[7,   410] loss: 0.070
[7,   420] loss: 0.069
[7,   430] loss: 0.069
[7,   440] loss: 0.072
[7,   450] loss: 0.070
[7,   460] loss: 0.069
[7,   470] loss: 0.069
[7,   480] loss: 0.069
[7,   490] loss: 0.070
[7,   500] loss: 0.070
[7,   510] loss: 0.069




[7,   520] loss: 0.068
[7,   530] loss: 0.070
[7,   540] loss: 0.070
[7,   550] loss: 0.071
[7,   560] loss: 0.067
[7,   570] loss: 0.068
[7,   580] loss: 0.071
[7,   590] loss: 0.063
[7,   600] loss: 0.068
[7,   610] loss: 0.071
[7,   620] loss: 0.070
[7,   630] loss: 0.073
[7,   640] loss: 0.068
[7,   650] loss: 0.067
[7,   660] loss: 0.070
[7,   670] loss: 0.068
[7,   680] loss: 0.075
[7,   690] loss: 0.074
[7,   700] loss: 0.071
[7,   710] loss: 0.071
[7,   720] loss: 0.069
[7,   730] loss: 0.070
[7,   740] loss: 0.069
[7,   750] loss: 0.069




[7,   760] loss: 0.069
[7,   770] loss: 0.068
[7,   780] loss: 0.070
[7,   790] loss: 0.071
[7,   800] loss: 0.070
[7,   810] loss: 0.068
[7,   820] loss: 0.067
[7,   830] loss: 0.071
[7,   840] loss: 0.071
[7,   850] loss: 0.070
[7,   860] loss: 0.070
[7,   870] loss: 0.069
[7,   880] loss: 0.069
[7,   890] loss: 0.068


100%|██████████| 1000/1000 [00:00<00:00, 1850.35it/s]


[7,   900] loss: 0.071
[7,   910] loss: 0.070
[7,   920] loss: 0.072
[7,   930] loss: 0.070
[7,   940] loss: 0.069
[7,   950] loss: 0.070
[7,   960] loss: 0.070
[7,   970] loss: 0.069
[7,   980] loss: 0.069
[7,   990] loss: 0.070
[7,  1000] loss: 0.070
Epoch 7 | Train Loss: 34.9832


100%|██████████| 100/100 [00:00<00:00, 4848.01it/s]
 88%|████████▊ | 7/8 [00:04<00:00,  1.72it/s]

Epoch 7 | Eval Loss: 0.6925




[8,    10] loss: 0.070
[8,    20] loss: 0.069
[8,    30] loss: 0.071
[8,    40] loss: 0.068
[8,    50] loss: 0.071
[8,    60] loss: 0.070
[8,    70] loss: 0.068
[8,    80] loss: 0.065
[8,    90] loss: 0.071
[8,   100] loss: 0.070
[8,   110] loss: 0.075
[8,   120] loss: 0.068
[8,   130] loss: 0.069
[8,   140] loss: 0.070
[8,   150] loss: 0.072




[8,   160] loss: 0.071
[8,   170] loss: 0.069
[8,   180] loss: 0.069
[8,   190] loss: 0.068
[8,   200] loss: 0.074
[8,   210] loss: 0.069
[8,   220] loss: 0.070
[8,   230] loss: 0.068
[8,   240] loss: 0.069
[8,   250] loss: 0.072
[8,   260] loss: 0.070
[8,   270] loss: 0.069
[8,   280] loss: 0.070
[8,   290] loss: 0.068
[8,   300] loss: 0.074
[8,   310] loss: 0.067
[8,   320] loss: 0.071
[8,   330] loss: 0.070
[8,   340] loss: 0.069
[8,   350] loss: 0.066
[8,   360] loss: 0.068




[8,   370] loss: 0.072
[8,   380] loss: 0.062
[8,   390] loss: 0.071
[8,   400] loss: 0.076
[8,   410] loss: 0.066
[8,   420] loss: 0.068
[8,   430] loss: 0.072
[8,   440] loss: 0.066
[8,   450] loss: 0.076
[8,   460] loss: 0.073
[8,   470] loss: 0.070
[8,   480] loss: 0.069
[8,   490] loss: 0.070
[8,   500] loss: 0.070
[8,   510] loss: 0.069
[8,   520] loss: 0.070
[8,   530] loss: 0.069




[8,   540] loss: 0.071
[8,   550] loss: 0.069
[8,   560] loss: 0.069
[8,   570] loss: 0.070
[8,   580] loss: 0.070
[8,   590] loss: 0.069
[8,   600] loss: 0.069
[8,   610] loss: 0.070
[8,   620] loss: 0.069
[8,   630] loss: 0.069
[8,   640] loss: 0.070
[8,   650] loss: 0.069
[8,   660] loss: 0.068
[8,   670] loss: 0.072
[8,   680] loss: 0.071
[8,   690] loss: 0.069
[8,   700] loss: 0.068
[8,   710] loss: 0.069
[8,   720] loss: 0.068
[8,   730] loss: 0.070
[8,   740] loss: 0.070




[8,   750] loss: 0.069
[8,   760] loss: 0.071
[8,   770] loss: 0.068
[8,   780] loss: 0.067
[8,   790] loss: 0.066
[8,   800] loss: 0.069
[8,   810] loss: 0.076
[8,   820] loss: 0.072
[8,   830] loss: 0.068
[8,   840] loss: 0.070
[8,   850] loss: 0.070
[8,   860] loss: 0.067
[8,   870] loss: 0.065
[8,   880] loss: 0.073
[8,   890] loss: 0.069
[8,   900] loss: 0.070




[8,   910] loss: 0.070
[8,   920] loss: 0.068
[8,   930] loss: 0.073
[8,   940] loss: 0.069
[8,   950] loss: 0.070
[8,   960] loss: 0.070
[8,   970] loss: 0.070
[8,   980] loss: 0.069


100%|██████████| 1000/1000 [00:00<00:00, 1837.94it/s]


[8,   990] loss: 0.068
[8,  1000] loss: 0.070
Epoch 8 | Train Loss: 34.7998


100%|██████████| 100/100 [00:00<00:00, 3967.00it/s]
100%|██████████| 8/8 [00:04<00:00,  1.69it/s]

Epoch 8 | Eval Loss: 0.6881


                                             

0.6959957875907421

# Questions
- Do we pass through the sequence multiple times, each time masking different tokens or do we pass through the sequence once masking one random token?
- Maybe we evaluate on how well the model predicts the other token in the sequence?

## Previous implementations

In [198]:
class VanillaAttentionTransformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(VanillaAttentionTransformer, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a # parameter controlling how important are positions
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, num_spins) # output layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, s, masked_token):
        position_ids = torch.arange(len(s), dtype=torch.long)
        position_ids = position_ids
        s = torch.tensor(s, dtype=torch.float)
        masked_token = torch.tensor(masked_token, dtype=torch.float)
        #print("the sequence with the masked token:", s)
        print("shape of position ids before embedding:", position_ids.shape)
        print("shape the sequence with the masked token:", s.shape)
        print("value of the word embedding:", self.word_embeddings(s).shape)
        print("positional embedding:", self.position_embeddings(position_ids))
        x = self.word_embeddings(s) + self.position_embeddings(position_ids)
        query = self.query_weight(x)
        key = self.key_weight(x)

        values = self.value_weight(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)) # (batch_size,embed_dim)
        exp_scaling = torch.exp(self.word_embeddings(masked_token) + self.position_embeddings(position_ids).T@query.T@key@(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)))
        attn_output = torch.sum(exp_scaling/torch.sum(exp_scaling.sum(0))*values) # not sure it multiplies the way I want it to multiply - check
        output = self.dropout(self.fc(attn_output))
        return output
    
class FactoredAttentionTransformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(FactoredAttentionTransformer, self).__init__()
        self.word_embeddings = nn.Embedding(num_spins, embed_dim)
        self.word_embedding.weight.requires_grad = False
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a # parameter controlling how important are positions
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, num_spins) # output layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, s, masked_token):
        # masked token should be equal to 0
        # masked_token = torch.tensor([0])
        position_ids = torch.arange(len(s), dtype=torch.long)
        position_ids = position_ids
        x = self.word_embeddings(s) + self.position_embeddings(position_ids)
        query = self.query_weight(x)
        key = self.key_weight(x)

        values = self.value_weight(self.word_embeddings(s)) # (embed_dim)
        exp_scaling = torch.exp(self.word_embeddings(masked_token) + self.position_embeddings(position_ids).T@query.T@key@(self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)))
        attn_output = torch.sum(exp_scaling/torch.sum(exp_scaling.sum(0))*values) # not sure it multiplies the way I want it to multiply - check
        output = self.fc(self.dropout(attn_output))
        return output