# Enzyme function prediction using neural networks

In this homework, we will design several neural networks with different architectures for Enzyme function predictions from protein sequences.

Enzyme functions can be represented by Enzyme Commission (EC) numbers. In this problem, each enzyme (protein) in the training or test set is labeled with exactly one EC number. There are a total of 200 distinct EC numbers appeared in the dataset. So this task can be formulated as a single-label multi-class classification problem.

To begin with, run the following cell to download the training and test data.

In [None]:
!wget https://drive.google.com/uc\?export\=download\&id\=1cJeJjoCfycp4f3yHABO8bai6Em0zoc15 -O train.csv
!wget https://drive.google.com/uc\?export\=download\&id\=1owiCCMlYXdT1z7wdz5k6If1fqQquI76P -O test.csv
!wget https://drive.google.com/uc\?export\=download\&id\=12HEAGnegf8h15M_3osmerN8gthW_ERoJ -O train_seqs.fasta
!wget https://drive.google.com/uc\?export\=download\&id\=1W1LDba5TLJwaNWvMT14QMDOfH6VPL7Wy -O test_seqs.fasta
!wget https://drive.google.com/uc\?export\=download\&id\=1-F1Seb2Fb-QOqBjfkFJIvc479IdTinlT -O ec_numbers.json
!wget https://drive.google.com/uc\?export\=download\&id\=1PlS7kXvcKGNlRa74FapVQ8M4GGcsWDjC -O train_subsample.csv
!wget https://drive.google.com/uc\?export\=download\&id\=1F2WyQV1xBdru3B3QBtDmTSQnxp1ZNSuP -O train_subsample.fasta

In [None]:
!pip install torch pandas numpy matplotlib scikit-learn tqdm pyarrow

Import the necessary packages:

In [61]:
# export
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json, os, time
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
##### DO NOT MODIFY ANYTHING IN THIS CELL #####

Load the training data and take a look at the sequences and EC number labels

In [62]:
df = pd.read_csv('train.csv')
df.head()

Unnamed: 0,Sequence,EC number,split
0,MTDLGLKWSCEYCTYENWPSAIKCTMCRAQRHNAPIITEEPFKSSS...,3.4.19.12,train
1,MHASLSSWLLAASLLTQPISVSGQGCPFAKRDGTVDSSLPQKRADA...,1.11.1.21,train
2,MSGYSSDRDRGRDRGFGAPRFGGSRAGPLSGKKFGNPGEKLVKKKW...,3.6.4.13,train
3,MTDSGDLCPHLDSIGEVTKEELIQKSKGTCQSCGVGGPNLWACLQC...,3.4.19.12,train
4,MSDEGSKRGSRADSLEAEPPLPPPPPPPPPGESSLVPTSPRYRPPL...,2.3.2.27,train


In [63]:
with open('ec_numbers.json') as f:
    ec_list = json.load(f)
print(f'Number of EC numbers: {len(ec_list)}')

Number of EC numbers: 200


In [64]:
sequences = df['Sequence'].tolist()
ec_numbers = df['EC number'].tolist()
ec2idx = {ec: idx for idx, ec in enumerate(ec_list)}
train_seq2name = {seq: f'train_seq_{i}' for i, seq in enumerate(sequences)}

Split 20% of the training data as the validation set:

In [65]:
seq_train, seq_val, ec_train, ec_val = train_test_split(sequences, ec_numbers, test_size=0.2, random_state=42)
print(f'Training samples: {len(seq_train)}')
print(f'Validation samples: {len(seq_val)}')

Training samples: 16000
Validation samples: 4000


## Task 1: One-hot tokenizer

