In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

from sklearn.metrics import f1_score

! pip install fair-esm
import esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
root = "/content/drive/MyDrive/CS279_Project/"

In [None]:
def read_file(file):
  with open(file, "r") as f:
    sequence = f.readlines()[1].strip()
  return sequence

In [None]:
def read_data(path):
  data = []
  fastas = [i for i in os.listdir(path) if i.endswith("fasta")]
  for fasta in fastas:
    print(fasta)
    prot_id = fasta[:fasta.rindex(".")]
    # prot_files = [i for i in os.listdir(root + "training/") if i.startswith(prot_id)]
    sequence = read_file(path + prot_id + ".fasta")
    secondary_structure = read_file(path + prot_id + ".dssp")
    assert len(sequence) == len(secondary_structure)
    data.append([prot_id, sequence, secondary_structure])
  return data

In [None]:
data = read_data(root + "retr231/training/")
len(data)

d1tf5a2.fasta
d2a65a1.fasta
d1n62c1.fasta
e1qd6.1C.fasta
d1kafa_.fasta
d3u9wa1.fasta
d3bvua2.fasta
d1px5a1.fasta
d1or7c_.fasta
d1wuil_.fasta
d2p6va1.fasta
d2baya_.fasta
d2w1va_.fasta
d3lyea_.fasta
d3piwa_.fasta
d1nkda_.fasta
d2gsva1.fasta
d1dtdb_.fasta
d3u1ua_.fasta
d2rkqa_.fasta
d2f2ha1.fasta
d1hf2a1.fasta
d1m15a2.fasta
d1usub_.fasta
d1nvus_.fasta
d2hq4a1.fasta
d3o6ca_.fasta
d1lg7a_.fasta
d1vq8v1.fasta
d1k6ka_.fasta
d1cy5a_.fasta
d1vr4a1.fasta
d1uoya_.fasta
d1cipa1.fasta
d4dnda_.fasta
d2pspa1.fasta
d1g8ea_.fasta
d3e9va1.fasta
d1c5ea_.fasta
d1t6aa_.fasta
d1izma_.fasta
d1a62a1.fasta
d1ji7a_.fasta
d2asba2.fasta
d3ar4a4.fasta
d3swoa2.fasta
d2aska_.fasta
d2fmaa_.fasta
d1l0qa1.fasta
d1llaa1.fasta
d3bhpa_.fasta
d1wy6a_.fasta
d3orsa_.fasta
d1bx7a_.fasta
d3w5ha1.fasta
d2j97a_.fasta
d1dk8a_.fasta
d2a26a1.fasta
d2je6i2.fasta
d3h7ia2.fasta
d1gpja3.fasta
d1sg4a1.fasta
d2xu3a_.fasta
d2a9ua1.fasta
d1boua_.fasta
d1mkya3.fasta
d1l3la2.fasta
d3pkza_.fasta
d1oisa_.fasta
d1ufya_.fasta
d2hnua_.fasta
d1eu1

1348

In [None]:
# pickle.dump(data, open(root + "all_data.dat", "wb"))

In [None]:
data = pickle.load(open(root + "all_data.dat", "rb"))

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [None]:
model.to(device)

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [None]:
def encode_secondary_structure(string):
  map = {"-": [1,0,0], "H": [0,1,0], "E": [0,0,1]}
  return torch.tensor([map[char] for char in string], dtype=torch.float32)

In [None]:
def embed_and_encode(data, model, device):
  for i in range(len(data)):
    tup = (data[i][0], data[i][1])
    _, _, batch_tokens = batch_converter([tup])
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    token_representations = results["representations"][33]
    token_representations = token_representations[0,1:-1,:].cpu()
    data[i][1] = token_representations
    data[i][2] = encode_secondary_structure(data[i][2])
    if (i + 1) % 100 == 0:
      print("Finished with", i+1)
  return data

In [None]:
data = embed_and_encode(data, model, device)

