In [1]:
import ray
import sys

sys.path.append("..")

In [2]:
if ray.is_initialized():
    ray.shutdown()
ray.init()

2024-02-14 10:22:58,797	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


0,1
Python version:,3.10.13
Ray version:,2.9.2
Dashboard:,http://127.0.0.1:8266


In [3]:
ray.cluster_resources()

{'CPU': 28.0,
 'node:172.30.66.101': 1.0,
 'memory': 9315849831.0,
 'object_store_memory': 4657924915.0,
 'accelerator_type:G': 1.0,
 'GPU': 1.0,
 'node:__internal_head__': 1.0}

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path

plt.rcParams["figure.dpi"] = 300
sns.set_theme(style="whitegrid")

In [5]:
data_path = Path("../data/cafa5")

#### ☁️ Distributed Processing

In [6]:
import os
import json
import torch
import torch.nn as nn
import random
from transformers import AutoTokenizer, EsmModel

In [7]:
model_name = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [8]:
encoded_seq = tokenizer(
    ["MSLEQKKGADIISKILQIQNSIGK", "ISRKEQENARIQSKL"],
    padding="longest",
    truncation=True,
    return_tensors="np",
)
print(encoded_seq)

{'input_ids': array([[ 0, 20,  8,  4,  9, 16, 15, 15,  6,  5, 13, 12, 12,  8, 15, 12,
         4, 16, 12, 16, 17,  8, 12,  6, 15,  2],
       [ 0, 12,  8, 10, 15,  9, 16,  9, 17,  5, 10, 12, 16,  8, 15,  4,
         2,  1,  1,  1,  1,  1,  1,  1,  1,  1]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0]])}


In [9]:
tokenizer.pad(encoded_seq, return_tensors="pt")

{'input_ids': tensor([[ 0, 20,  8,  4,  9, 16, 15, 15,  6,  5, 13, 12, 12,  8, 15, 12,  4, 16,
         12, 16, 17,  8, 12,  6, 15,  2],
        [ 0, 12,  8, 10, 15,  9, 16,  9, 17,  5, 10, 12, 16,  8, 15,  4,  2,  1,
          1,  1,  1,  1,  1,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0]])}

In [10]:
# targets
go_targets = np.load(data_path / "train_bp_top500_targets.npy")
go_targets.shape

(92210, 500)

In [11]:
ray.data.DatasetContext.get_current().execution_options.preserve_order = (
    True  # deterministic
)

In [12]:
ds = ray.data.read_parquet(
    data_path / "train_split.parquet", columns=["Entry ID", "Sequence", "Index"]
)
ds = ds.random_shuffle(seed=0)
ds.take(1)

Parquet Files Sample 0:   0%|          | 0/1 [00:00<?, ?it/s]

2024-02-14 10:23:02,208	INFO dataset.py:2488 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2024-02-14 10:23:02,211	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=89 for stage ReadParquet to satisfy output blocks of size at least DataContext.get_current().target_min_block_size=1.0MiB.
2024-02-14 10:23:02,212	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 89, each read task output is split into 89 smaller blocks.
2024-02-14 10:23:02,212	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> AllToAllOperator[RandomShuffle] -> LimitOperator[limit=1]
2024-02-14 10:23:02,213	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, prese

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

