In [1]:
import sys

sys.path.append("..")

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path

plt.rcParams["figure.dpi"] = 300
sns.set_theme(style="whitegrid")

In [3]:
%load_ext watermark
%watermark

Last updated: 2024-02-15T17:02:12.444991-08:00

Python implementation: CPython
Python version       : 3.10.13
IPython version      : 8.18.1

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 5.15.133.1-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 28
Architecture: 64bit



In [4]:
data_path = Path("../data/cafa5")

In [5]:
df = pd.read_parquet(data_path / "top100_train_split.parquet")
df.head()

Unnamed: 0,Entry ID,Sequence,Index
59837,Q75PL7,MPPNSVDKTNETEYLKDNHVDYEKLIAPQASPIKHKIVVMNVIRFS...,59837
12347,E9QA15,MDDFERRRELRRQKREEMRLEAERIAYQRNDDDEEEAARERRRRAR...,12347
62582,Q84TI3,MKRLRSSDDLDFCNDKNVDGEPPNSDRPASSSHRGFFSGNNRDRGE...,62582
47476,Q24060,MRYLCVFSLTLILCCLSIKAQSLNCTRLRENCRPCTRRLVDPINDL...,47476
40185,P84282,APECGREAHCGDDCQSQVVTRDFDDRTCPKLLCCSKDGWCGNTDAN...,40185


In [6]:
from datasets import Dataset

In [7]:
dataset = Dataset.from_pandas(df, preserve_index=False)
dataset

Dataset({
    features: ['Entry ID', 'Sequence', 'Index'],
    num_rows: 70921
})

In [8]:
dataset = dataset.train_test_split(test_size=0.25, seed=0)
dataset

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index'],
        num_rows: 17731
    })
})

In [9]:
dataset["train"][0]

{'Entry ID': 'Q6U841',
 'Sequence': 'MEIKDQGAQMEPLLPTRNDEEAVVDRGGTRSILKTHFEKEDLEGHRTLFIGVHVPLGGRKSHRRHRHRGHKHRKRDRERDSGLEDGRESPSFDTPSQRVQFILGTEDDDEEHIPHDLFTELDEICWREGEDAEWRETARWLKFEEDVEDGGERWSKPYVATLSLHSLFELRSCILNGTVLLDMHANTLEEIADMVLDQQVSSGQLNEDVRHRVHEALMKQHHHQNQKKLTNRIPIVRSFADIGKKQSEPNSMDKNAGQVVSPQSAPACVENKNDVSRENSTVDFSKGLGGQQKGHTSPCGMKQRHEKGPPHQQEREVDLHFMKKIPPGAEASNILVGELEFLDRTVVAFVRLSPAVLLQGLAEVPIPTRFLFILLGPLGKGQQYHEIGRSIATLMTDEVFHDVAYKAKDRNDLVSGIDEFLDQVTVLPPGEWDPSIRIEPPKNVPSQEKRKIPAVPNGTAAHGEAEPHGGHSGPELQRTGRIFGGLILDIKRKAPYFWSDFRDAFSLQCLASFLFLYCACMSPVITFGGLLGEATEGRISAIESLFGASMTGIAYSLFGGQPLTILGSTGPVLVFEKILFKFCKEYGLSYLSLRASIGLWTATLCIILVATDASSLVCYITRFTEEAFASLICIIFIYEALEKLFELSEAYPINMHNDLELLTQYSCNCVEPHNPSNGTLKEWRESNISASDIIWENLTVSECKSLHGEYVGRACGHDHPYVPDVLFWSVILFFSTVTLSATLKQFKTSRYFPTKVRSIVSDFAVFLTILCMVLIDYAIGIPSPKLQVPSVFKPTRDDRGWFVTPLGPNPWWTVIAAIIPALLCTILIFMDQQITAVIINRKEHKLKKGCGYHLDLLMVAVMLGVCSIMGLPWFVAATVLSITHVNSLKLESECSAPGEQPKFLGIREQRVTGLMIFILMGSSVFMTSILKFIPMPVLYGVFLYMGASSLKGIQFFDRIKLFW

In [10]:
import os
import json
import torch
import torch.nn as nn
import random
from transformers import AutoTokenizer, EsmModel

In [11]:
model_name = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [12]:
def tokenize_seqs(batch):
    return tokenizer(
        batch["Sequence"],
        padding="longest",
        truncation=True,
    )

In [13]:
tokenized_dataset = dataset.map(tokenize_seqs, batched=True)
tokenized_dataset

Map:   0%|          | 0/53190 [00:00<?, ? examples/s]

Map:   0%|          | 0/17731 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 17731
    })
})

