In [1]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# os.environ["HF_DATASETS_OFFLINE"] = "1"
# os.environ["HF_HUB_OFFLINE"] = "1"

from src.model.model import ProtT5CLIP
from src.model.data_collator import DataCollatorForProtT5CLIP
from src.model.subset_trainer import ProteinSampleSubsetTrainer

from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict, load_from_disk

import torch
import re
import pandas as pd
import numpy as np
import gc
from datetime import datetime

from transformers import (
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    AutoConfig,
    CLIPConfig,
    PretrainedConfig,
)

from peft import (
    LoraConfig,
    get_peft_model,
)

device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

model_name_identifier = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

USE_WANDB = False

if USE_WANDB:
    import wandb

    run = wandb.init(project="protT5-CLIP", name=f"protT5-CLIP-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")

report_to = "wandb" if USE_WANDB else None

print("Using device:", device)

Using device: cuda:0


In [2]:
plm_config = AutoConfig.from_pretrained("Rostlab/prot_t5_xl_uniref50")
llm_config = AutoConfig.from_pretrained("microsoft/Phi-3.5-mini-instruct", trust_remote_code=True)

model_config = PretrainedConfig(
    name_or_path_plm="Rostlab/prot_t5_xl_uniref50",
    name_or_path_llm="microsoft/Phi-3.5-mini-instruct",
    plm_config=plm_config,
    llm_config=llm_config,
    output_hidden_states=True,
    output_attentions=True,
    return_dict=True,
    frozen_plm=False,
    frozen_llm=False,
    projection_dim=1024,
    logit_scale_init_value=2.6592,
)

In [3]:
model = ProtT5CLIP(model_config)
model.to(device)
model.to(torch.bfloat16)
print("Loaded model...")

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded model...


In [4]:
model.base_model.model_llm.model.layers[0].self_attn.o_proj

Linear(in_features=3072, out_features=3072, bias=False)

In [5]:
target_modules = []
modules_to_save = ["protein_projection", "text_projection"]
if not model_config.frozen_plm:
    target_modules += ["q", "k", "v", "o"]
    modules_to_save += model.loading_info_plm["missing_keys"]
if not model_config.frozen_llm:
    target_modules += ["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
    modules_to_save += model.loading_info_llm["missing_keys"]

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    # use_rslora=True,
    # use_dora=True,
)

model = get_peft_model(model, lora_config)
print("target_modules:", target_modules)
print("modules_to_save:", modules_to_save)
model.print_trainable_parameters()

target_modules: ['q', 'k', 'v', 'o', 'k_proj', 'q_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj']
modules_to_save: ['protein_projection', 'text_projection']
trainable params: 12,582,912 || all params: 5,045,998,593 || trainable%: 0.2494


In [6]:
tokenizer_plm = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.name_or_path_plm,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

tokenizer_llm = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.name_or_path_llm,
)

In [7]:
dataset_path = "../tmp/data/train_val_GO_skimmed"
dataset_path_processed = "../tmp/data/train_val_GO_skimmed_processed"

if not os.path.exists(dataset_path_processed):
    print("Processing dataset...")
    dataset = load_from_disk(dataset_path)

    for split in ["train"]:  # dataset:
        tknz_plm = tokenizer_plm(text=dataset[split]["sequence"], padding=False, truncation=False)
        tknz_llm = tokenizer_llm(text=dataset[split]["GO Sentence"], padding=False, truncation=False)

        dataset[split] = dataset[split].add_column("input_ids_sequence", tknz_plm["input_ids"])
        dataset[split] = dataset[split].add_column("attention_mask_sequence", tknz_plm["attention_mask"])
        dataset[split] = dataset[split].add_column("input_ids_text", tknz_llm["input_ids"])
        dataset[split] = dataset[split].add_column("attention_mask_text", tknz_llm["attention_mask"])
    
    dataset.save_to_disk(dataset_path_processed)
else:
    print("Loading dataset from disk...")
    dataset = load_from_disk(dataset_path_processed)


print(dataset)
print(dataset["train"][0])