Protein sequences consist of a list of amino acids. There are 20 types of standard amino acids. We need to transform (tokenize) protein sequences into tensors so that neural networks can take them as inputs. A straightforward way to tokenize protein sequences is to use one-hot encoding ([wiki link](https://en.wikipedia.org/wiki/One-hot)). In this task you need to complete the function `one_hot_encode` which takes a protein sequence (a string of amino acids) of length $L$ and output an one-hot-encoded tensor of shape $L\times 20$ (Note: there exist some unknown amino acids 'X' in the sequences, for such amino acid we can just encode it as all-zero vector).

In [66]:
# export
##### DO NOT MODIFY ANYTHING ABOVE THIS LINE #####

amino_acids = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}

# One-hot encoding function
def one_hot_encode(sequence):
    # TODO: sequence is a string of amino acids, return the one-hot encoded tensor of the sequence.
  one_hot = torch.zeros(len(sequence), 20)

  for i, aa in enumerate(sequence):
      if aa in aa_to_idx:
        one_hot[i, aa_to_idx[aa]] = 1
  return one_hot

With the one-hot tokenizer, we can design the dataset class. As you can see in the following cell, each data point returned by the dataset class contains three items: the one-hot encoded tensor, the length of the sequence, and the label. Since the length of the sequences in the dataset can vary, we provide the collate function that pads the sequences with zeros to the maximum length in the batch. You can pass this `collate_fn` to the `collate_fn` parameter of pytorch's DataLoader class to ensure the correct behaviour of batching.

In [67]:
# export
##### DO NOT MODIFY ANYTHING ABOVE THIS LINE #####

class ProteinDataset(Dataset):
    def __init__(self, sequences, labels, ec2idx):
        self.sequences = sequences
        self.labels = [ec2idx.get(ec, -1) for ec in labels]
        # TODO: use one_hot_encode to get the sequence tensors
        #####
        self.seq_tensors = [one_hot_encode(seq) for seq in sequences]
        #####

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_tensor = self.seq_tensors[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return seq_tensor, len(seq_tensor), label

# Collate function with padding
def collate_fn(batch):
    sequences, seq_lens, labels = zip(*batch)
    max_len = max(seq_lens)

    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)

    return padded_sequences, torch.tensor(seq_lens, dtype=torch.long), torch.tensor(labels, dtype=torch.long)


## Task 2: Transformer built from multi-head attention

In this task, you are required to implement a vanilla transformer encoder model for the EC number prediction task. You should construct the transformer model using blocks like `torch.nn.MultiheadAttention`, `torch.nn.Linear`, `torch.nn.LayerNorm`. Your transformer model should have the same architecture as the encoder module described in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762). We recommend you to check PyTorch's documentation for the modules mentioned before.

In [68]:
# export
##### DO NOT MODIFY ANYTHING ABOVE THIS LINE #####

class AttentionClassifier(nn.Module):
    def __init__(self, num_classes, embed_dim=64, num_heads=1, num_layers=1, ff_dim=128):
        super(AttentionClassifier, self).__init__()
        self.embedding = nn.Linear(20, embed_dim)
        # TODO: implement the transformer block using multiheadattention, linear, and layernorm.
        #####
        self.attention_layers = nn.ModuleList([
            nn.ModuleDict({
                'self_attn': nn.MultiheadAttention(embed_dim, num_heads, batch_first=True),
                'norm1': nn.LayerNorm(embed_dim),
                'ff': nn.Sequential(
                    nn.Linear(embed_dim, ff_dim),
                    nn.ReLU(),
                    nn.Linear(ff_dim, embed_dim)
                ),
                'norm2': nn.LayerNorm(embed_dim)
            }) for _ in range(num_layers)
        ])
        #####
        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x, seq_lens):
        x = self.embedding(x)
        max_len = x.shape[1]
        mask = torch.arange(max_len, device=x.device).expand(len(seq_lens), max_len) >= seq_lens.unsqueeze(1)
        mask = mask.to(x.device)

        # TODO: forward part of the transformer block
        #####
        for layer in self.attention_layers:
            attn_out, _ = layer['self_attn'](x, x, x, key_padding_mask=mask)
            x = layer['norm1'](x + attn_out)
            ff_out = layer['ff'](x)
            x = layer['norm2'](x + ff_out)
        #####

        x = x.permute(0, 2, 1)
        x = self.pooling(x).squeeze(-1)
        return self.fc(x)

## Task 3: Transformer built from TransformerEncoder class

