In [1]:
import sys

sys.path.append("..")

In [2]:
import pandas as pd
import numpy as np

from pathlib import Path

In [3]:
%load_ext watermark
%watermark

Last updated: 2024-02-17T01:43:23.810230+00:00

Python implementation: CPython
Python version       : 3.10.13
IPython version      : 8.21.0

Compiler    : GCC 12.3.0
OS          : Linux
Release     : 5.15.0-1033-aws
Machine     : x86_64
Processor   : x86_64
CPU cores   : 48
Architecture: 64bit



In [4]:
data_path = Path("../data/cafa5")

In [5]:
df = pd.read_parquet(data_path / "top100_train_split.parquet")
df.head()

Unnamed: 0,Entry ID,Sequence,Index
59837,Q75PL7,MPPNSVDKTNETEYLKDNHVDYEKLIAPQASPIKHKIVVMNVIRFS...,59837
12347,E9QA15,MDDFERRRELRRQKREEMRLEAERIAYQRNDDDEEEAARERRRRAR...,12347
62582,Q84TI3,MKRLRSSDDLDFCNDKNVDGEPPNSDRPASSSHRGFFSGNNRDRGE...,62582
47476,Q24060,MRYLCVFSLTLILCCLSIKAQSLNCTRLRENCRPCTRRLVDPINDL...,47476
40185,P84282,APECGREAHCGDDCQSQVVTRDFDDRTCPKLLCCSKDGWCGNTDAN...,40185


In [6]:
from datasets import Dataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
dataset = Dataset.from_pandas(df, preserve_index=False)
dataset

Dataset({
    features: ['Entry ID', 'Sequence', 'Index'],
    num_rows: 70921
})

In [8]:
dataset = dataset.train_test_split(test_size=0.25, seed=0)
dataset

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index'],
        num_rows: 17731
    })
})

In [9]:
dataset["train"][0]

{'Entry ID': 'Q6U841',
 'Sequence': 'MEIKDQGAQMEPLLPTRNDEEAVVDRGGTRSILKTHFEKEDLEGHRTLFIGVHVPLGGRKSHRRHRHRGHKHRKRDRERDSGLEDGRESPSFDTPSQRVQFILGTEDDDEEHIPHDLFTELDEICWREGEDAEWRETARWLKFEEDVEDGGERWSKPYVATLSLHSLFELRSCILNGTVLLDMHANTLEEIADMVLDQQVSSGQLNEDVRHRVHEALMKQHHHQNQKKLTNRIPIVRSFADIGKKQSEPNSMDKNAGQVVSPQSAPACVENKNDVSRENSTVDFSKGLGGQQKGHTSPCGMKQRHEKGPPHQQEREVDLHFMKKIPPGAEASNILVGELEFLDRTVVAFVRLSPAVLLQGLAEVPIPTRFLFILLGPLGKGQQYHEIGRSIATLMTDEVFHDVAYKAKDRNDLVSGIDEFLDQVTVLPPGEWDPSIRIEPPKNVPSQEKRKIPAVPNGTAAHGEAEPHGGHSGPELQRTGRIFGGLILDIKRKAPYFWSDFRDAFSLQCLASFLFLYCACMSPVITFGGLLGEATEGRISAIESLFGASMTGIAYSLFGGQPLTILGSTGPVLVFEKILFKFCKEYGLSYLSLRASIGLWTATLCIILVATDASSLVCYITRFTEEAFASLICIIFIYEALEKLFELSEAYPINMHNDLELLTQYSCNCVEPHNPSNGTLKEWRESNISASDIIWENLTVSECKSLHGEYVGRACGHDHPYVPDVLFWSVILFFSTVTLSATLKQFKTSRYFPTKVRSIVSDFAVFLTILCMVLIDYAIGIPSPKLQVPSVFKPTRDDRGWFVTPLGPNPWWTVIAAIIPALLCTILIFMDQQITAVIINRKEHKLKKGCGYHLDLLMVAVMLGVCSIMGLPWFVAATVLSITHVNSLKLESECSAPGEQPKFLGIREQRVTGLMIFILMGSSVFMTSILKFIPMPVLYGVFLYMGASSLKGIQFFDRIKLFW

In [10]:
import os
import json
import torch
import torch.nn as nn
import random
from transformers import AutoTokenizer, EsmModel

In [11]:
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json: 100%|██████████| 95.0/95.0 [00:00<00:00, 686kB/s]
vocab.txt: 100%|██████████| 93.0/93.0 [00:00<00:00, 799kB/s]
special_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 1.10MB/s]


In [12]:
def tokenize_seqs(batch):
    return tokenizer(
        batch["Sequence"],
        padding="longest",
        truncation=True,
    )

