In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import tqdm
import random
import torch.nn.functional as F

# for tensorboard
from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter()

# Experiment with cross attention
In this experiment, we sample data not only from uniform distribution but from variety of distributions. We then use this data for encoder and feed its output throught cross-attention mechanism to the decoder. The decoder takes the sequence with missing tokens and tries to fill in the gaps.

### Architecture
- We use multihead attention for the encoder and vanilla attention for the decoder.
- In the encoder, before splitting heads I "merge" the 5 sequences through linear layer.

In [48]:
cd semester-project

/Users/mariayuffa/semester-project


In [72]:
# Sequences for decoder
final_chains_train = np.load('final_chains_T=1_num_iters=400_train.npy')
print("Loaded train sequences of proteins sampled from Boltzmann distribution:", final_chains_train.shape)

final_chains_test = np.load('final_chains_T=1_num_iters=400_test.npy')
print("Loaded test sequences of proteins sampled from Boltzmann distribution:", final_chains_test.shape)


# Sequences for encoder
k = 0
final_chains_encoder_train = np.zeros((1000,5,200))
final_chains_encoder_train[:,0,:] = np.load('final_chains_T=1_num_iters=400_J=10_test.npy') 
final_chains_encoder_test = np.zeros((1000,5,200))
final_chains_encoder_test[:,0,:] = np.load('final_chains_T=1_num_iters=400_J=10_test.npy') 
for i in range(20,100,20):
    k+=1
    final_chains_encoder_train[:,k,:] = np.load('final_chains_T=1_num_iters=400_J='+str(i)+'_train.npy')
    final_chains_encoder_test[:,k,:] = np.load('final_chains_T=1_num_iters=400_J='+str(i)+'_test.npy') 

print("Loaded train sequences of proteins for encoder distribution:", final_chains_encoder_train.shape)
print("Loaded test sequences of proteins for encoder distribution:", final_chains_encoder_test.shape)

tensor_samples_train = torch.tensor(final_chains_encoder_train, dtype=torch.float32) 
tensor_samples_test = torch.tensor(final_chains_encoder_test, dtype=torch.float32) 

Loaded train sequences of proteins sampled from Boltzmann distribution: (1000, 200)
Loaded test sequences of proteins sampled from Boltzmann distribution: (1000, 200)
Loaded train sequences of proteins for encoder distribution: (1000, 5, 200)
Loaded test sequences of proteins for encoder distribution: (1000, 5, 200)


In [73]:
# Define the parameters for different distributions
distributions = [{"type": "normal", "mean": 0, "std": 1},
        {"type": "uniform", "low": -1, "high": 1},
        {"type": "exponential", "scale": 1},
        {"type": "gamma", "scale": 1},
        {"type": "poisson", "lam": 1.0}]

def sample_data(num_samples, distributions, len_of_seq):
    # Define the parameters for different distributions
    
    # Initialize the tensor to store the samples
    all_samples = torch.zeros(len(distributions), num_samples, len_of_seq)
    
    # Sample data from each distribution
    for i in range(len(distributions)):
        distr = distributions[i]
        if distr["type"] == "normal":
            samples = np.random.normal(distr["mean"], distr["std"], (num_samples, len_of_seq))
            samples = np.where(samples >= 0, 1, -1)
        elif distr["type"] == "uniform":
            samples = np.random.uniform(distr["low"], distr["high"], (num_samples, len_of_seq))
            samples = np.where(samples >= 0, 1, -1)
        elif distr["type"] == "exponential":
            samples = np.random.exponential(distr["scale"], (num_samples, len_of_seq))
            samples = np.where(samples >= distr["scale"], 1, -1)
        elif distr["type"] == "gamma":
            samples = np.random.poisson(distr["scale"], (num_samples, len_of_seq))
            samples = np.where(samples >= np.mean(samples), 1, -1)
        elif distr["type"] == "poisson":
            samples = np.random.poisson(distr["lam"], (num_samples, len_of_seq))
            samples = np.where(samples >= np.mean(samples), 1, -1)
        else:
            raise ValueError("Unsupported distribution type")
        
        # Store the samples in the tensor
        all_samples[i] = torch.tensor(samples, dtype=torch.float32)
    
    return all_samples.reshape(num_samples, len(distributions), len_of_seq)