In this task, you will also implement a vanilla transformer model. Instead of constructing the model from small blocks like MultiheadAttention, you should use the wrapped module `torch.nn.TransformerEncoderLayer` and `torch.nn.TransformerEncoder` to directly build the model. We recommend you to check the documentation of these two modules to learn their usage.

In [69]:
class TransformerClassifier(nn.Module):
    def __init__(self, num_classes, embed_dim=64, num_heads=1, num_layers=1, ff_dim=128):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Linear(20, embed_dim)  # Project one-hot input to embedding space
        # TODO: transformer block using TransformerEncoderLayer
        #####
        encoder_layers = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True)
        self.encoder_layers = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        #####
        self.pooling = nn.AdaptiveAvgPool1d(1)  # Global average pooling
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x, seq_lens):
        x = self.embedding(x)  # (batch_size, seq_len, embed_dim)

        # Create attention mask
        max_len = x.shape[1]
        mask = torch.arange(max_len, device=x.device).expand(len(seq_lens), max_len) >= seq_lens.unsqueeze(1)
        mask = mask.to(x.device)

        # TODO: forward part of the transformer block
        #####
        x = self.encoder_layers(x, src_key_padding_mask=mask)
        #####

        x = x.permute(0, 2, 1)
        x = self.pooling(x).squeeze(-1)  # (batch, embed_dim)
        return self.fc(x)

## Task 4: 1D-CNN model

In this task, you are going to implement a model using 1D CNN layers. You can use PyTorch's `torch.nn.Conv1d` to construct the model. Note that for simplicity, you do not have to consider the padded part of the input tensor. Refer to PyTorch's documentation for the usage of `torch.nn.Conv1d`.