In [13]:
tokenized_dataset = dataset.map(tokenize_seqs, batched=True)
tokenized_dataset

Map:   0%|          | 0/53190 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
                                                                  

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 17731
    })
})

In [14]:
tokenized_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "Index"]
)
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 53190
    })
    test: Dataset({
        features: ['Entry ID', 'Sequence', 'Index', 'input_ids', 'attention_mask'],
        num_rows: 17731
    })
})

In [15]:
tokenized_dataset["train"].num_rows

53190

In [16]:
from torch.utils.data import DataLoader, Dataset

In [17]:
class ProteinDataset(Dataset):
    def __init__(self, dataset, split="train"):
        self.data = dataset[split]

    def __len__(self):
        return self.data.num_rows

    def __getitem__(self, i):
        return self.data[i]

In [18]:
train_dataset = ProteinDataset(tokenized_dataset)
val_dataset = ProteinDataset(tokenized_dataset, split="test")

In [19]:
# targets
targets = np.load(data_path / "train_bp_top100_targets.npy")
targets.shape

(88652, 100)

In [20]:
def collate_fn(batch):
    batch = tokenizer.pad(batch)
    batch["targets"] = torch.as_tensor(
        targets[batch["Index"].tolist()], dtype=torch.float
    )

    return batch

In [21]:
batch_size = 8
num_workers = 8

In [22]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

In [23]:
next(iter(train_loader))