# Example usage
#tensor_samples_train = sample_data(1000, distributions, 200)  # 5 distributions, 1000 samples each, length of sequence 10
#print(tensor_samples_train.shape)

#tensor_samples_test = sample_data(1000, distributions, 200)  # 5 distributions, 1000 samples each, length of sequence 10
#print(tensor_samples_test.shape)

In [74]:
#data_train = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
#data_train_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
#data_test = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])
#data_test_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])

def one_hot_encoding(seq, vocab):
    if isinstance(seq, int):
        one_hot = torch.zeros(1, len(vocab), dtype=int)
        one_hot[:, vocab[seq]] = 1
    else:
        one_hot = torch.zeros((len(seq), len(vocab)), dtype=int)
        for i, spin in enumerate(seq):
            #print("spins:",spin)
            #print("vocab:",vocab)
            one_hot[i, vocab[spin]] = 1
    return one_hot

def mask_random_spins(sequence, vocab, mask_token=2):
    """
    Mask one random spin in a sequence of protein spins.
    
    Parameters:
    - sequence: a list or sequence of spins (integers)
    - mask_token: the token used to mask a spin (default is 2)
    
    Returns:
    - masked_sequence: a sequence similar to the input but with one spin masked
    - masked_position: the position of the spin that was masked
    """
    # Ensure the sequence can be converted to a list for masking
    sequence_list = sequence.numpy().tolist() if isinstance(sequence, torch.Tensor) else list(sequence)
    # Choose a random position to mask, excluding the first spin
    mask_positions = random.sample(range(len(sequence)), 10)
    
    # Mask the chosen position
    masked_sequence = sequence_list.copy()
    for mask_position in mask_positions:
        masked_sequence[mask_position] = mask_token

    # Create an array of zeros with shape (len(sequence), len(vocab))
    one_hot = one_hot_encoding(masked_sequence, vocab)
        
    return torch.tensor(one_hot, dtype=torch.float), torch.tensor(mask_position)
    

In [75]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads
        
        self.W_q = nn.Linear(embed_dim, embed_dim) # Linear layer for query 
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim) # Output layer to ensure that dimensionality matches the model's expected dimensionality
        
        self.combine_heads = nn.Linear(num_heads, 1)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, embed_dim = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
     
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        #print("x somewhere:", x.shape)
        #print("seqeunce length:",seq_length)
        #print("d model dim:",self.d_model)
        return x.transpose(1, 2).contiguous().view(batch_size,seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        #print("shape of Q before splitting heads:", Q.shape)
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        #print("shape of Q after splitting heads:", Q.shape)
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        #print(" before W_o:", self.combine_heads(attn_output.permute(1,2,0)).squeeze(-1).shape)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class VanillaAttention(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(VanillaAttention, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a  # parameter controlling how important positions are
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)  # output layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, s, enc_output):

        position_ids = torch.arange(s.size(0), dtype=torch.long)
        #x = self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)
        x = s + self.a*self.position_embeddings(position_ids)
        
        query = self.query_weight(x)
        key = self.key_weight(enc_output)
        values = self.value_weight(enc_output)
        
        # Simple attention score calculation (Dot product): this is equivalent to the interaction matrix
        scores = torch.matmul(query, key.transpose(-2, -1))  # Transpose last two dimensions for matrix multiplication
        scores = torch.softmax(scores, dim=-1)  # Apply softmax to scores to get probabilities

        # Apply attention scores to values
        attn_output = torch.matmul(scores, values)

        # Sum over the sequence length dimensions
        #attn_output = attn_output.sum(dim=1)
        output = self.fc(self.dropout(attn_output)) # should have size (20,3)

        return output

In [76]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [77]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, proj_layer_dim, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=len(distributions))
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(len(distributions),1)
        
    def forward(self, x):
        x = self.fc(x.permute(1,2,0)).permute(2,0,1)
        attn_output = self.self_attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)# masking one of the word in the sequence
        self.cross_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output):
        enc_output = enc_output.squeeze(0)
        attn_output = self.self_attn(x, x)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        x = x.sum(dim=0)
        return x