In [70]:
class CNNClassifier(nn.Module):
    def __init__(self, num_classes, embed_dim=64, num_filters=128, kernel_size=3, num_layers=3):
        super(CNNClassifier, self).__init__()
        self.embedding = nn.Linear(20, embed_dim)
        # TODO: 1D convolutional layers
        #####
        layers = []
        in_channels = embed_dim
        for _ in range(num_layers):
            layers.append(nn.Conv1d(in_channels, num_filters, kernel_size, padding=kernel_size//2))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.BatchNorm1d(num_filters))
            in_channels = num_filters
        self.conv_layers = nn.Sequential(*layers)
        #####

        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(num_filters, num_classes)

    def forward(self, x, seq_lens):
        x = self.embedding(x).permute(0, 2, 1)  # Convert to (batch, embed_dim, seq_len)

        # TODO: forward part of the CNN block
        #####
        x = self.conv_layers(x)
        #####

        x = self.pooling(x).squeeze(-1)  # Global average pooling
        return self.fc(x)

## Training the neural networks

Complete the function `train_model`.

In [71]:
def train_model(model, train_dataset, val_dataset, num_classes, epochs=100, batch_size=256, lr=1e-3, patience=10, device='cuda:0'):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_acc = 0
    patience_counter = 0
    best_ckpt = None

    for epoch in range(epochs):
        start_epoch = time.time()
        model.train()
        total_loss, correct, total = 0, 0, 0

        for sequences, seq_lens, labels in train_loader:
            # TODO: backpropagation
            #####
            sequences, seq_lens, labels = sequences.to(device), seq_lens.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(sequences, seq_lens)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            #####

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for sequences, seq_lens, labels in val_loader:
                # TODO: model inference
                #####
                sequences, seq_lens, labels = sequences.to(device), seq_lens.to(device), labels.to(device)
                outputs = model(sequences, seq_lens)
                preds = outputs.argmax(dim=1)
                #####
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        end_epoch = time.time()
        print(f'Epoch [{epoch+1} / {epochs}]: Train Loss={total_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}, Time={end_epoch - start_epoch:.4f} sec')

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            best_ckpt = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    return model, best_ckpt

Train AttentionClassifier. Note that to get better performance, you might need to tune the hyperparameters like epochs, batch_size, learning rate (lr), early stop patience (patience), as well as the model size (number of layers, number of attention heads, embed dimension, feed-forward dimensions). The same is true for training other models.

In [75]:
train_dataset = ProteinDataset(seq_train, ec_train, ec2idx)
val_dataset = ProteinDataset(seq_val, ec_val, ec2idx)
num_classes = len(ec_list)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = AttentionClassifier(num_classes).to(device)
model, best_ckpt = train_model(model, train_dataset, val_dataset, num_classes=num_classes, epochs=28, batch_size=14, lr=5e-4, patience=8, device=device)
model.load_state_dict(best_ckpt)

df_test = pd.read_csv('test.csv')
test_sequences = df_test['Sequence'].tolist()
test_seq2name = {seq: f'test_seq_{i}' for i, seq in enumerate(test_sequences)}
test_dataset = ProteinDataset(test_sequences, [0]*len(test_sequences), ec2idx)
test_loader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn)

model.eval()
preds = []
with torch.no_grad():
    for sequences, seq_lens, _ in test_loader:
        # TODO: inference on the test set
        #####
        sequences = sequences.to(device)
        seq_lens = seq_lens.to(device)
        output = model(sequences, seq_lens)
        preds_batch = output.argmax(dim=1).cpu().numpy()
        preds.extend(preds_batch)
        #####
# save the predictions to a individual CSV file, each row contains the predicted EC number for the corresponding sequence in the test set, no need for header
preds = [ec_list[pred] for pred in preds]
df_preds = pd.DataFrame(preds)
df_preds.to_csv('test_preds_attention.csv', index=False, header=False)

Epoch [1 / 28]: Train Loss=5760.4093, Train Acc=0.0217, Val Acc=0.0483, Time=5.4176 sec
Epoch [2 / 28]: Train Loss=4559.6886, Train Acc=0.1212, Val Acc=0.1948, Time=5.4369 sec
Epoch [3 / 28]: Train Loss=3637.6143, Train Acc=0.2724, Val Acc=0.3420, Time=5.4780 sec
Epoch [4 / 28]: Train Loss=3010.7403, Train Acc=0.3925, Val Acc=0.4315, Time=5.4180 sec
Epoch [5 / 28]: Train Loss=2556.4606, Train Acc=0.4746, Val Acc=0.4938, Time=5.4505 sec
Epoch [6 / 28]: Train Loss=2279.4355, Train Acc=0.5211, Val Acc=0.5032, Time=5.4366 sec
Epoch [7 / 28]: Train Loss=2115.7090, Train Acc=0.5504, Val Acc=0.5327, Time=5.4237 sec
Epoch [8 / 28]: Train Loss=1967.3829, Train Acc=0.5778, Val Acc=0.5543, Time=5.4508 sec
Epoch [9 / 28]: Train Loss=1864.6382, Train Acc=0.5982, Val Acc=0.5857, Time=5.4550 sec
Epoch [10 / 28]: Train Loss=1769.2826, Train Acc=0.6150, Val Acc=0.5837, Time=5.4464 sec
Epoch [11 / 28]: Train Loss=1662.0666, Train Acc=0.6338, Val Acc=0.5905, Time=5.4397 sec
Epoch [12 / 28]: Train Loss=15

Train TransformerClassifier

In [76]:
model = TransformerClassifier(num_classes).to(device)
model, best_ckpt = train_model(model, train_dataset, val_dataset, num_classes=num_classes, epochs=40, batch_size=32, lr=7e-4, patience=10, device=device)
model.load_state_dict(best_ckpt)

model.eval()
preds = []
with torch.no_grad():
    for sequences, seq_lens, _ in test_loader:
        # TODO: inference on the test set
        #####
        sequences = sequences.to(device)
        seq_lens = seq_lens.to(device)
        output = model(sequences, seq_lens)
        preds_batch = output.argmax(dim=1).cpu().numpy()
        preds.extend(preds_batch)
        #####
# save the predictions to a individual CSV file, each row contains the predicted EC number for the corresponding sequence in the test set, no need for header
preds = [ec_list[pred] for pred in preds]
df_preds = pd.DataFrame(preds)
df_preds.to_csv('test_preds_transformer.csv', index=False, header=False)



Epoch [1 / 40]: Train Loss=2537.1804, Train Acc=0.0211, Val Acc=0.0560, Time=3.8230 sec
Epoch [2 / 40]: Train Loss=1966.7617, Train Acc=0.1432, Val Acc=0.2360, Time=3.8565 sec
Epoch [3 / 40]: Train Loss=1534.0374, Train Acc=0.3099, Val Acc=0.3580, Time=3.8387 sec
Epoch [4 / 40]: Train Loss=1273.8096, Train Acc=0.4187, Val Acc=0.4230, Time=3.8467 sec
Epoch [5 / 40]: Train Loss=1101.4379, Train Acc=0.4861, Val Acc=0.4898, Time=3.8595 sec
Epoch [6 / 40]: Train Loss=980.4196, Train Acc=0.5411, Val Acc=0.5300, Time=3.8873 sec
Epoch [7 / 40]: Train Loss=889.8299, Train Acc=0.5752, Val Acc=0.5550, Time=3.9101 sec
Epoch [8 / 40]: Train Loss=816.5716, Train Acc=0.6042, Val Acc=0.5813, Time=3.8661 sec
Epoch [9 / 40]: Train Loss=763.6870, Train Acc=0.6273, Val Acc=0.6000, Time=3.8714 sec
Epoch [10 / 40]: Train Loss=717.1447, Train Acc=0.6465, Val Acc=0.6260, Time=3.8565 sec
Epoch [11 / 40]: Train Loss=677.2745, Train Acc=0.6633, Val Acc=0.6325, Time=3.8543 sec
Epoch [12 / 40]: Train Loss=650.0390

Train CNNClassifier

In [78]:
model = CNNClassifier(num_classes).to(device)
model, best_ckpt = train_model(model, train_dataset, val_dataset, num_classes=num_classes, epochs=20, batch_size=28, lr=5e-4, patience=10, device=device)
model.load_state_dict(best_ckpt)

model.eval()
preds = []
with torch.no_grad():
    for sequences, seq_lens, _ in test_loader:
        # TODO: inference on the test set
        #####
        sequences = sequences.to(device)
        seq_lens = seq_lens.to(device)
        output = model(sequences, seq_lens)
        preds_batch = output.argmax(dim=1).cpu().numpy()
        preds.extend(preds_batch)
        #####
# save the predictions to a individual CSV file, each row contains the predicted EC number for the corresponding sequence in the test set, no need for header
preds = [ec_list[pred] for pred in preds]
df_preds = pd.DataFrame(preds)
df_preds.to_csv('test_preds_cnn.csv', index=False, header=False)

Epoch [1 / 20]: Train Loss=2737.4207, Train Acc=0.0430, Val Acc=0.0457, Time=1.9805 sec
Epoch [2 / 20]: Train Loss=2422.9362, Train Acc=0.1562, Val Acc=0.1950, Time=1.9782 sec
Epoch [3 / 20]: Train Loss=2092.1353, Train Acc=0.3412, Val Acc=0.2898, Time=1.9811 sec
Epoch [4 / 20]: Train Loss=1709.0103, Train Acc=0.5163, Val Acc=0.5175, Time=1.9795 sec
Epoch [5 / 20]: Train Loss=1351.3714, Train Acc=0.6631, Val Acc=0.6348, Time=1.9803 sec
Epoch [6 / 20]: Train Loss=1039.5305, Train Acc=0.7739, Val Acc=0.7490, Time=2.0434 sec
Epoch [7 / 20]: Train Loss=792.3929, Train Acc=0.8399, Val Acc=0.8055, Time=1.9738 sec
Epoch [8 / 20]: Train Loss=613.6025, Train Acc=0.8803, Val Acc=0.8652, Time=1.9723 sec
Epoch [9 / 20]: Train Loss=471.8585, Train Acc=0.9108, Val Acc=0.8245, Time=1.9842 sec
Epoch [10 / 20]: Train Loss=369.4509, Train Acc=0.9299, Val Acc=0.8800, Time=1.9672 sec
Epoch [11 / 20]: Train Loss=292.9254, Train Acc=0.9436, Val Acc=0.8818, Time=1.9777 sec
Epoch [12 / 20]: Train Loss=232.205

## Task 5: Using pretrained protein language model embeddings

In the previous tasks we are using one-hot encoded sequences as the model inputs. With the advancement of language models, many pretrained protein language models (pLM) have been widely used in protein-related problems. Below we are going to explore the usage of pLM embeddings for EC number prediction. We will use ESM-2 (https://github.com/facebookresearch/esm) to extract protein sequence embeddings. First you need to to check ESM-2's documentation to learn how to generate the embeddings using the fasta files we have provided. You should use the model `esm2_t33_650M_UR50D`, retrieve the last-layer sequence-level embedding (no need for residue-level embedding). You should generate one `.pt` file for each sequence embedding and save it in the directory `esm_embeddings`. Since the embedding generation can be time-consuming, we will use a subsampled training set (50% of the original training set) for this task. If you have adquate computational resources, you can also use the complete training set.

In [17]:
!pip install fair-esm
!pip install biopython

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [44]:
import esm
from Bio import SeqIO
from tqdm.auto import tqdm
import os
import numpy as np
import torch

def gen_emb(fasta_file, out_dir='esm_embeddings', device='cuda:0'):
    records = list(SeqIO.parse(fasta_file, 'fasta'))
    names = [rec.id for rec in records]
    sequences = [str(rec.seq) for rec in records]
    print(f'Number of sequences: {len(sequences)}')

    data = [(name, seq) for name, seq in zip(names, sequences)]

    # TODO: Load ESM-2 model (esm2_t33_650M_UR50D) and batch converter
    #####
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    #####
    model.to(device)
    model.eval()  # disables dropout for deterministic results

    batch_size = 16 # Reduce if you are running out of cuda memory
    num_batches = int(np.ceil(len(data) / batch_size))

    for i in tqdm(range(num_batches)):
        batch = data[i * batch_size:(i + 1) * batch_size]
        names_batch, seqs_batch = zip(*batch)
        batch_labels, batch_strs, batch_tokens = batch_converter(batch)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        batch_tokens = batch_tokens.to(device)
        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            # TODO: inference
            #####
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
            #####
        # TODO: get per-residue representations
        #####
        token_representations = results["representations"][33].cpu()
        #####
        # Generate per-sequence representations via averaging
        for k, tokens_len in enumerate(batch_lens):
            seq_name = names_batch[k]
            seq_tokens = token_representations[k, :tokens_len]
            seq_mean = seq_tokens.mean(0)
            save = {'mean_representations': {33: seq_mean}}
            torch.save(save, os.path.join(out_dir, f'{seq_name}.pt'))

You have two options for getting the ESM-2 embeddings:

### Option 1: Generate the embeddings by yourself
You can run the `gen_emb` function in the following two cells to generate the embeddings.

In [13]:
gen_emb('test_seqs.fasta')

Number of sequences: 2000


100%|██████████| 125/125 [04:27<00:00,  2.14s/it]


In [14]:
gen_emb('train_subsample.fasta')

Number of sequences: 10000


100%|██████████| 625/625 [22:30<00:00,  2.16s/it]


### Option 2: Download the precomputed embeddings
If you have trouble with the GPU or PACE cluster, you can choose to run the following cell to download the precomputed embeddings. Note that you still need to implement the `gen_emb` function if you use the precomputed embeddings.

In [19]:
!wget https://drive.usercontent.google.com/download\?id\=1wLGtohLE1vdZigOxs9T7o-7STT_Jkpi5\&export\=download\&authuser\=0\&confirm\=t\&uuid\=ec23bb94-d652-41e8-9452-532243d8f7b9\&at\=AEz70l7XSEc5DkNaHY0svMQy9lv6:1740447531952 -O esm_embeddings.zip
!unzip esm_embeddings.zip

--2025-02-26 11:53:01--  https://drive.usercontent.google.com/download?id=1wLGtohLE1vdZigOxs9T7o-7STT_Jkpi5&export=download&authuser=0&confirm=t&uuid=ec23bb94-d652-41e8-9452-532243d8f7b9&at=AEz70l7XSEc5DkNaHY0svMQy9lv6:1740447531952
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 108.177.122.132, 2607:f8b0:4002:c02::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|108.177.122.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 122685700 (117M) [application/octet-stream]
Saving to: ‘esm_embeddings.zip’


2025-02-26 11:53:02 (107 MB/s) - ‘esm_embeddings.zip’ saved [122685700/122685700]

Archive:  esm_embeddings.zip
   creating: esm_embeddings/
  inflating: esm_embeddings/train_seq_7025.pt  
  inflating: esm_embeddings/train_seq_19379.pt  
  inflating: esm_embeddings/train_seq_177.pt  
  inflating: esm_embeddings/train_seq_19136.pt  
  inflating: esm_embeddings/train_seq_8971.pt  
  inflating: esm_embeddings/tr

In [45]:
df = pd.read_csv('train_subsample.csv')
sequences = df['Sequence']
ec_numbers = df['EC number'].tolist()

train_seq2name = {seq: f'train_seq_{i}' for i, seq in enumerate(sequences)}
seq_train, seq_val, ec_train, ec_val = train_test_split(sequences, ec_numbers, test_size=0.2, random_state=42)

Construct a new dataset class to use ESM-2 embeddings.

In [46]:
class ProteinESMDataset(Dataset):
    def __init__(self, sequences, seq2name, emb_dir, labels, ec2idx):
        super().__init__()
        self.labels = [ec2idx.get(ec, -1) for ec in labels]
        self.embeddings = []
        for seq in tqdm(sequences, desc='Loading esm embeddings'):
            name = seq2name[seq]
            emb_file = os.path.join(emb_dir, f'{name}.pt')
            emb = torch.load(emb_file)['mean_representations'][33]
            self.embeddings.append(emb)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, index):
        emb = self.embeddings[index]
        label = torch.tensor(self.labels[index], dtype=torch.long)
        return emb, label

Implement a simple MLP model to use ESM-2 embeddings as inputs for EC number prediction.

In [47]:
class MLPClassifier(nn.Module):
    def __init__(self, num_classes, input_dim=1280, hidden_dim=640):
        super(MLPClassifier, self).__init__()
        # TODO: linear layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim // 2, num_classes)

    def forward(self, x):
        # TODO: forward function
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        
        return x

Complete the function `train_model_esm`, which trains the model taking ESM-2 embeddings as inputs.

In [48]:
def train_model_esm(model, train_dataset, val_dataset, num_classes, epochs=100, batch_size=256, lr=1e-3, patience=10, device='cuda:0'):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_acc = 0
    patience_counter = 0
    best_ckpt = None

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0

        for sequences, labels in train_loader:
            # TODO: backpropogation
            #####
            sequences, labels = sequences.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            #####

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for sequences, labels in val_loader:
                # TODO: inference
                #####
                sequences, labels = sequences.to(device), labels.to(device)
                outputs = model(sequences)
                preds = outputs.argmax(dim=1)
                #####
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        print(f'Epoch {epoch+1}: Train Loss={total_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}')

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            best_ckpt = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    return model, best_ckpt

Train the MLP model and generate predictions for the test set.

In [49]:
emb_dir = 'esm_embeddings'
os.makedirs(emb_dir, exist_ok=True)
train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, ec_train, ec2idx)
val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, ec_val, ec2idx)
num_classes = len(ec_list)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MLPClassifier(num_classes).to(device)
model, best_ckpt = train_model_esm(model, train_dataset, val_dataset, num_classes=num_classes, epochs=20, batch_size=32, lr=1e-4, patience=3, device=device)
model.load_state_dict(best_ckpt)