[36m(ReadParquet->SplitBlocks(89) pid=65195)[0m   return transform_pyarrow.concat(tables)


[{'Entry ID': 'Q17604',
  'Sequence': 'MVSHKKNDRPRPLWILKIHKRLSLFEFKRYATGIGKDDGQDISWVLKGNAKNNVYQVTVETMENCETDECKKVIWVPDELAESTGTMFEDFKEDQPQESVSSISNNEANWGSSVNELDENYEKMQKEETFDPYDSDSDTSEDSDFDEDFEDSDKTMCSGQS',
  'Index': 48142}]

In [13]:
def tokenize_seqs(batch):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    encoded_seqs = tokenizer(
        batch["Sequence"].tolist(),
        padding="longest",
        truncation=True,
        return_tensors="np",
    )
    return dict(
        input_ids=encoded_seqs["input_ids"],
        attention_mask=encoded_seqs["attention_mask"],
        targets=go_targets[batch["Index"].tolist()],
    )

In [14]:
def preprocess(df):
    return tokenize_seqs(df)

In [15]:
test_size = 0.25
train_ds, valid_ds = ds.train_test_split(test_size=test_size)

2024-02-14 10:23:03,083	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=89 for stage ReadParquet to satisfy output blocks of size at least DataContext.get_current().target_min_block_size=1.0MiB.
2024-02-14 10:23:03,084	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 89, each read task output is split into 89 smaller blocks.
2024-02-14 10:23:03,085	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> AllToAllOperator[RandomShuffle]
2024-02-14 10:23:03,086	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:03,086	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

In [16]:
sample_ds = train_ds.map_batches(preprocess, batch_format="pandas")
sample_ds.show(1)

2024-02-14 10:23:03,601	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(preprocess)] -> LimitOperator[limit=1]
2024-02-14 10:23:03,602	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:03,602	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/67 [00:00<?, ?it/s]

{'input_ids': array([ 0, 20,  7, ...,  1,  1,  1]), 'attention_mask': array([1, 1, 1, ..., 0, 0, 0]), 'targets': array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0

## 🏃Training

In [17]:
import os
import json
import random

In [18]:
def set_seeds(seed=0):
    """Set seeds for reproducibility."""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    eval("setattr(torch.backends.cudnn, 'deterministic', True)")
    eval("setattr(torch.backends.cudnn, 'benchmark', False)")
    os.environ["PYTHONHASHSEED"] = str(seed)

In [19]:
DATASET_LOC = data_path / "train_split.parquet"


def load_data(num_samples=None):
    ds = ray.data.read_parquet(DATASET_LOC, columns=["Entry ID", "Sequence", "Index"])
    ds = ds.random_shuffle(seed=0)
    ds = ray.data.from_items(ds.take(num_samples)) if num_samples else ds
    return ds

In [20]:
class CustomPreprocessor:
    """Custom preprocessor class."""

    def transform(self, ds):
        return ds.map_batches(
            preprocess,
            batch_format="pandas",
        )

#### 🤖 Model

In [21]:
llm = EsmModel.from_pretrained(model_name)
embedding_dim = llm.config.hidden_size
embedding_dim

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


480

In [22]:
pad_token_id = llm.config.pad_token_id
pad_token_id

1

In [23]:
class FinetunedESM(nn.Module):
    def __init__(self, llm, dropout_p, embedding_dim, num_classes):
        super().__init__()
        self.llm = llm
        self.dropout_p = dropout_p
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
        self.dropout = nn.Dropout(dropout_p)
        self.pre_classifier = nn.Linear(embedding_dim, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def mean_pooling(self, token_embeddings, attention_mask):
        """Average the embedding of all amino acids in a sequence"""

        # expand the mask
        expanded_mask = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.shape).float()
        )

        # sum unmasked token embeddings
        sum_embeddings = torch.sum(token_embeddings * expanded_mask, dim=1)

        # number of unmasked tokens for each sequence
        # set a min value to avoid divide by zero
        num_tokens = torch.clamp(expanded_mask.sum(1), min=1e-9)

        # divide
        mean_embeddings = sum_embeddings / num_tokens
        return mean_embeddings

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]

        # per token representations from the last layer
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state

        # average per token representations
        mean_embeddings = self.mean_pooling(token_embeddings, attention_mask)

        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/distilbert/modeling_distilbert.py
        mean_embeddings = self.pre_classifier(mean_embeddings)  # (bs, embedding_dim)
        mean_embeddings = nn.ReLU()(mean_embeddings)
        mean_embeddings = self.dropout(mean_embeddings)

        logits = self.classifier(mean_embeddings)  # (bs, num_classes)
        return logits

    @torch.inference_mode()
    def predict(self, batch):
        self.eval()
        y = self(batch)
        return y.cpu().numpy()

    def save(self, dp):
        with open(Path(dp, "args.json"), "w") as fp:
            contents = {
                "dropout_p": self.dropout_p,
                "embedding_dim": self.embedding_dim,
                "num_classes": self.num_classes,
            }
            json.dump(contents, fp, indent=4, sort_keys=False)

        torch.save(self.state_dict(), Path(dp) / "model.pt")

    @classmethod
    def load(cls, esm_model, args_fp, state_dict_fp):
        with open(args_fp, "r") as fp:
            kwargs = json.load(fp=fp)

        llm = EsmModel.from_pretrained(esm_model)
        model = cls(llm=llm, **kwargs)
        model.load_state_dict(
            torch.load(state_dict_fp, map_location=torch.device("cpu"))
        )
        return model