Finished with 99
Finished with 199
Finished with 299
Finished with 399
Finished with 499
Finished with 599
Finished with 699
Finished with 799
Finished with 899
Finished with 999
Finished with 1099
Finished with 1199
Finished with 1299


In [None]:
# pickle.dump(data, open(root + "data_embedded_and_encoded.dat", "wb"))

In [4]:
data = pickle.load(open(root + "data_embedded_and_encoded.dat", "rb"))

In [5]:
class EmbeddingDataset(Dataset):
  def __init__(self, data):
    self.data = data

  def __getitem__(self, index):
    prot_id, embedding, label = self.data[index]
    return embedding, label

  def __len__(self):
    return len(self.data)

def collate_fn(batch):
  padded_embeddings = pad_sequence([i[0] for i in batch], batch_first=True)
  padded_labels = pad_sequence([i[1] for i in batch], batch_first=True)
  mask = (padded_embeddings != 0).any(dim=-1)
  return padded_embeddings, padded_labels, mask

In [6]:
np.random.seed(0)
train_idx = np.random.choice(len(data), size=int(len(data) * 0.7), replace=False)
val_idx = [i for i in range(len(data)) if not i in train_idx]

In [7]:
train_dataset = EmbeddingDataset(list(map(data.__getitem__, train_idx)))
val_dataset = EmbeddingDataset(list(map(data.__getitem__, val_idx)))
train_dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn, shuffle=False)

# Sanity Check
for batch in train_dataloader:
    sequences, labels, mask = batch
    print(sequences.shape)
    print(labels.shape)
    print(mask.shape)