df_test = pd.read_csv('test.csv')
test_sequences = df_test['Sequence'].tolist()
test_seq2name = {seq: f'test_seq_{i}' for i, seq in enumerate(test_sequences)}
test_dataset = ProteinESMDataset(test_sequences, test_seq2name, emb_dir, [0]*len(test_sequences), ec2idx)
test_loader = DataLoader(test_dataset, batch_size=256)

model.eval()
preds = []
with torch.no_grad():
    for sequences, _ in test_loader:
        # TODO: inference on the test set
        #####
        sequences = sequences.to(device)
        outputs = model(sequences)  
        batch_preds = outputs.argmax(dim=1).cpu().tolist()  
        preds.extend(batch_preds)
        #####
# save the predictions to a individual CSV file, each row contains the predicted EC number for the corresponding sequence in the test set, no need for header
preds = [ec_list[pred] for pred in preds]
df_preds = pd.DataFrame(preds)
df_preds.to_csv('test_preds_esm.csv', index=False, header=False)

Loading esm embeddings: 100%|██████████| 8000/8000 [00:04<00:00, 1612.54it/s]
Loading esm embeddings: 100%|██████████| 2000/2000 [00:01<00:00, 1801.96it/s]


Epoch 1: Train Loss=1307.9856, Train Acc=0.0739, Val Acc=0.1495
Epoch 2: Train Loss=1130.7603, Train Acc=0.1547, Val Acc=0.2695
Epoch 3: Train Loss=806.7474, Train Acc=0.3991, Val Acc=0.4800
Epoch 4: Train Loss=561.5579, Train Acc=0.6161, Val Acc=0.6785
Epoch 5: Train Loss=396.9936, Train Acc=0.7598, Val Acc=0.7785
Epoch 6: Train Loss=290.4386, Train Acc=0.8313, Val Acc=0.8340
Epoch 7: Train Loss=218.0014, Train Acc=0.8806, Val Acc=0.8910
Epoch 8: Train Loss=168.1701, Train Acc=0.9093, Val Acc=0.9080
Epoch 9: Train Loss=133.5334, Train Acc=0.9271, Val Acc=0.9295
Epoch 10: Train Loss=108.5431, Train Acc=0.9403, Val Acc=0.9440
Epoch 11: Train Loss=89.2271, Train Acc=0.9510, Val Acc=0.9420
Epoch 12: Train Loss=75.3087, Train Acc=0.9571, Val Acc=0.9500
Epoch 13: Train Loss=64.2074, Train Acc=0.9643, Val Acc=0.9615
Epoch 14: Train Loss=55.0883, Train Acc=0.9696, Val Acc=0.9605
Epoch 15: Train Loss=47.8729, Train Acc=0.9736, Val Acc=0.9705
Epoch 16: Train Loss=41.7994, Train Acc=0.9758, Val 