In [78]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(Transformer, self).__init__()
        #self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        #self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.word_embeddings = nn.Linear(num_spins, embed_dim)

        #self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        #self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_layer = EncoderLayer(embed_dim, proj_layer_dim, dropout)
        self.decoder_layer = DecoderLayer(embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout)
        self.fc = nn.Linear(embed_dim, num_spins)

    def forward(self, src, tgt):
        #src_mask, tgt_mask = self.generate_mask(src, tgt)
        #src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        #tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        src_embedded = self.word_embeddings(src)
        tgt_embedded = self.word_embeddings(tgt)
        enc_output = self.encoder_layer(src_embedded)
        dec_output = self.decoder_layer(tgt_embedded, enc_output)
        output = self.fc(dec_output)
        #print("output of the transformer:", output.shape)
        return output

## Training and validation

In [82]:
writer = SummaryWriter('runs/data_exp_distr_run_1')

def evaluate(model, data_test, data_test_dec, vocab, criterion, device=0):
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm.tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq_enc = data
        input_seq_dec = data_test_dec[i]
        
        input_encoder_one_hot = torch.stack([one_hot_encoding(input_seq_enc[i].tolist(), vocab) for i in range(len(input_seq_enc))], dim=0)
        input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)

        input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)
        # mask a token
        masked_sequence_dec, position = mask_random_spins(input_seq_dec, vocab, mask_token=2)
        
        # Forward pass
        outputs = model.forward(input_encoder_one_hot, masked_sequence_dec)

        #output_token = F.log_softmax(outputs, dim=-1)
        target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #target_token = input_seq[position]

        # Compute loss
        loss = criterion(outputs.unsqueeze(0), target_token)
        epoch_loss += loss.item()
    #model.train()
    return epoch_loss / len(data_test)

def train(model, data_train, data_train_dec, data_test, data_test_dec, vocab, optimizer, criterion, num_epochs=8, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq_enc = data
            input_seq_dec = data_train_dec[i]

            input_encoder_one_hot = torch.stack([one_hot_encoding(input_seq_enc[i].tolist(), vocab) for i in range(len(input_seq_enc))], dim=0)
            input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)

            input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)

            # mask a token in decoder
            masked_sequence_dec, position = mask_random_spins(input_seq_dec, vocab, mask_token=2)

            # Forward pass
            prediction = model.forward(input_encoder_one_hot, masked_sequence_dec) #masked_sequence[masked_position]
            
            #predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #input_seq[masked_position]
            
            # Compute loss
            #print("model prediction:", prediction.shape)
            #print("target:", target_token.shape)
            loss = criterion(prediction.unsqueeze(0), target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                writer.add_scalar("Running Loss", running_loss / 100, epoch)
                #print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        writer.add_scalar("Train Loss", epoch_loss / len(data), epoch)
        eval_loss = evaluate(model, data_test, data_test_dec, vocab, criterion, device=device)
        writer.add_scalar("Eval Loss", eval_loss, epoch)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)

writer.flush()
writer.close()

In [83]:
# Define the parameters 
vocab_size = 3
vocab = {-1:0,1:1,2:2} 
L = 200
embedding_dim = 200
proj_layer_dim = 128
hidden_dim = 200
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [84]:
# Example usage:

model = Transformer(embed_dim=embedding_dim, a=0, max_seq_length=L, num_spins=3, proj_layer_dim=128, dropout=dropout_rate)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, tensor_samples_train, final_chains_train, tensor_samples_test, final_chains_test, vocab, optimizer, criterion, device=device)
#torch.save(model.state_dict(), 'models/lstm_scratch.pt')
#evaluate(model, test_dataloader, criterion, device=device)

  0%|          | 0/8 [00:00<?, ?it/s]

  input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)
  return torch.tensor(one_hot, dtype=torch.float), torch.tensor(mask_position)
100%|██████████| 1000/1000 [00:15<00:00, 64.53it/s]


Epoch 1 | Train Loss: 2864.2012


  input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)
100%|██████████| 1000/1000 [00:07<00:00, 142.03it/s]
 12%|█▎        | 1/8 [00:22<02:37, 22.54s/it]

Epoch 1 | Eval Loss: 1.5495


100%|██████████| 1000/1000 [00:13<00:00, 74.71it/s]


Epoch 2 | Train Loss: 630.8991


100%|██████████| 1000/1000 [00:06<00:00, 151.56it/s]
 25%|██▌       | 2/8 [00:42<02:06, 21.04s/it]

Epoch 2 | Eval Loss: 2.4862


