In [1]:
%reload_ext autoreload
%autoreload 2

from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
import torch
from src.model.basic_model import ProtT5CLIP
from src.model.data_collator import DataCollatorForProtT5CLIP
import re
import pandas as pd
import numpy as np
import os
import gc

from transformers import (
    T5Tokenizer,
    Trainer,
    TrainingArguments,
)

from peft import (
    LoraConfig,
    get_peft_model,
)

In [2]:
model_cfg = {
    "base_model_plm": "Rostlab/prot_t5_xl_uniref50",
    "freeze_plm": False,
    "base_model_llm": "microsoft/Phi-3.5-mini-instruct",
    "freeze_llm": False,
}

model = ProtT5CLIP(model_cfg)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
target_modules = []
modules_to_save = []
if not model_cfg["freeze_plm"]:
    target_modules += ["q", "k", "v", "o"]
    modules_to_save = model.loading_info_plm["missing_keys"]
if not model_cfg["freeze_llm"]:
    target_modules += ["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
    modules_to_save += model.loading_info_llm["missing_keys"]

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    # use_rslora=True,
    # use_dora=True,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,388,608 || all params: 5,037,609,984 || trainable%: 0.1665


In [4]:
tokenizer_plm = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=model_cfg["base_model_plm"],
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

tokenizer_llm = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_cfg["base_model_llm"],
)

In [5]:
dataset = [
    {
        "uid": "A001",
        "sequence": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRFPHLQKPFQRF",
        "text": "This protein is involved in membrane transport.",
    },
    {
        "uid": "A002",
        "sequence": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTTSPSTLKT",
        "text": "This enzyme catalyzes the hydrolysis of ATP.",
    },
    {
        "uid": "A003",
        "sequence": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN",
        "text": "This transcription factor regulates gene expression.",
    },
] * 1000

dataset = Dataset.from_list(dataset)
dataset = dataset.add_column("sequence_original", dataset["sequence"])
dataset = dataset.map(lambda x: {"sequence": " ".join(list(re.sub(r"[UZOB]", "X", x["sequence"])))})

tknz_plm = tokenizer_plm(text=dataset["sequence"], padding=False, truncation=False)
tknz_llm = tokenizer_llm(text=dataset["text"], padding=False, truncation=False)

dataset = dataset.add_column(
    "input_ids", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["input_ids"], tknz_llm["input_ids"])]
)
dataset = dataset.add_column(
    "attention_mask", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["attention_mask"], tknz_llm["attention_mask"])]
)

dataset = dataset.remove_columns(["uid", "sequence", "text", "sequence_original"])
dataset = DatasetDict({"train": dataset, "test": dataset})

print(dataset)
print(dataset["train"][0])

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 3000
    })
})
{'input_ids': {'sequence': [19, 4, 9, 6, 13, 6, 21, 12, 13, 12, 4, 3, 15, 3, 6, 5, 4, 5, 4, 5, 4, 4, 12, 13, 20, 4, 16, 14, 13, 15, 16, 8, 15, 13, 20, 4, 16, 14, 13, 15, 16, 8, 15, 1], 'text': [910, 26823, 338, 9701, 297, 3813, 10800, 8608, 29889]}, 'attention_mask': {'sequence': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'text': [1, 1, 1, 1, 1, 1, 1, 1, 1]}}


In [6]:
# collated_data = data_collator(dataset['train'].select(range(20)))
# collated_data#['input_ids']

In [9]:
data_collator = DataCollatorForProtT5CLIP(
    tokenizer_plm=tokenizer_plm,
    tokenizer_llm=tokenizer_llm,
    padding=True,
    pad_to_multiple_of=8,
)

training_args = TrainingArguments(
    output_dir="../tmp/models/",
    learning_rate=1e-3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_steps=1,
    do_train=False,
    do_eval=False,
    evaluation_strategy="steps",  # use eval_strategy
    eval_steps=300,
    save_strategy="steps",
    save_steps=300,
    remove_unused_columns=True,
    # label_names=["labels"],
    seed=69420,
)

def compute_metrics(eval_preds):
    return {
        "loss": 1.0,
        "accuracy": 0.5,
        "precision": 0.5,
        "recall": 0.5,
        "f1": 0.5,
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"].select(range(3)),
    # eval_dataset=dataset['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

gc.collect()

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()

  0%|          | 0/2 [00:00<?, ?it/s]

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.0088, -0.2823,  0.0425,  ...,  0.2538, -0.1532,  0.0862],
         [ 0.2733,  0.0373,  0.0339,  ...,  0.2583,  0.3603,  0.2235],
         [ 0.1986, -0.0644,  0.1578,  ...,  0.0536,  0.1903,  0.1227],
         ...,
         [-0.2545, -0.1984,  0.0010,  ..., -0.1478, -0.1356,  0.0838],
         [-0.0566, -0.0406, -0.1214,  ..., -0.1220, -0.1797,  0.0865],
         [-0.1056, -0.0000, -0.1906,  ..., -0.0821, -0.0557,  0.1157]],

        [[ 0.0918, -0.0767, -0.0157,  ...,  0.2731,  0.1496,  0.0275],
         [ 0.3758, -0.0374, -0.2219,  ...,  0.3062,  0.1348,  0.0006],
         [ 0.3602, -0.2302, -0.1732,  ...,  0.3243,  0.1327, -0.0235],
         ...,
         [ 0.1157, -0.0701, -0.1848,  ..., -0.2225, -0.1143,  0.1544],
         [-0.1004, -0.0853, -0.0898,  ..., -0.1680, -0.0000,  0.1170],
         [-0.0917, -0.0228, -0.1487,  ...,  0.0295, -0.1818, -0.0235]]],
       device='mps:0', grad_fn=<MulBackward0>), past_key

TypeError: 'NoneType' object is not subscriptable