Write as a train loop:

In [3]:
%%time
import torch
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from transformers import EsmTokenizer, EsmForMaskedLM, AdamW

# Training configuration
num_epochs = 1
lr = 5e-5
max_len = 280
mask_prob = 0.15

# Initialize model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=".cache")
model = EsmForMaskedLM.from_pretrained(model_name, cache_dir=".cache")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Sample protein sequences (same as before)
sequences = [
    "RVQPTESIVRFPNITNLCPFGEVFNATRFSSVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNPAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
]   

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.01)
aa_preds_tracker = {}

# Training loop
for epoch in range(num_epochs):
    model.train()
    aa_substitutions = []
    
    # Tokenize sequences fresh each epoch (simulating different batches)
    tokenized_seqs = tokenizer(sequences, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    tokenized_seqs = {k: v.to(device) for k, v in tokenized_seqs.items()}
    original_ids = tokenized_seqs["input_ids"]
    
    # Generate new mask for each epoch
    torch.manual_seed(0)
    rand = torch.rand(original_ids.shape, device=device)
    mask_arr = (rand < mask_prob) * \
               (original_ids != tokenizer.cls_token_id) * \
               (original_ids != tokenizer.eos_token_id) * \
               (original_ids != tokenizer.pad_token_id)
    
    masked_original_ids = original_ids.clone()
    masked_original_ids[mask_arr] = tokenizer.mask_token_id

    # Forward pass
    model.eval()
    optimizer.zero_grad()
    outputs = model(masked_original_ids, labels=original_ids)
    loss = outputs.loss
    preds = outputs.logits
    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()
    
    # Calculate metrics
    predicted_ids = torch.argmax(preds, dim=-1)
    mask = (masked_original_ids == tokenizer.mask_token_id)
    
    original_tokens = original_ids[mask]
    predicted_tokens = predicted_ids[mask]
    correct = (original_tokens == predicted_tokens).sum().item()
    total = mask.sum().item()
    accuracy = (correct / total) * 100 if total > 0 else 0.0

    aa_keys = [f"{tokenizer.convert_ids_to_tokens(o.item())}->{tokenizer.convert_ids_to_tokens(p.item())}"
               for o, p in zip(original_tokens, predicted_tokens)]
    aa_substitutions.extend(aa_keys)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {loss.item():.4f} | Accuracy: {accuracy:.2f}%")
    aa_counter = Counter(aa_substitutions)
    aa_preds_tracker[epoch + 1] = aa_counter

df = pd.DataFrame.from_dict(aa_preds_tracker, orient='index').T

# Rename columns for clarity
df.columns = [f'epoch {i}' for i in df.columns]

# Reset index and rename the index column
df.rename_axis(index='expected_aa->predicted_aa', inplace=True)
df.reset_index(inplace=True)

# Fill NaN values with 0 and sort by total count
df = df.fillna(0)
df['total'] = df.iloc[:, 1:].sum(axis=1)
df = df.sort_values('total', ascending=False).drop('total', axis=1)

# Reset index again
df = df.reset_index(drop=True)
df


Epoch 1/1
Loss: 0.8311 | Accuracy: 7.78%
CPU times: user 69 ms, sys: 5.66 ms, total: 74.7 ms
Wall time: 332 ms


Unnamed: 0,expected_aa->predicted_aa,epoch 1
0,S->S,6
1,F->G,6
2,A->G,6
3,V->L,5
4,V->S,5
...,...,...
80,C->N,1
81,V->T,1
82,Q->L,1
83,H->D,1


Need to test if this is properly filtering out special tokens, when encountered:

In [15]:
from transformers import EsmTokenizer, EsmForMaskedLM
import torch

# Load model/tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=".cache")
model = EsmForMaskedLM.from_pretrained(model_name, cache_dir=".cache")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Get special tokens and a few amino acids
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
a_id = tokenizer.convert_tokens_to_ids("A")
c_id = tokenizer.convert_tokens_to_ids("C")

# Simulate model outputs (including special tokens)
original_tokens = torch.tensor([a_id, pad_id, c_id, eos_id], device=device)
predicted_tokens = torch.tensor([a_id, pad_id, a_id, pad_id], device=device)

# Calculate accuracy
correct = (original_tokens == predicted_tokens).sum().item()
total = original_tokens.numel()
accuracy = (correct / total) * 100

# Substitution keys (filtered)
aa_keys = [
    f"{tokenizer.convert_ids_to_tokens(o.item())}->{tokenizer.convert_ids_to_tokens(p.item())}"
    for o, p in zip(original_tokens, predicted_tokens)
]

print("Original:", tokenizer.convert_ids_to_tokens(original_tokens.tolist()))
print("Predicted:", tokenizer.convert_ids_to_tokens(predicted_tokens.tolist()))
print("Substitutions:", aa_keys)
print(f"Accuracy: {accuracy:.2f}%")

# Assertions
assert total == 2, "Only two valid amino acid substitutions should be considered"
assert correct == 1, "Only A->A is correct"
assert "A->A" in aa_keys
assert all("<pad>" not in k and "<eos>" not in k for k in aa_keys), "Special tokens should not be in substitutions"
print("✅ Test passed (no special tokens included).")

Original: ['A', '<pad>', 'C', '<eos>']
Predicted: ['A', '<pad>', 'A', '<pad>']
Substitutions: ['A->A', '<pad>-><pad>', 'C->A', '<eos>-><pad>']
Accuracy: 50.00%


AssertionError: Only two valid amino acid substitutions should be considered

Try filtering the special tokens out to only amino acids:

In [16]:
from transformers import EsmTokenizer, EsmForMaskedLM
import torch

# Load model/tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=".cache")
model = EsmForMaskedLM.from_pretrained(model_name, cache_dir=".cache")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Get special tokens and a few amino acids
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
a_id = tokenizer.convert_tokens_to_ids("A")
c_id = tokenizer.convert_tokens_to_ids("C")

# Simulate model outputs (including special tokens)
original_tokens = torch.tensor([a_id, pad_id, c_id, eos_id], device=device)
predicted_tokens = torch.tensor([a_id, pad_id, a_id, pad_id], device=device)

# Filter to only canonical amino acids
aa_ids_tensor = torch.tensor([tokenizer.convert_tokens_to_ids(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"], device=device)
is_aa_only = torch.isin(original_tokens, aa_ids_tensor) & torch.isin(predicted_tokens, aa_ids_tensor)
aa_only_original = original_tokens[is_aa_only]
aa_only_predicted = predicted_tokens[is_aa_only]

# Calculate accuracy on amino acids only
correct = (aa_only_original == aa_only_predicted).sum().item()
total = is_aa_only.sum().item()
accuracy = (correct / total) * 100 if total > 0 else 0.0

# Substitution keys (filtered)
aa_keys = [
    f"{tokenizer.convert_ids_to_tokens(o.item())}->{tokenizer.convert_ids_to_tokens(p.item())}"
    for o, p in zip(aa_only_original, aa_only_predicted)
]

print("Original:", tokenizer.convert_ids_to_tokens(aa_only_original.tolist()))
print("Predicted:", tokenizer.convert_ids_to_tokens(aa_only_predicted.tolist()))
print("Substitutions:", aa_keys)
print(f"Accuracy: {accuracy:.2f}%")

# Assertions
assert total == 2, "Only two valid amino acid substitutions should be considered"
assert correct == 1, "Only A->A is correct"
assert "A->A" in aa_keys
assert all("<pad>" not in k and "<eos>" not in k for k in aa_keys), "Special tokens should not be in substitutions"
print("✅ Test passed (no special tokens included).")

Original: ['A', 'C']
Predicted: ['A', 'A']
Substitutions: ['A->A', 'C->A']
Accuracy: 50.00%
✅ Test passed (no special tokens included).


Okay, now fix the train loop:

In [23]:
%%time
import torch
import numpy as np
from collections import defaultdict, Counter
from transformers import EsmTokenizer, EsmForMaskedLM, AdamW

# Training configuration
num_epochs = 1
lr = 5e-5
max_len = 280
mask_prob = 0.15

# Initialize model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=".cache")
model = EsmForMaskedLM.from_pretrained(model_name, cache_dir=".cache")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Sample protein sequences (same as before)
sequences = [
    "RVQPTESIVRFPNITNLCPFGEVFNATRFSSVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNPAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGSKPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
    "RVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF",
]   

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.01)
aa_preds_tracker = {}

# Training loop
for epoch in range(num_epochs):
    model.train()
    aa_substitutions = []
    
    # Tokenize sequences fresh each epoch (simulating different batches)
    tokenized_seqs = tokenizer(sequences, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    tokenized_seqs = {k: v.to(device) for k, v in tokenized_seqs.items()}
    original_ids = tokenized_seqs["input_ids"]
    attention_mask = tokenized_seqs["attention_mask"]
    
    # Generate new mask for each epoch
    torch.manual_seed(0)
    rand = torch.rand(original_ids.shape, device=device)
    mask_arr = (rand < mask_prob) * \
               (original_ids != tokenizer.cls_token_id) * \
               (original_ids != tokenizer.eos_token_id) * \
               (original_ids != tokenizer.pad_token_id)
    
    masked_original_ids = original_ids.clone()
    masked_original_ids[mask_arr] = tokenizer.mask_token_id

    # Forward pass
    model.eval()
    optimizer.zero_grad()
    outputs = model(input_ids=masked_original_ids, attention_mask=attention_mask, labels=original_ids)
    loss = outputs.loss
    preds = outputs.logits
    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()
    
    # Calculate metrics
    predicted_ids = torch.argmax(preds, dim=-1)
    mask = (masked_original_ids == tokenizer.mask_token_id)
    
    original_tokens = original_ids[mask]
    predicted_tokens = predicted_ids[mask]

    aa_ids_tensor = torch.tensor([tokenizer.convert_tokens_to_ids(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"], device=device)
    is_aa_only = torch.isin(original_tokens, aa_ids_tensor) & torch.isin(predicted_tokens, aa_ids_tensor)
    aa_only_original = original_tokens[is_aa_only]
    aa_only_predicted = predicted_tokens[is_aa_only]

    # Calculate accuracy 
    correct = (aa_only_original == aa_only_predicted).sum().item()
    total = is_aa_only.sum().item()
    accuracy = (correct / total) * 100 if total > 0 else 0.0

    aa_keys = [f"{tokenizer.convert_ids_to_tokens(o.item())}->{tokenizer.convert_ids_to_tokens(p.item())}"
               for o, p in zip(original_tokens, predicted_tokens)]
    aa_substitutions.extend(aa_keys)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {loss.item():.4f} | Accuracy: {accuracy:.2f}%")
    aa_counter = Counter(aa_substitutions)

    # for substitution, count in aa_counter.items():
    #     substitution = substitution.split("->")
    #     original = substitution[0]
    #     predicted = substitution[1]
        
    #     print(f"{original}->{predicted}: {count}")

    aa_preds_tracker[epoch + 1] = aa_counter

df = pd.DataFrame.from_dict(aa_preds_tracker, orient='index').T

# Rename columns for clarity
df.columns = [f'epoch {i}' for i in df.columns]

# Reset index and rename the index column
df.rename_axis(index='expected_aa->predicted_aa', inplace=True)
df.reset_index(inplace=True)

# Fill NaN values with 0 and sort by total count
df = df.fillna(0)
df['total'] = df.iloc[:, 1:].sum(axis=1)
df = df.sort_values('total', ascending=False).drop('total', axis=1)

# Reset index again
df = df.reset_index(drop=True)
df


Epoch 1/1
Loss: 0.8311 | Accuracy: 7.78%
CPU times: user 68.6 ms, sys: 6.96 ms, total: 75.5 ms
Wall time: 343 ms


Unnamed: 0,expected_aa->predicted_aa,epoch 1
0,S->S,6
1,F->G,6
2,A->G,6
3,V->L,5
4,V->S,5
...,...,...
80,C->N,1
81,V->T,1
82,Q->L,1
83,H->D,1