torch.Size([32, 401, 1280])
torch.Size([32, 401, 3])
torch.Size([32, 401])
torch.Size([32, 483, 1280])
torch.Size([32, 483, 3])
torch.Size([32, 483])
torch.Size([32, 557, 1280])
torch.Size([32, 557, 3])
torch.Size([32, 557])
torch.Size([32, 328, 1280])
torch.Size([32, 328, 3])
torch.Size([32, 328])
torch.Size([32, 413, 1280])
torch.Size([32, 413, 3])
torch.Size([32, 413])
torch.Size([32, 469, 1280])
torch.Size([32, 469, 3])
torch.Size([32, 469])
torch.Size([32, 539, 1280])
torch.Size([32, 539, 3])
torch.Size([32, 539])
torch.Size([32, 612, 1280])
torch.Size([32, 612, 3])
torch.Size([32, 612])
torch.Size([32, 451, 1280])
torch.Size([32, 451, 3])
torch.Size([32, 451])
torch.Size([32, 759, 1280])
torch.Size([32, 759, 3])
torch.Size([32, 759])
torch.Size([32, 452, 1280])
torch.Size([32, 452, 3])
torch.Size([32, 452])
torch.Size([32, 715, 1280])
torch.Size([32, 715, 3])
torch.Size([32, 715])
torch.Size([32, 522, 1280])
torch.Size([32, 522, 3])
torch.Size([32, 522])
torch.Size([32, 410, 1280

In [9]:
class ProteinSSModel(nn.Module):
  def __init__(self, embedding_dim, out_dim, body_num=5, dropout=0.4):
    super(ProteinSSModel, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv1d(embedding_dim, 128, kernel_size=1),
      nn.GELU(),
      nn.BatchNorm1d(128)
    )
    self.body_num = body_num
    self.body1 = nn.ModuleList([
      nn.Sequential(
          nn.Conv1d(128, 128, kernel_size=1, padding=0),
          nn.GELU(),
          nn.BatchNorm1d(128),
      ) for _ in range(body_num)
    ])
    self.bn_list = nn.ModuleList([nn.BatchNorm1d(128) for _ in range(body_num)])
    self.dropout = nn.Dropout(p=dropout)
    self.tail = nn.Sequential(
      nn.Conv1d(128, 64, kernel_size=3, padding=1),
      nn.Conv1d(64, 32, kernel_size=3, padding=1),
      nn.GELU(),
      nn.BatchNorm1d(32),
      nn.Conv1d(32, 16, kernel_size=3, padding=1),
      nn.Conv1d(16, out_dim, kernel_size=1)
    )

  def forward(self, x, mask=None):
    x = x.transpose(-1, -2).contiguous()
    x = self.conv(x)
    for bind in range(self.body_num):
      x1 = self.body1[bind](x)
      x = self.bn_list[bind](x + x1)
    x = self.dropout(x)
    x = self.tail(x)
    x = x.transpose(-1, -2).contiguous()
    if mask is not None:
      x = x * mask.unsqueeze(-1)
    return x
    # torch.softmax(x, dim=-1)

In [24]:
def categorical_accuracy(predictions, labels, masks):
  correct = (predictions.argmax(dim=-1) == labels.argmax(dim=-1))
  correct = correct * masks
  total_valid = masks.sum()
  accuracy = correct.sum().float() / total_valid
  return accuracy.item()

def f1_score_macro(predictions, labels, masks, num_classes=3):
    predicted_classes = predictions.argmax(dim=-1)  # Shape: [batch_size, max_seq_len]
    labels = labels.argmax(dim=-1)
    pred_flat = predicted_classes[masks.bool()].cpu().numpy()
    labels_flat = labels[masks.bool()].cpu().numpy()
    f1 = f1_score(labels_flat, pred_flat, labels=list(range(num_classes)), average="macro")
    return f1

In [17]:
# Model input: [batch_size, max_seq_len, embedding_dim]
# Labels shape: [batch_size, max_seq_len, 3] (one-hot encoded for 3 classes)
# Mask shape: [batch_size, max_seq_len] (1 for valid positions, 0 for padding)

def train(model, device, train_loader, val_loader, epochs=20, lr=5e-4):
  model = model.to(device)

  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.CrossEntropyLoss(reduction="none")

  for epoch in range(epochs):
    # Train
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    train_f1 = 0.0
    for inputs, labels, masks in train_loader:
      inputs = inputs.to(device)  # [batch_size, max_seq_len, embedding_dim]
      labels = labels.to(device)  # [batch_size, max_seq_len, 3]
      masks = masks.to(device)   # [batch_size, max_seq_len]

      outputs = model(inputs)  # Outputs: [batch_size, max_seq_len, 3]
      outputs_flat = outputs.view(-1, labels.shape[-1])  # [batch_size * max_seq_len, 3]
      labels_flat = labels.view(-1, labels.shape[-1]).argmax(dim=-1)
      masks_flat = masks.view(-1)  # [batch_size * max_seq_len]

      loss = criterion(outputs_flat, labels_flat)  # [batch_size * max_seq_len]
      loss = (loss * masks_flat).sum() / masks_flat.sum()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss += loss.item()
      acc = categorical_accuracy(outputs, labels, masks)
      f1 = f1_score_macro(outputs, labels, masks)
      train_acc += acc
      train_f1 += f1

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = train_acc / len(train_loader)
    avg_train_f1 = train_f1 / len(train_loader)
    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Accuracy: {avg_train_acc:.4f}, F1 Macro: {avg_train_f1:.4f}", end="\t")

    # Validate
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    val_f1 = 0.0
    with torch.no_grad():
      for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        outputs = model(inputs)
        outputs_flat = outputs.view(-1, labels.shape[-1])
        labels_flat = labels.view(-1, labels.shape[-1]).argmax(dim=-1)
        masks_flat = masks.view(-1)

        loss = criterion(outputs_flat, labels_flat)
        loss = (loss * masks_flat).sum() / masks_flat.sum()
        val_loss += loss.item()
        acc = categorical_accuracy(outputs, labels, masks)
        f1 = f1_score_macro(outputs, labels, masks)
        val_acc += acc
        val_f1 += f1

    avg_val_loss = val_loss / len(val_loader)
    avg_val_acc = val_acc / len(val_loader)
    avg_val_f1 = val_f1 / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}, F1 Macro: {avg_val_f1:.4f}")

  return model

In [18]:
ssmodel = ProteinSSModel(embedding_dim=data[0][1].shape[-1],
                         out_dim=data[0][2].shape[-1],
                         body_num=10)

In [19]:
ssmodel = train(ssmodel, device, train_dataloader, val_dataloader, epochs=10, lr=5e-4)

Epoch 1/10, Train Loss: 0.6682, Accuracy: 0.7214, F1 Macro: 0.6714	Validation Loss: 0.7728, Accuracy: 0.8141, F1 Macro: 0.7975
Epoch 2/10, Train Loss: 0.3854, Accuracy: 0.8527, F1 Macro: 0.8476	Validation Loss: 0.3629, Accuracy: 0.8591, F1 Macro: 0.8515
Epoch 3/10, Train Loss: 0.3404, Accuracy: 0.8673, F1 Macro: 0.8627	Validation Loss: 0.3551, Accuracy: 0.8600, F1 Macro: 0.8521
Epoch 4/10, Train Loss: 0.3125, Accuracy: 0.8762, F1 Macro: 0.8721	Validation Loss: 0.3520, Accuracy: 0.8618, F1 Macro: 0.8547
Epoch 5/10, Train Loss: 0.2892, Accuracy: 0.8857, F1 Macro: 0.8823	Validation Loss: 0.3694, Accuracy: 0.8579, F1 Macro: 0.8515
Epoch 6/10, Train Loss: 0.2627, Accuracy: 0.8949, F1 Macro: 0.8912	Validation Loss: 0.3814, Accuracy: 0.8553, F1 Macro: 0.8491
Epoch 7/10, Train Loss: 0.2416, Accuracy: 0.9035, F1 Macro: 0.8996	Validation Loss: 0.4174, Accuracy: 0.8451, F1 Macro: 0.8407
Epoch 8/10, Train Loss: 0.2138, Accuracy: 0.9162, F1 Macro: 0.9131	Validation Loss: 0.4282, Accuracy: 0.8457, F

In [None]:
torch.cuda.empty_cache()

In [21]:
parameters = filter(lambda p: p.requires_grad, ssmodel.parameters())
params = sum([np.prod(p.size()) for p in parameters])

In [22]:
params

366947

In [25]:
def categorical_accuracy_verbose(predictions, labels, masks):
  correct = (predictions.argmax(dim=-1) == labels.argmax(dim=-1))
  correct = correct * masks
  total_valid = masks.sum()
  accuracy = correct.sum().float() / total_valid
  return accuracy.item(), predictions.argmax(dim=-1).numpy(), labels.argmax(dim=-1).numpy()

In [69]:
def class_wise_accuracy(predictions, labels, masks, num_classes=3):
  predicted_classes = predictions.argmax(dim=-1)
  true_classes = labels.argmax(dim=-1)
  class_accuracies = []
  for cls in range(num_classes):
    class_mask = (true_classes == cls) * masks
    correct = ((predicted_classes == cls) & (true_classes == cls)) * class_mask
    correct_count = correct.sum().float()
    total_count = class_mask.sum().float()
    accuracy = correct_count / (total_count + 1e-5)
    class_accuracies.append(accuracy.item())
  return class_accuracies

In [71]:
for inputs, labels, masks in DataLoader(val_dataset, batch_size=len(val_dataset), collate_fn=collate_fn, shuffle=False):
  outputs = ssmodel(inputs.to(device)).detach().cpu()
  predictions = torch.softmax(outputs, dim=-1)
  class_accuracies = class_wise_accuracy(predictions, labels, masks)
  print(class_accuracies, inputs.shape[0])

[0.8537600636482239, 0.8966524004936218, 0.7440081834793091] 405