Loading esm embeddings: 100%|██████████| 2000/2000 [00:01<00:00, 1694.86it/s]


## Grading

### Task 1 (4 points)
Task 1 will be graded by the correctness of the function `one_hot_encode`.

### Task 2 (9 points)
Tasks 2-5 will be graded by the accuracy of the predictions made by each model on the test set.

- Accuracy >= 0.7: 9 points
- Accuracy >= 0.65: 8 points
- Accuracy >= 0.6: 7 points
- Accuracy >= 0.55: 6 points
- Accuracy < 0.55: 0 points

### Task 3 (9 points)

- Accuracy >= 0.7: 9 points
- Accuracy >= 0.65: 8 points
- Accuracy >= 0.6: 7 points
- Accuracy >= 0.55: 6 points
- Accuracy < 0.55: 0 points

### Task 4 (9 points)

- Accuracy >= 0.5: 9 points
- Accuracy >= 0.45: 7 points
- Accuracy < 0.45: 0 points

### Task 5 (9 points)

- Accuracy >= 0.97: 9 points
- Accuracy >= 0.96: 7 points
- Accuracy >= 0.95: 5 points
- Accuracy < 0.95: 0 points

## Submission

After completing all the tasks, you need to submit five files to Gradescope:
- `hw3_nn.ipynb`, the notebook file with all tasks completed.
- `test_preds_attention.csv`
- `test_preds_transformer.csv`
- `test_preds_cnn.csv`
- `test_preds_esm.csv`
- `weights.csv`: the answer for Problem 3: Design a Neural Network by Hand.

Note that the four `.csv` files for predictions on the test set will be automatically generated when running this notebook, do not change the codes regarding the save of the prediction results. **DO NOT** submit the files in a zip file, please submit them individually to Gradescope.