In [14]:
tokenized_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "Index"]
)
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 17731
    })
})

In [15]:
tokenized_dataset["train"].num_rows

53190

In [16]:
from torch.utils.data import DataLoader, Dataset

In [17]:
class ProteinDataset(Dataset):
    def __init__(self, dataset, split="train"):
        self.data = dataset[split]

    def __len__(self):
        return self.data.num_rows

    def __getitem__(self, i):
        return self.data[i]

In [18]:
train_dataset = ProteinDataset(tokenized_dataset)
val_dataset = ProteinDataset(tokenized_dataset, split="test")

In [19]:
# targets
targets = np.load(data_path / "train_bp_top100_targets.npy")
targets.shape

(88652, 100)

In [20]:
def collate_fn(batch):
    batch = tokenizer.pad(batch)
    batch["targets"] = torch.as_tensor(
        targets[batch["Index"].tolist()], dtype=torch.float
    )

    return batch

In [21]:
batch_size = 8
num_workers = 8

In [22]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

In [23]:
next(iter(train_loader))

{'Index': tensor([ 5574, 19919, 37912,  7491, 17731, 18442, 79841, 68678]), 'input_ids': tensor([[ 0, 20, 13,  ...,  1,  1,  1],
        [ 0, 20,  9,  ...,  1,  1,  1],
        [ 0, 20,  7,  ...,  1,  1,  1],
        ...,
        [ 0, 20,  8,  ...,  1,  1,  1],
        [ 0, 20,  8,  ...,  1,  1,  1],
        [ 0, 20, 14,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'targets': tensor([[1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
   

## 🏃Training

#### 🤖 Model

In [24]:
llm = EsmModel.from_pretrained(model_name)
embedding_dim = llm.config.hidden_size
embedding_dim

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


480

In [25]:
class FinetunedESM(nn.Module):
    def __init__(self, llm, dropout_p, embedding_dim, num_classes):
        super().__init__()
        self.llm = llm
        self.dropout_p = dropout_p
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
        self.dropout = nn.Dropout(dropout_p)
        self.pre_classifier = nn.Linear(embedding_dim, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def mean_pooling(self, token_embeddings, attention_mask):
        """Average the embedding of all amino acids in a sequence"""

        # expand the mask
        expanded_mask = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.shape).float()
        )

        # sum unmasked token embeddings
        sum_embeddings = torch.sum(token_embeddings * expanded_mask, dim=1)

        # number of unmasked tokens for each sequence
        # set a min value to avoid divide by zero
        num_tokens = torch.clamp(expanded_mask.sum(1), min=1e-9)

        # divide
        mean_embeddings = sum_embeddings / num_tokens
        return mean_embeddings

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]

        # per token representations from the last layer
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state

        # average per token representations
        mean_embeddings = self.mean_pooling(token_embeddings, attention_mask)

        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/distilbert/modeling_distilbert.py
        mean_embeddings = self.pre_classifier(mean_embeddings)  # (bs, embedding_dim)
        mean_embeddings = nn.ReLU()(mean_embeddings)
        mean_embeddings = self.dropout(mean_embeddings)

        logits = self.classifier(mean_embeddings)  # (bs, num_classes)
        return logits

    @torch.inference_mode()
    def predict(self, batch):
        self.eval()
        y = self(batch)
        return y.cpu().numpy()

    def save(self, dp):
        with open(Path(dp, "args.json"), "w") as fp:
            contents = {
                "dropout_p": self.dropout_p,
                "embedding_dim": self.embedding_dim,
                "num_classes": self.num_classes,
            }
            json.dump(contents, fp, indent=4, sort_keys=False)

        torch.save(self.state_dict(), Path(dp) / "model.pt")

    @classmethod
    def load(cls, esm_model, args_fp, state_dict_fp):
        with open(args_fp, "r") as fp:
            kwargs = json.load(fp=fp)

        llm = EsmModel.from_pretrained(esm_model)
        model = cls(llm=llm, **kwargs)
        model.load_state_dict(
            torch.load(state_dict_fp, map_location=torch.device("cpu"))
        )
        return model

In [26]:
num_classes = 100

model = FinetunedESM(
    llm=llm, dropout_p=0.1, embedding_dim=embedding_dim, num_classes=num_classes
)
print(model.parameters)

<bound method Module.parameters of FinetunedESM(
  (llm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNo

In [27]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [28]:
print(f"# Trainable Parameters: {count_parameters(model)}")

# Trainable Parameters: 34271861


In [29]:
import math
from functools import partial

In [30]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        self.W_a = nn.Parameter(torch.randn(in_dim, rank) / math.sqrt(rank))
        self.W_b = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = x @ self.W_a  # batch * rank
        x = x @ self.W_b  # batch * out_dim
        return self.alpha * x

In [31]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [32]:
def apply_lora(
    model,
    lora_rank=8,
    lora_alpha=16,
    lora_query=True,
    lora_key=False,
    lora_value=True,
    lora_projection=False,
    lora_mlp=False,
    lora_head=True,
):
    # freeze model layers
    for param in model.parameters():
        param.requires_grad = False

    # for adding lora to linear layers
    linear_with_lora = partial(LinearWithLoRA, rank=lora_rank, alpha=lora_alpha)

    # iterate through each transfomer layer
    for layer in model.llm.encoder.layer:
        if lora_query:
            layer.attention.self.query = linear_with_lora(layer.attention.self.query)

        if lora_key:
            layer.attention.self.key = linear_with_lora(layer.attention.self.key)

        if lora_value:
            layer.attention.self.value = linear_with_lora(layer.attention.self.value)

        if lora_projection:
            layer.attention.output.dense = linear_with_lora(
                layer.attention.output.dense
            )

        if lora_mlp:
            layer.output.dense = linear_with_lora(layer.output.dense)
            layer.output.dense = linear_with_lora(layer.output.dense)

    if lora_head:
        model.pre_classifier = linear_with_lora(model.pre_classifier)
        model.classifier = linear_with_lora(model.classifier)

In [33]:
apply_lora(model)
print(f"# Trainable Parameters: {count_parameters(model)}")

# Trainable Parameters: 196640


In [34]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torchmetrics.functional.classification import multilabel_f1_score

In [35]:
L.seed_everything(0)

Seed set to 0


0

In [36]:
class ESMLightningModule(L.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.loss_fn(logits, batch["targets"])
        self.log("train_loss", loss)
        return loss  # this is passed to the optimizer for training

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.loss_fn(logits, batch["targets"])
        self.log("val_loss", loss, prog_bar=True)

        f1_score = multilabel_f1_score(
            logits, batch["targets"].type(torch.int), num_classes
        )
        self.log("val_f1_score", f1_score, prog_bar=True)

    def test_step(self, batch, batch_idx):
        logits = self(batch)

        f1_score = multilabel_f1_score(
            logits, batch["targets"].type(torch.int), num_classes
        )
        self.log("f1_score", f1_score, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [37]:
lightning_model = ESMLightningModule(model)

In [38]:
callbacks = [ModelCheckpoint(save_top_k=1, mode="max", monitor="val_f1_score")]
logger = CSVLogger(save_dir="logs/", name="esm_model_lora")

In [39]:
trainer = L.Trainer(
    max_epochs=5,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    deterministic=True,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [40]:
import time

In [41]:
start = time.time()

trainer.fit(
    model=lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)
end = time.time()
print(f"Training Time: {(end-start)/60:.2f} min")

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | model   | FinetunedESM      | 34.5 M
1 | loss_fn | BCEWithLogitsLoss | 0     
----------------------------------------------
196 K     Trainable params
34.3 M    Non-trainable params
34.5 M    Total params
137.874   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


Training Time: 170.78 min