In [24]:
# model = FinetunedESM(
#     llm=llm, dropout_p=0.1, embedding_dim=embedding_dim, num_classes=500
# )
# print(model.parameters)

In [25]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [26]:
# count_parameters(model)

In [27]:
import math
from functools import partial

In [28]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        self.W_a = nn.Parameter(torch.randn(in_dim, rank) / math.sqrt(rank))
        self.W_b = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = x @ self.W_a  # batch * rank
        x = x @ self.W_b  # batch * out_dim
        return self.alpha * x

In [29]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [30]:
# apply_lora(model)
# model

In [31]:
# count_parameters(model)

In [32]:
# for name, param in model.named_parameters():
#     print(f"{name}: {param.requires_grad}")

In [33]:
from ray.train.torch import get_device

get_device()

device(type='cuda', index=0)

In [34]:
def collate_fn(batch):
    padded = tokenizer.pad(
        {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]},
        return_tensors="pt",
    )

    batch["input_ids"] = padded["input_ids"].to(device=get_device())
    batch["attention_mask"] = padded["attention_mask"].to(device=get_device())
    batch["targets"] = torch.as_tensor(
        batch["targets"], dtype=torch.float, device=get_device()
    )

    return batch

In [35]:
sample_batch = sample_ds.take_batch(batch_size=32)
sample_batch

2024-02-14 10:23:06,563	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(preprocess)] -> LimitOperator[limit=32]
2024-02-14 10:23:06,564	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:06,564	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/67 [00:00<?, ?it/s]