Loading dataset from disk...
DatasetDict({
    train: Dataset({
        features: ['identifier', 'term', 'aspect', 'GO Name', 'GO Sentence', 'sequence', 'species', '__index_level_0__', 'input_ids_sequence', 'attention_mask_sequence', 'input_ids_text', 'attention_mask_text'],
        num_rows: 143552
    })
    test: Dataset({
        features: ['identifier', 'term', 'aspect', 'GO Name', 'GO Sentence', 'sequence', 'species', '__index_level_0__'],
        num_rows: 1064435
    })
})
{'identifier': 'A0A021WW32', 'term': 'GO:0010628', 'aspect': 'BPO', 'GO Name': 'positive regulation of gene expression', 'GO Sentence': 'The biological process is positive regulation of gene expression.', 'sequence': 'MFYEHIILAKKGPLARIWLAAHWDKKITKAHVFETNIEKSVEGILQPKVKLALRTSGHLLLGVVRIYSRKAKYLLADCNEAFVKIKMAFRPGMVDLPEGHREANVNAITLPEVFHDFDTALPELNDIDIEAQFSINQSRADEITMREDYGSLSLSLQDDGFGDIGFEAETPEIIRCSIPSNINDKIFDNDVLENIESLDPHSLDAHADMPGSRLDGDGFGDSFGQPALFEDDLFGDPSQPVEQITKESTTVLNADDSDEDAIDNIHNVPSPATSLVNSIEDEKEENNLNGHASVSE

In [8]:
data_collator = DataCollatorForProtT5CLIP(
    tokenizer_plm=tokenizer_plm, tokenizer_llm=tokenizer_llm, padding=True, pad_to_multiple_of=8
)

training_args = TrainingArguments(
    output_dir=f"../tmp/models/checkpoints/{model_name_identifier}",
    # run_name=run.name if USE_WANDB else None,
    # report_to=report_to,
    learning_rate=1e-3,
    per_device_train_batch_size=2,#26,
    # per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_steps=1,
    # do_train=False,
    # do_eval=False,
    # eval_steps=300,
    # save_strategy="steps",
    # save_steps=300,
    remove_unused_columns=False,
    # label_names=["labels"],
    seed=69420,
)


def compute_metrics(eval_preds):
    return {
        "loss": 1.0,
        "accuracy": 0.5,
        "precision": 0.5,
        "recall": 0.5,
        "f1": 0.5,
    }


trainer = ProteinSampleSubsetTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    # eval_dataset=dataset['valid'],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [9]:
gc.collect()

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
torch.set_printoptions(profile="full")

if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
2024/12/19 22:49:30 INFO server is running addr=127.0.0.1:35039
2024/12/19 22:49:30 INFO Will exit if parent process dies. ppid=1104365
2024/12/19 22:49:30 INFO connection: ManageConnectionData: new connection created id=127.0.0.1:60756
[34m[1mwandb[0m: Currently logged in as: [33mfinnlueth[0m ([33mfinnlueth-organization[0m). Use [1m`wandb login --relogin`[0m to force relogin
2024/12/19 22:49:30 INFO handleInformInit: received streamId=50k776xh id=127.0.0.1:60756
2024/12/19 22:49:31 INFO handleInformInit: stream started streamId=50k776xh id=127.0.0.1:60756


----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])


You are not running the flash-attention implementation, expect numerical differences.


Step,Training Loss
1,0.6875
2,0.7188
3,0.8164
4,0.6914
5,0.6992
6,0.6953
7,0.6953
8,0.6953
9,0.6953
10,0.6953


----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 32])
torch.Size([2, 8])
torch.Size([2, 32])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 48])
torch.Size([2, 8])
torch.Size([2, 48])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 32])
torch.Size([2, 8])
torch.Size([2, 32])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----



----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 32])
torch.Size([2, 8])
torch.Size([2, 32])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 24])
torch.Size([2, 8])
torch.Size([2, 24])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 8])
torch.Size([2, 8])
torch.Size([2, 8])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 24])
torch.Size([2, 8])
torch.Size([2, 24])
----
torch.Size([2, 8])
torch.Size([2, 8])
torch.Size([2, 8])
torch.Size([2, 8])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
torch.Size([2, 8])
torch.Size([2, 16])
torch.Size([2, 8])
torch.Size([2, 16])
----
tor

KeyboardInterrupt: 

---
# Model saving

In [11]:
model_save_path = f"../tmp/models/protT5-CLIP-{model_name_identifier}"
model.save_pretrained(model_save_path)



In [15]:
ProtT5CLIP.from_pretrained(model_save_path)

AttributeError: 'NoneType' object has no attribute 'from_pretrained'

In [14]:
reloaded_model = ProtT5CLIP.from_pretrained(model_save_path)

AttributeError: 'NoneType' object has no attribute 'from_pretrained'

In [12]:
reloaded_model = ProtT5CLIP.from_pretrained(model_save_path)

# Move both models to CPU for comparison
model.to("cpu")
reloaded_model.to("cpu")

# Compare model parameters
models_equal = True
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
    if not torch.allclose(p1, p2):
        print("Found mismatch in parameters")
        models_equal = False
        break

print(f"\nModels are identical: {models_equal}")

# Compare model state dicts
original_state = model.state_dict()
reloaded_state = reloaded_model.state_dict()

states_equal = True
for key in original_state:
    if key not in reloaded_state:
        print(f"Key {key} missing from reloaded model")
        states_equal = False
    elif not torch.equal(original_state[key], reloaded_state[key]):
        print(f"Mismatch in values for key {key}")
        states_equal = False

print(f"State dictionaries are identical: {states_equal}")


AttributeError: 'NoneType' object has no attribute 'from_pretrained'

---
## Analysis

In [10]:
# pd.DataFrame(trainer.state.log_history)

In [11]:
# import matplotlib.pyplot as plt

# log_df = pd.DataFrame(trainer.state.log_history)

# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# ax1_twin = ax1.twinx()
# log_df.plot(y='loss', ax=ax1, color='blue', label='Training Loss')
# log_df.plot(y='grad_norm', ax=ax1_twin, color='red', label='Gradient Norm')
# ax1.set_xlabel('Step')
# ax1.set_ylabel('Loss', color='blue')
# ax1_twin.set_ylabel('Gradient Norm', color='red')
# ax1.set_title('Training Loss and Gradient Norm over Time')
# ax1.grid(True)

# lines1, labels1 = ax1.get_legend_handles_labels()
# lines2, labels2 = ax1_twin.get_legend_handles_labels()
# ax1_twin.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

# log_df.plot(y='learning_rate', ax=ax2, color='green', label='Learning Rate')
# ax2.set_xlabel('Step')
# ax2.set_ylabel('Learning Rate')
# ax2.set_title('Learning Rate Schedule')
# ax2.grid(True)
# ax2.legend()

# plt.tight_layout()
# plt.show()


In [12]:
# import os

# os.makedirs("../tmp/models", exist_ok=True)

# log_df.to_csv("../tmp/models/training_logs.csv", index=True)
# print("Training logs saved to ../tmp/models/training_logs.csv")