100%|██████████| 1000/1000 [00:14<00:00, 68.80it/s]


Epoch 3 | Train Loss: 389.6560


100%|██████████| 1000/1000 [00:07<00:00, 132.32it/s]
 38%|███▊      | 3/8 [01:04<01:47, 21.52s/it]

Epoch 3 | Eval Loss: 1.0908


100%|██████████| 1000/1000 [00:15<00:00, 63.07it/s]


Epoch 4 | Train Loss: 213.7412


100%|██████████| 1000/1000 [00:07<00:00, 141.26it/s]
 50%|█████     | 4/8 [01:27<01:28, 22.08s/it]

Epoch 4 | Eval Loss: 0.9384


100%|██████████| 1000/1000 [00:13<00:00, 73.96it/s]


Epoch 5 | Train Loss: 166.3436


100%|██████████| 1000/1000 [00:06<00:00, 148.28it/s]
 62%|██████▎   | 5/8 [01:47<01:04, 21.43s/it]

Epoch 5 | Eval Loss: 1.6429


100%|██████████| 1000/1000 [00:13<00:00, 74.12it/s]


Epoch 6 | Train Loss: 157.8172


100%|██████████| 1000/1000 [00:06<00:00, 147.68it/s]
 75%|███████▌  | 6/8 [02:08<00:42, 21.03s/it]

Epoch 6 | Eval Loss: 0.6804


100%|██████████| 1000/1000 [00:13<00:00, 75.24it/s]


Epoch 7 | Train Loss: 130.7992


100%|██████████| 1000/1000 [00:06<00:00, 151.75it/s]
 88%|████████▊ | 7/8 [02:27<00:20, 20.66s/it]

Epoch 7 | Eval Loss: 0.5222


100%|██████████| 1000/1000 [00:14<00:00, 71.11it/s]


Epoch 8 | Train Loss: 127.9613


100%|██████████| 1000/1000 [00:06<00:00, 147.66it/s]
                                             

Epoch 8 | Eval Loss: 0.5579




0.639806738367537

In [85]:
# To save only the decoder layer weights
torch.save(model.decoder_layer.state_dict(), 'model_decoder/decoder_weights.pth')

# If you need to load these weights later
decoder_weights = torch.load('model_decoder/decoder_weights.pth')
model.decoder_layer.load_state_dict(decoder_weights)

# Save the weights of the FC layer
torch.save(model.fc.state_dict(), 'model_decoder/transformer_fc_weights.pth')

# To load these weights back into the FC layer later
fc_weights = torch.load('model_decoder/transformer_fc_weights.pth')
model.fc.load_state_dict(fc_weights)

<All keys matched successfully>

# Ablation study
In this study we remove the encoder when testing the model

In [98]:
class TransformerAblated(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(TransformerAblated, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.decoder_layer = DecoderLayer(embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout)
        self.fc = nn.Linear(embed_dim, num_spins)

    def forward(self, tgt):
        tgt_embedded = self.word_embeddings(tgt)
        dec_output = self.decoder_layer(tgt_embedded, tgt_embedded)
        output = self.fc(dec_output)
        return output

In [100]:
# Create an instance of the new model
new_model = TransformerAblated(embed_dim=embedding_dim, a=0, max_seq_length=L, num_spins=3, proj_layer_dim=128, dropout=dropout_rate)

# Load the saved decoder weights
decoder_weights = torch.load('model_decoder/decoder_weights.pth')
new_model.decoder_layer.load_state_dict(decoder_weights)

# Load the saved FC weights
fc_weights = torch.load('model_decoder/transformer_fc_weights.pth')
new_model.fc.load_state_dict(fc_weights)


<All keys matched successfully>

In [103]:
writer = SummaryWriter('runs/transformer_ablation_run_3')

def evaluate(new_model, data_test, data_test_dec, vocab, criterion, device=0):
    new_model.eval()
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm.tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq_dec = data_test_dec[i]
        input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)
        # mask a token
        masked_sequence_dec, position = mask_random_spins(input_seq_dec, vocab, mask_token=2)
        
        # Forward pass
        outputs = new_model.forward(masked_sequence_dec)

        #output_token = F.log_softmax(outputs, dim=-1)
        target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #target_token = input_seq[position]

        # Compute loss
        loss = criterion(outputs.unsqueeze(0), target_token)
        epoch_loss += loss.item()

    return epoch_loss / len(data_test)

def train(model, new_model, data_train, data_train_dec, data_test, data_test_dec, vocab, optimizer, criterion, num_epochs=20, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq_enc = data
            input_seq_dec = data_train_dec[i]

            input_encoder_one_hot = torch.stack([one_hot_encoding(input_seq_enc[i].tolist(), vocab) for i in range(len(input_seq_enc))], dim=0)
            input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)

            input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)

            # mask a token in decoder
            masked_sequence_dec, position = mask_random_spins(input_seq_dec, vocab, mask_token=2)

            # Forward pass
            prediction = model.forward(input_encoder_one_hot, masked_sequence_dec) #masked_sequence[masked_position]
            
            #predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #input_seq[masked_position]
            
            # Compute loss
            #print("model prediction:", prediction.shape)
            #print("target:", target_token.shape)
            loss = criterion(prediction.unsqueeze(0), target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                writer.add_scalar("Running Loss", running_loss / 100, epoch)
                #print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        new_model.decoder_layer = model.decoder_layer
        new_model.fc = model.fc
        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        writer.add_scalar("Train Loss", epoch_loss / len(data), epoch)
        eval_loss = evaluate(new_model, data_test, data_test_dec, vocab, criterion, device=device)
        writer.add_scalar("Eval Loss", eval_loss, epoch)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)