[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: a7264939eedf2d986d5a4317cb59ea58b3c0ae7301000000 Worker ID: 38fe109214120401e096136626a30dca50e2a8a4a0600919be19a5c3 Node ID: 78e8c183769d68b24ba063fdab01146d419dfa1149f0c4580506909f Worker IP address: 172.30.66.101 Worker port: 32927 Worker PID: 65188 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.


[36m(MapBatches(preprocess) pid=65188)[0m *** SIGSEGV received at time=1707934987 on cpu 7 ***
[36m(MapBatches(preprocess) pid=65188)[0m PC: @           0x4f5d43  (unknown)  frame_dealloc
[36m(MapBatches(preprocess) pid=65188)[0m     @     0x7f27810ac520  (unknown)  (unknown)
[36m(MapBatches(preprocess) pid=65188)[0m [2024-02-14 10:23:07,169 E 65188 65188] logging.cc:361: *** SIGSEGV received at time=1707934987 on cpu 7 ***
[36m(MapBatches(preprocess) pid=65188)[0m [2024-02-14 10:23:07,169 E 65188 65188] logging.cc:361: PC: @           0x4f5d43  (unknown)  frame_dealloc
[36m(MapBatches(preprocess) pid=65188)[0m [2024-02-14 10:23:07,169 E 65188 65188] logging.cc:361:     @     0x7f27810ac520  (unknown)  (unknown)
[36m(MapBatches(preprocess) pid=65188)[0m Fatal Python error: Segmentation fault
[36m(MapBatches(preprocess) pid=65188)[0m 
[36m(MapBatches(preprocess) pid=65188)[0m Stack (most recent call first):
[36m(MapBatches(preprocess) pid=65188)[0m   File "/home/ytia

{'input_ids': array([[ 0, 20,  7, ...,  1,  1,  1],
        [ 0, 20,  5, ...,  1,  1,  1],
        [ 0, 20, 16, ...,  9,  8,  2],
        ...,
        [ 0, 20,  6, ..., 11, 15,  2],
        [ 0, 20,  8, ...,  1,  1,  1],
        [ 0, 20,  5, ...,  1,  1,  1]]),
 'attention_mask': array([[1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 1, 1, 1],
        ...,
        [1, 1, 1, ..., 1, 1, 1],
        [1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0]]),
 'targets': array([[1, 1, 1, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0],
        [1, 0, 1, ..., 0, 0, 0],
        ...,
        [1, 1, 1, ..., 0, 0, 0],
        [1, 0, 1, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0]])}

In [36]:
sample_batch["input_ids"].shape

(32, 1024)

In [37]:
sample_batch = collate_fn(batch=sample_batch)
sample_batch

{'input_ids': tensor([[ 0, 20,  7,  ...,  1,  1,  1],
         [ 0, 20,  5,  ...,  1,  1,  1],
         [ 0, 20, 16,  ...,  9,  8,  2],
         ...,
         [ 0, 20,  6,  ..., 11, 15,  2],
         [ 0, 20,  8,  ...,  1,  1,  1],
         [ 0, 20,  5,  ...,  1,  1,  1]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
 'targets': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')}

In [38]:
sample_batch["targets"].shape

torch.Size([32, 500])

In [39]:
import ray.train as train
from ray.train import Checkpoint, CheckpointConfig, DataConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchCheckpoint, TorchTrainer
import tempfile
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel

In [40]:
# model.to(get_device())
# sample_output = model(sample_batch)
# sample_output.shape

In [41]:
def train_step(ds, batch_size, model, num_classes, loss_fn, optimizer):
    model.train()
    train_loss = []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    for batch in ds_generator:
        optimizer.zero_grad()
        y_hat = model(batch)
        if num_classes == 1:
            y_hat = y_hat.unsqueeze(1)
        loss = loss_fn(y_hat, batch["targets"])
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

    return np.mean(train_loss)

In [42]:
def eval_step(ds, batch_size, model, num_classes, loss_fn):
    model.eval()
    eval_loss = []
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    with torch.no_grad():
        for batch in ds_generator:
            y_hat = model(batch)
            if num_classes == 1:
                y_hat = y_hat.unsqueeze(1)
            loss = loss_fn(y_hat, batch["targets"])
            eval_loss.append(loss.item())
            y_trues.extend(batch["targets"].cpu().numpy())
            y_preds.extend(y_hat.cpu().numpy())

    return np.mean(eval_loss), y_trues, y_preds

In [43]:
lora_rank = 8
lora_alpha = 1
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = True

In [44]:
def train_loop_per_worker(config):
    # Hyperparameters
    esm_model = config["esm_model"]
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_classes = config["num_classes"]
    do_lora = config["do_lora"]

    set_seeds()
    train_ds = train.get_dataset_shard("train")
    val_ds = train.get_dataset_shard("val")

    llm = EsmModel.from_pretrained(esm_model)
    model = FinetunedESM(
        llm=llm,
        dropout_p=dropout_p,
        embedding_dim=llm.config.hidden_size,
        num_classes=num_classes,
    )

    if do_lora:
        # freeze model layers
        for param in model.parameters():
            param.requires_grad = False

        # for adding lora to linear layers
        linear_with_lora = partial(LinearWithLoRA, rank=lora_rank, alpha=lora_alpha)

        # iterate through each transfomer layer
        for layer in model.llm.encoder.layer:
            if lora_query:
                layer.attention.self.query = linear_with_lora(
                    layer.attention.self.query
                )

            if lora_key:
                layer.attention.self.key = linear_with_lora(layer.attention.self.key)

            if lora_value:
                layer.attention.self.value = linear_with_lora(
                    layer.attention.self.value
                )

            if lora_projection:
                layer.attention.output.dense = linear_with_lora(
                    layer.attention.output.dense
                )

            if lora_mlp:
                layer.output.dense = linear_with_lora(layer.output.dense)
                layer.output.dense = linear_with_lora(layer.output.dense)

        if lora_head:
            model.pre_classifier = linear_with_lora(model.pre_classifier)
            model.classifier = linear_with_lora(model.classifier)

    print(f"# Trainable Parameters: {count_parameters(model)}")

    model = train.torch.prepare_model(model)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer, mode="min", factor=lr_factor, patience=lr_patience
    )

    num_workers = train.get_context().get_world_size()
    batch_size_per_worker = batch_size // num_workers

    for epoch in range(num_epochs):
        train_loss = train_step(
            train_ds, batch_size_per_worker, model, num_classes, loss_fn, optimizer
        )
        val_loss, _, _ = eval_step(
            val_ds, batch_size_per_worker, model, num_classes, loss_fn
        )
        scheduler.step(val_loss)

        # Checkpoint
        with tempfile.TemporaryDirectory(prefix="ray_results") as dp:
            if isinstance(model, DistributedDataParallel):  # cpu
                model.module.save(dp=dp)
            else:
                model.save(dp=dp)
            metrics = dict(
                epoch=epoch,
                lr=optimizer.param_groups[0]["lr"],
                train_loss=train_loss,
                val_loss=val_loss,
            )
            checkpoint = Checkpoint.from_directory(dp)
            train.report(metrics, checkpoint=checkpoint)

In [45]:
# Train loop config
train_loop_config = {
    "esm_model": model_name,
    "dropout_p": 0.1,
    "lr": 1e-3,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 1,
    "batch_size": 8,
    "num_classes": 500,
    "do_lora": True,
}

In [46]:
# Scaling config
num_workers = 1
resources_per_worker = {"CPU": 8, "GPU": 1}

scaling_config = ScalingConfig(
    num_workers=num_workers,
    use_gpu=bool(resources_per_worker["GPU"]),
    resources_per_worker=resources_per_worker,
)

In [47]:
# Run config
checkpoint_config = CheckpointConfig(
    num_to_keep=1, checkpoint_score_attribute="val_loss", checkpoint_score_order="min"
)
run_config = RunConfig(
    name="llm",
    checkpoint_config=checkpoint_config,
    storage_path=str(Path().resolve() / "ray_results"),
)

In [48]:
ds = load_data()
test_size = 0.25
train_ds, valid_ds = ds.train_test_split(test_size=test_size)

Parquet Files Sample 0:   0%|          | 0/1 [00:00<?, ?it/s]

2024-02-14 10:23:08,361	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=89 for stage ReadParquet to satisfy output blocks of size at least DataContext.get_current().target_min_block_size=1.0MiB.
2024-02-14 10:23:08,362	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 89, each read task output is split into 89 smaller blocks.
2024-02-14 10:23:08,362	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> AllToAllOperator[RandomShuffle]
2024-02-14 10:23:08,363	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:08,363	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

[36m(ReadParquet->SplitBlocks(89) pid=65192)[0m   return transform_pyarrow.concat(tables)


In [49]:
preprocessor = CustomPreprocessor()
train_ds = preprocessor.transform(train_ds)
val_ds = preprocessor.transform(valid_ds)
train_ds = train_ds.materialize()
valid_ds = val_ds.materialize()

2024-02-14 10:23:08,828	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(preprocess)]
2024-02-14 10:23:08,828	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:08,828	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/67 [00:00<?, ?it/s]

2024-02-14 10:23:11,864	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(preprocess)]
2024-02-14 10:23:11,864	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
2024-02-14 10:23:11,865	INFO streaming_executor.py:115 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/23 [00:00<?, ?it/s]

In [50]:
options = ray.data.ExecutionOptions(preserve_order=True)
dataset_config = DataConfig(datasets_to_split=["train"], execution_options=options)

In [51]:
# Trainer
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets={"train": train_ds, "val": val_ds},
    dataset_config=dataset_config,
)

In [52]:
%%time
# Train
results = trainer.fit()

0,1
Current time:,2024-02-14 11:06:58
Running for:,00:43:45.64
Memory:,27.5/31.2 GiB

Trial name,status,loc,iter,total time (s),epoch,lr,train_loss
TorchTrainer_1d1de_00000,TERMINATED,172.30.66.101:70036,1,2621.45,0,0.001,0.166822


[36m(raylet)[0m Spilled 3497 MiB, 292 objects, write throughput 1512 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
[36m(TorchTrainer pid=70036)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=70036)[0m - (ip=172.30.66.101, pid=70336) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=70336)[0m Setting up process group for: env:// [rank=0, world_size=1]


[36m(RayTrainWorker pid=70336)[0m # Trainable Parameters: 199840


[36m(RayTrainWorker pid=70336)[0m Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
[36m(RayTrainWorker pid=70336)[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[36m(RayTrainWorker pid=70336)[0m Moving model to device: cuda:0
[36m(SplitCoordinator pid=70410)[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]
[36m(SplitCoordinator pid=70410)[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=9.0, gpu=1.0, object_store_memory=0.0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbose_progress=False)
[36m(SplitCoordinator pid=70410)[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current(

(pid=70410) Running 0:   0%|          | 0/67 [00:00<?, ?it/s]

[33m(raylet)[0m [2024-02-14 10:23:58,713 E 65082 65082] (raylet) node_manager.cc:3022: 7 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 78e8c183769d68b24ba063fdab01146d419dfa1149f0c4580506909f, IP: 172.30.66.101) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 172.30.66.101`
[33m(raylet)[0m 
[33m(raylet)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.
[36m(RayTrainWorker pid=70336)[0m Executing DAG InputDataBuffer[Input] -> TaskPoolMapO

(pid=70336) Running 0:   0%|          | 0/23 [00:00<?, ?it/s]

[36m(RayTrainWorker pid=70336)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ytian/github/esm-lora/notebooks/ray_results/llm/TorchTrainer_1d1de_00000_0_2024-02-14_10-23-13/checkpoint_000000)
2024-02-14 11:06:58,717	INFO tune.py:1042 -- Total run time: 2625.65 seconds (2625.60 seconds for the tuning loop).


CPU times: user 11.9 s, sys: 2.52 s, total: 14.4 s
Wall time: 43min 45s


In [53]:
results.metrics_dataframe

Unnamed: 0,epoch,lr,train_loss,val_loss,timestamp,checkpoint_dir_name,should_checkpoint,done,training_iteration,trial_id,...,iterations_since_restore,config/train_loop_config/esm_model,config/train_loop_config/dropout_p,config/train_loop_config/lr,config/train_loop_config/lr_factor,config/train_loop_config/lr_patience,config/train_loop_config/num_epochs,config/train_loop_config/batch_size,config/train_loop_config/num_classes,config/train_loop_config/do_lora
0,0,0.001,0.166822,0.162471,1707937617,checkpoint_000000,True,False,1,1d1de_00000,...,1,facebook/esm2_t12_35M_UR50D,0.1,0.001,0.8,3,1,8,500,True


In [54]:
results.best_checkpoints

[(Checkpoint(filesystem=local, path=/home/ytian/github/esm-lora/notebooks/ray_results/llm/TorchTrainer_1d1de_00000_0_2024-02-14_10-23-13/checkpoint_000000),
  {'epoch': 0,
   'lr': 0.001,
   'train_loss': 0.16682193633395168,
   'val_loss': 0.1624705195846965,
   'timestamp': 1707937617,
   'checkpoint_dir_name': 'checkpoint_000000',
   'should_checkpoint': True,
   'done': False,
   'training_iteration': 1,
   'trial_id': '1d1de_00000',
   'date': '2024-02-14_11-06-57',
   'time_this_iter_s': 2621.449065208435,
   'time_total_s': 2621.449065208435,
   'pid': 70036,
   'hostname': 'Witcher',
   'node_ip': '172.30.66.101',
   'config': {'train_loop_config': {'esm_model': 'facebook/esm2_t12_35M_UR50D',
     'dropout_p': 0.1,
     'lr': 0.001,
     'lr_factor': 0.8,
     'lr_patience': 3,
     'num_epochs': 1,
     'batch_size': 8,
     'num_classes': 500,
     'do_lora': True}},
   'time_since_restore': 2621.449065208435,
   'iterations_since_restore': 1})]