In [2]:
%reload_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

from src.model.model import ProtT5CLIP
from src.model.data_collator import DataCollatorForProtT5CLIP

from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict, load_from_disk

import torch
import re
import pandas as pd
import numpy as np
import gc
from datetime import datetime

from transformers import (
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    AutoConfig,
    CLIPConfig,
    PretrainedConfig
    
)

from peft import (
    LoraConfig,
    get_peft_model,
)

device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


USE_WANDB = False

if USE_WANDB:
    import wandb
    run = wandb.init(project="protT5-CLIP", name=f"protT5-CLIP-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")
    
report_to = "wandb" if USE_WANDB else None


In [3]:
plm_config = AutoConfig.from_pretrained("Rostlab/prot_t5_xl_uniref50")
llm_config = AutoConfig.from_pretrained("microsoft/Phi-3.5-mini-instruct", trust_remote_code=True)

model_config = PretrainedConfig(
    name_or_path_plm = "Rostlab/prot_t5_xl_uniref50",
    name_or_path_llm = "microsoft/Phi-3.5-mini-instruct",
    plm_config=plm_config,
    llm_config=llm_config,
    output_hidden_states=True,
    output_attentions=True,
    return_dict=True,
    frozen_plm=False,
    frozen_llm=False,
    projection_dim=1024,
    logit_scale_init_value=2.6592,
)

In [None]:
model = ProtT5CLIP(model_config)
model.to(device)
model.to(torch.bfloat16)

In [4]:
# print("\nModel device:", next(model.parameters()).device)
# print("\nModel PLM device:", next(model.model_plm.parameters()).device)
# print("\nModel LLM device:", next(model.model_llm.parameters()).device)

# print("\nProtein Model (T5) Parameter dtypes:")
# for name, param in model.model_plm.named_parameters():
#     print(f"{name}: {param.dtype}")

# print("\nText Model (Phi) Parameter dtypes:")
# for name, param in model.model_llm.named_parameters():
#     print(f"{name}: {param.dtype}")

# print("\nProjection Layer Parameter dtypes:")
# for name, param in model.protein_projection.named_parameters():
#     print(f"protein_projection.{name}: {param.dtype}")
# for name, param in model.text_projection.named_parameters():
#     print(f"text_projection.{name}: {param.dtype}")

# print(f"\nLogit Scale dtype: {model.logit_scale.dtype}")


In [4]:
target_modules = []
modules_to_save = ['protein_projection', 'text_projection']
if not model_config.frozen_plm:
    target_modules += ["q", "k", "v", "o"]
    modules_to_save += model.loading_info_plm["missing_keys"]
if not model_config.frozen_llm:
    target_modules += ["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
    modules_to_save += model.loading_info_llm["missing_keys"]

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    # use_rslora=True,
    # use_dora=True,
)

model = get_peft_model(model, lora_config)
print("target_modules:", target_modules)
print("modules_to_save:", modules_to_save)
model.print_trainable_parameters()

NameError: name 'model' is not defined

In [5]:

tokenizer_plm = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.name_or_path_plm,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

tokenizer_llm = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.name_or_path_llm,
)

In [6]:
# dataset = [
#     {
#         "uid": "A001",
#         "sequence": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRFPHLQKPFQRF",
#         "text": "This protein is involved in membrane transport.",
#     },
#     {
#         "uid": "A002",
#         "sequence": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTTSPSTLKT",
#         "text": "This enzyme catalyzes the hydrolysis of ATP.",
#     },
#     {
#         "uid": "A003",
#         "sequence": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN",
#         "text": "This transcription factor regulates gene expression.",
#     },
# ] * 1000

# dataset = Dataset.from_list(dataset)
# dataset = dataset.add_column("sequence_original", dataset["sequence"])
# dataset = dataset.map(lambda x: {"sequence": " ".join(list(re.sub(r"[UZOB]", "X", x["sequence"])))})

# tknz_plm = tokenizer_plm(text=dataset["sequence"], padding=False, truncation=False)
# tknz_llm = tokenizer_llm(text=dataset["text"], padding=False, truncation=False)

# dataset = dataset.add_column(
#     "input_ids", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["input_ids"], tknz_llm["input_ids"])]
# )
# dataset = dataset.add_column(
#     "attention_mask", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["attention_mask"], tknz_llm["attention_mask"])]
# )

# dataset = dataset.remove_columns(["uid", "sequence", "text", "sequence_original"])
# dataset = DatasetDict({"train": dataset, "test": dataset})

# print(dataset)
# print(dataset["train"][0])

In [6]:
overwrite = False

processed_dataset_path = "../tmp/data/processed_train_val_GO"
# processed_dataset_path = "tmp/data/processed_train_val_GO_FULL"

if not overwrite and os.path.exists(processed_dataset_path):
    print("Loading processed dataset from disk...")
    dataset = load_from_disk(processed_dataset_path)