writer.flush()
writer.close()

In [104]:
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, new_model, tensor_samples_train, final_chains_train, tensor_samples_test, final_chains_test, vocab, optimizer, criterion, device=device)

# Define the parameters 
vocab_size = 3
vocab = {-1:0,1:1,2:2} 
L = 200
embedding_dim = 200
proj_layer_dim = 128
hidden_dim = 200
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

  input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)
100%|██████████| 1000/1000 [00:15<00:00, 63.55it/s]


Epoch 1 | Train Loss: 111.3822


100%|██████████| 1000/1000 [00:03<00:00, 324.37it/s]
  5%|▌         | 1/20 [00:18<05:57, 18.82s/it]

Epoch 1 | Eval Loss: 57.9337


100%|██████████| 1000/1000 [00:16<00:00, 62.13it/s]


Epoch 2 | Train Loss: 107.1652


100%|██████████| 1000/1000 [00:02<00:00, 347.44it/s]
 10%|█         | 2/20 [00:37<05:40, 18.92s/it]

Epoch 2 | Eval Loss: 6.2594


100%|██████████| 1000/1000 [00:15<00:00, 63.73it/s]


Epoch 3 | Train Loss: 107.3647


100%|██████████| 1000/1000 [00:03<00:00, 324.13it/s]
 15%|█▌        | 3/20 [00:56<05:20, 18.85s/it]

Epoch 3 | Eval Loss: 3.6162


100%|██████████| 1000/1000 [00:15<00:00, 64.01it/s]


Epoch 4 | Train Loss: 113.5557


100%|██████████| 1000/1000 [00:03<00:00, 333.04it/s]
 20%|██        | 4/20 [01:15<05:00, 18.77s/it]

Epoch 4 | Eval Loss: 2.8574


100%|██████████| 1000/1000 [00:15<00:00, 64.37it/s]


Epoch 5 | Train Loss: 114.4541


100%|██████████| 1000/1000 [00:03<00:00, 325.59it/s]
 25%|██▌       | 5/20 [01:33<04:40, 18.71s/it]

Epoch 5 | Eval Loss: 2.3178


100%|██████████| 1000/1000 [00:14<00:00, 67.86it/s]


Epoch 6 | Train Loss: 110.8094


100%|██████████| 1000/1000 [00:03<00:00, 301.39it/s]
 30%|███       | 6/20 [01:51<04:18, 18.49s/it]

Epoch 6 | Eval Loss: 5.4813


100%|██████████| 1000/1000 [00:15<00:00, 65.96it/s]


Epoch 7 | Train Loss: 115.2452


100%|██████████| 1000/1000 [00:03<00:00, 331.95it/s]
 35%|███▌      | 7/20 [02:10<03:59, 18.39s/it]

Epoch 7 | Eval Loss: 1.5391


100%|██████████| 1000/1000 [00:15<00:00, 65.61it/s]


Epoch 8 | Train Loss: 116.9898