{'Index': tensor([25674, 20809, 73439, 86298, 42191, 25796, 11249, 22095]), 'input_ids': tensor([[ 0, 20,  6,  ...,  1,  1,  1],
        [ 0, 20, 14,  ...,  1,  1,  1],
        [ 0, 20, 22,  ...,  1,  1,  1],
        ...,
        [ 0, 20,  8,  ...,  1,  1,  1],
        [ 0, 20,  8,  ...,  1,  1,  1],
        [ 0, 20,  8,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'targets': tensor([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
         1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
         1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1.,
   

## 🏃Training

#### 🤖 Model

In [24]:
llm = EsmModel.from_pretrained(model_name)
embedding_dim = llm.config.hidden_size
embedding_dim

config.json: 100%|██████████| 724/724 [00:00<00:00, 4.10MB/s]
model.safetensors: 100%|██████████| 2.61G/2.61G [00:33<00:00, 76.9MB/s]
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


1280

In [25]:
class FinetunedESM(nn.Module):
    def __init__(self, llm, dropout_p, embedding_dim, num_classes):
        super().__init__()
        self.llm = llm
        self.dropout_p = dropout_p
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
        self.dropout = nn.Dropout(dropout_p)
        self.pre_classifier = nn.Linear(embedding_dim, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def mean_pooling(self, token_embeddings, attention_mask):
        """Average the embedding of all amino acids in a sequence"""

        # expand the mask
        expanded_mask = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.shape).float()
        )

        # sum unmasked token embeddings
        sum_embeddings = torch.sum(token_embeddings * expanded_mask, dim=1)

        # number of unmasked tokens for each sequence
        # set a min value to avoid divide by zero
        num_tokens = torch.clamp(expanded_mask.sum(1), min=1e-9)

        # divide
        mean_embeddings = sum_embeddings / num_tokens
        return mean_embeddings

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]

        # per token representations from the last layer
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state

        # average per token representations
        mean_embeddings = self.mean_pooling(token_embeddings, attention_mask)

        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/distilbert/modeling_distilbert.py
        mean_embeddings = self.pre_classifier(mean_embeddings)  # (bs, embedding_dim)
        mean_embeddings = nn.ReLU()(mean_embeddings)
        mean_embeddings = self.dropout(mean_embeddings)

        logits = self.classifier(mean_embeddings)  # (bs, num_classes)
        return logits

    @torch.inference_mode()
    def predict(self, batch):
        self.eval()
        y = self(batch)
        return y.cpu().numpy()

    def save(self, dp):
        with open(Path(dp, "args.json"), "w") as fp:
            contents = {
                "dropout_p": self.dropout_p,
                "embedding_dim": self.embedding_dim,
                "num_classes": self.num_classes,
            }
            json.dump(contents, fp, indent=4, sort_keys=False)

        torch.save(self.state_dict(), Path(dp) / "model.pt")

    @classmethod
    def load(cls, esm_model, args_fp, state_dict_fp):
        with open(args_fp, "r") as fp:
            kwargs = json.load(fp=fp)

        llm = EsmModel.from_pretrained(esm_model)
        model = cls(llm=llm, **kwargs)
        model.load_state_dict(
            torch.load(state_dict_fp, map_location=torch.device("cpu"))
        )
        return model

In [26]:
num_classes = 100

model = FinetunedESM(
    llm=llm, dropout_p=0.1, embedding_dim=embedding_dim, num_classes=num_classes
)
print(model.parameters)

<bound method Module.parameters of FinetunedESM(
  (llm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm

In [27]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [28]:
print(f"# Trainable Parameters: {count_parameters(model)}")

# Trainable Parameters: 654121721


In [29]:
# freeze model layers
for param in model.parameters():
    param.requires_grad = False

for param in model.pre_classifier.parameters():
    param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = True

In [30]:
print(f"# Trainable Parameters: {count_parameters(model)}")

# Trainable Parameters: 1767780


In [31]:
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

llm.embeddings.word_embeddings.weight: False
llm.embeddings.position_embeddings.weight: False
llm.encoder.layer.0.attention.self.query.weight: False
llm.encoder.layer.0.attention.self.query.bias: False
llm.encoder.layer.0.attention.self.key.weight: False
llm.encoder.layer.0.attention.self.key.bias: False
llm.encoder.layer.0.attention.self.value.weight: False
llm.encoder.layer.0.attention.self.value.bias: False
llm.encoder.layer.0.attention.output.dense.weight: False
llm.encoder.layer.0.attention.output.dense.bias: False
llm.encoder.layer.0.attention.LayerNorm.weight: False
llm.encoder.layer.0.attention.LayerNorm.bias: False
llm.encoder.layer.0.intermediate.dense.weight: False
llm.encoder.layer.0.intermediate.dense.bias: False
llm.encoder.layer.0.output.dense.weight: False
llm.encoder.layer.0.output.dense.bias: False
llm.encoder.layer.0.LayerNorm.weight: False
llm.encoder.layer.0.LayerNorm.bias: False
llm.encoder.layer.1.attention.self.query.weight: False
llm.encoder.layer.1.attention.s

In [48]:
import pytorch_lightning as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torchmetrics.functional.classification import multilabel_f1_score

In [49]:
pl.seed_everything(0)

Seed set to 0


0

In [50]:
class ESMLightningModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.loss_fn(logits, batch["targets"])
        self.log("train_loss", loss)
        return loss  # this is passed to the optimizer for training

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.loss_fn(logits, batch["targets"])
        self.log("val_loss", loss, prog_bar=True)

        f1_score = multilabel_f1_score(
            logits, batch["targets"].type(torch.int), num_classes
        )
        self.log("val_f1_score", f1_score, prog_bar=True)

    def test_step(self, batch, batch_idx):
        logits = self(batch)

        f1_score = multilabel_f1_score(
            logits, batch["targets"].type(torch.int), num_classes
        )
        self.log("f1_score", f1_score, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [53]:
lightning_model = ESMLightningModule(model, learning_rate=1e-3)

In [54]:
callbacks = [ModelCheckpoint(save_top_k=1, mode="max", monitor="val_f1_score")]
logger = CSVLogger(save_dir="logs/", name="esm_model_head")

In [56]:
from pytorch_lightning.strategies import FSDPStrategy

In [57]:
strategy = FSDPStrategy(
    auto_wrap_policy = {
        nn.TransformerEncoderLayer,
        nn.TransformerDecoderLayer,
    },
    activation_checkpointing_policy = None
)

In [58]:
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=callbacks,
    accelerator="cuda",
    precision="16-mixed",
    devices=4,
    strategy=strategy,
    logger=logger,
    deterministic=True,
)

MisconfigurationException: `Trainer(strategy=<pytorch_lightning.strategies.fsdp.FSDPStrategy object at 0x7fd517790d00>)` is not compatible with an interactive environment. Run your code as a script, or choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.

In [None]:
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

In [44]:
import time

In [39]:
start = time.time()

trainer.fit(
    model=lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)
end = time.time()
print(f"Training Time: {(end-start)/60:.2f} min")

You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[rank: 0] Seed set to 0
You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For mo

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
    self._run_sanity_check()
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1059, in _run_sanity_check
    val_loop.run()
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 642, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 635, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/tmp/ipykernel_2842/3126010323.py", line 19, in validation_step
    logits = self(batch)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_2842/3126010323.py", line 10, in forward
    return self.model(batch)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_2842/2619312275.py", line 35, in forward
    token_embeddings = self.llm(
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 914, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 619, in forward
    layer_outputs = layer_module(
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 509, in forward
    self_attention_outputs = self.attention(
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 443, in forward
    self_outputs = self.self(
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/ml/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 347, in forward
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 21.43 GiB. GPU 1 has a total capacity of 22.19 GiB of which 16.56 GiB is free. Including non-PyTorch memory, this process has 5.63 GiB memory in use. Of the allocated memory 4.81 GiB is allocated by PyTorch, and 369.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