else:
    print("Processing dataset...")
    dataset = load_from_disk("../tmp/data/train_val_GO")
    dataset = DatasetDict({
        'train': dataset['train'],#.select(range(10000)),
        'valid': dataset['test']#.select(range(3000))
    })

    for split in dataset:
        dataset[split] = dataset[split].filter(lambda x: len(x["sequence"]) < 256)
        
        dataset[split] = dataset[split].map(lambda x: {"sequence": " ".join(list(re.sub(r"[UZOB]", "X", x["sequence"])))})
        dataset[split] = dataset[split].remove_columns(["identifier", "term", "aspect", "GO Name", "species", "__index_level_0__"])

        tknz_plm = tokenizer_plm(text=dataset[split]["sequence"], padding=False, truncation=False)
        tknz_llm = tokenizer_llm(text=dataset[split]["GO Sentence"], padding=False, truncation=False)

        dataset[split] = dataset[split].add_column(
            "input_ids", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["input_ids"], tknz_llm["input_ids"])]
        )
        dataset[split] = dataset[split].add_column(
            "attention_mask", [{"sequence": seq, "text": txt} for seq, txt in zip(tknz_plm["attention_mask"], tknz_llm["attention_mask"])]
        )

    dataset = dataset.remove_columns(["sequence", "GO Sentence"])
    
    print("Saving processed dataset to disk...")
    dataset.save_to_disk(processed_dataset_path)

print(dataset)
print(dataset["train"][0])

Processing dataset...


Filter:   0%|          | 0/4299428 [00:00<?, ? examples/s]

Map:   0%|          | 0/896917 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1064435 [00:00<?, ? examples/s]

Map:   0%|          | 0/221346 [00:00<?, ? examples/s]

Saving processed dataset to disk...


Saving the dataset (0/6 shards):   0%|          | 0/896917 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/221346 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 896917
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 221346
    })
})
{'input_ids': {'sequence': [19, 20, 7, 11, 12, 6, 18, 3, 22, 4, 4, 3, 4, 3, 6, 15, 6, 3, 4, 20, 5, 11, 13, 4, 3, 3, 4, 3, 9, 17, 5, 9, 5, 11, 11, 16, 13, 10, 18, 10, 17, 7, 11, 10, 18, 18, 17, 18, 9, 10, 15, 14, 22, 11, 22, 13, 3, 13, 20, 4, 17, 17, 11, 17, 5, 11, 6, 19, 14, 13, 12, 5, 22, 18, 18, 11, 22, 17, 6, 11, 8, 22, 11, 3, 13, 10, 11, 18, 13, 22, 18, 17, 4, 11, 9, 20, 16, 3, 14, 17, 4, 11, 11, 7, 13, 11, 11, 4, 22, 3, 6, 5, 17, 22, 10, 20, 5, 12, 22, 6, 13, 17, 5, 11, 14, 9, 4, 22, 15, 14, 3, 13, 17, 4, 9, 9, 1], 'text': [450, 4768, 5996, 1889, 338, 8178, 1072, 2785, 310, 5094, 17082, 457, 6354, 29889]}, 'attention_mask': {'sequence': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [None]:

data_collator = DataCollatorForProtT5CLIP(
    tokenizer_plm=tokenizer_plm,
    tokenizer_llm=tokenizer_llm,
    padding=True,
    pad_to_multiple_of=8
)

training_args = TrainingArguments(
    output_dir="../tmp/models/checkpoints/",
    run_name=run.name if USE_WANDB else None,
    learning_rate=1e-3,
    per_device_train_batch_size=26,
    # per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_steps=1,
    # do_train=False,
    # do_eval=False,
    # eval_steps=300,
    # save_strategy="steps",
    # save_steps=300,
    remove_unused_columns=False,
    # label_names=["labels"],
    seed=69420,
    report_to=report_to,
)

def compute_metrics(eval_preds):
    return {
        "loss": 1.0,
        "accuracy": 0.5,
        "precision": 0.5,
        "recall": 0.5,
        "f1": 0.5,
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"].select(range(1000)), # Important to limit the training set size for now
    # eval_dataset=dataset['valid'],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [None]:
gc.collect()

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
torch.set_printoptions(profile="full")

if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()

---
# Model saving

In [None]:
model.save_pretrained(f"../tmp/models/protT5-CLIP-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")

---
## Analysis

In [None]:
# pd.DataFrame(trainer.state.log_history)

In [None]:
# import matplotlib.pyplot as plt

# log_df = pd.DataFrame(trainer.state.log_history)

# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# ax1_twin = ax1.twinx()
# log_df.plot(y='loss', ax=ax1, color='blue', label='Training Loss')
# log_df.plot(y='grad_norm', ax=ax1_twin, color='red', label='Gradient Norm')
# ax1.set_xlabel('Step')
# ax1.set_ylabel('Loss', color='blue')
# ax1_twin.set_ylabel('Gradient Norm', color='red')
# ax1.set_title('Training Loss and Gradient Norm over Time')
# ax1.grid(True)

# lines1, labels1 = ax1.get_legend_handles_labels()
# lines2, labels2 = ax1_twin.get_legend_handles_labels()
# ax1_twin.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

# log_df.plot(y='learning_rate', ax=ax2, color='green', label='Learning Rate') 
# ax2.set_xlabel('Step')
# ax2.set_ylabel('Learning Rate')
# ax2.set_title('Learning Rate Schedule')
# ax2.grid(True)
# ax2.legend()

# plt.tight_layout()
# plt.show()


In [None]:
# import os

# os.makedirs("../tmp/models", exist_ok=True)

# log_df.to_csv("../tmp/models/training_logs.csv", index=True)
# print("Training logs saved to ../tmp/models/training_logs.csv")