100%|██████████| 1000/1000 [00:02<00:00, 356.33it/s]
 40%|████      | 8/20 [02:28<03:39, 18.28s/it]

Epoch 8 | Eval Loss: 2.5127


100%|██████████| 1000/1000 [00:13<00:00, 73.98it/s]


Epoch 9 | Train Loss: 96.2396


100%|██████████| 1000/1000 [00:02<00:00, 357.96it/s]
 45%|████▌     | 9/20 [02:44<03:14, 17.67s/it]

Epoch 9 | Eval Loss: 3.0840


100%|██████████| 1000/1000 [00:13<00:00, 71.88it/s]


Epoch 10 | Train Loss: 116.8877


100%|██████████| 1000/1000 [00:02<00:00, 333.82it/s]
 50%|█████     | 10/20 [03:01<02:54, 17.43s/it]

Epoch 10 | Eval Loss: 1.3160


100%|██████████| 1000/1000 [00:13<00:00, 71.57it/s]


Epoch 11 | Train Loss: 103.1891


100%|██████████| 1000/1000 [00:03<00:00, 291.57it/s]
 55%|█████▌    | 11/20 [03:18<02:36, 17.43s/it]

Epoch 11 | Eval Loss: 1.8759


100%|██████████| 1000/1000 [00:15<00:00, 64.95it/s]


Epoch 12 | Train Loss: 105.6388


100%|██████████| 1000/1000 [00:03<00:00, 328.21it/s]
 60%|██████    | 12/20 [03:37<02:21, 17.74s/it]

Epoch 12 | Eval Loss: 1.6712


100%|██████████| 1000/1000 [00:15<00:00, 66.65it/s]


Epoch 13 | Train Loss: 115.3018


100%|██████████| 1000/1000 [00:03<00:00, 262.81it/s]
 65%|██████▌   | 13/20 [03:56<02:06, 18.06s/it]

Epoch 13 | Eval Loss: 1.6458


100%|██████████| 1000/1000 [00:14<00:00, 71.28it/s]


Epoch 14 | Train Loss: 114.2144


100%|██████████| 1000/1000 [00:02<00:00, 370.62it/s]
 70%|███████   | 14/20 [04:12<01:45, 17.66s/it]

Epoch 14 | Eval Loss: 2.2263


100%|██████████| 1000/1000 [00:13<00:00, 72.25it/s]


Epoch 15 | Train Loss: 111.1181


100%|██████████| 1000/1000 [00:02<00:00, 358.30it/s]
 75%|███████▌  | 15/20 [04:29<01:26, 17.35s/it]

Epoch 15 | Eval Loss: 1.9236


100%|██████████| 1000/1000 [00:14<00:00, 71.25it/s]


Epoch 16 | Train Loss: 103.5366


100%|██████████| 1000/1000 [00:02<00:00, 340.71it/s]
 80%|████████  | 16/20 [04:46<01:08, 17.24s/it]

Epoch 16 | Eval Loss: 2.3727


100%|██████████| 1000/1000 [00:14<00:00, 69.84it/s]


Epoch 17 | Train Loss: 108.6093


100%|██████████| 1000/1000 [00:03<00:00, 330.44it/s]
 85%|████████▌ | 17/20 [05:03<00:51, 17.27s/it]

Epoch 17 | Eval Loss: 1.7544


100%|██████████| 1000/1000 [00:14<00:00, 68.49it/s]


Epoch 18 | Train Loss: 126.6667


100%|██████████| 1000/1000 [00:03<00:00, 311.90it/s]
 90%|█████████ | 18/20 [05:21<00:34, 17.43s/it]

Epoch 18 | Eval Loss: 1.7751


100%|██████████| 1000/1000 [00:15<00:00, 62.58it/s]


Epoch 19 | Train Loss: 111.4818


100%|██████████| 1000/1000 [00:02<00:00, 340.29it/s]
 95%|█████████▌| 19/20 [05:40<00:17, 17.88s/it]

Epoch 19 | Eval Loss: 3.1513


100%|██████████| 1000/1000 [00:14<00:00, 70.57it/s]


Epoch 20 | Train Loss: 102.2865


100%|██████████| 1000/1000 [00:03<00:00, 317.43it/s]
                                               

Epoch 20 | Eval Loss: 2.0364
mps


