In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import gc
import os
import random

import accelerate
import pandas as pd
import torch
import transformers
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model
from transformers import T5Tokenizer, Trainer, TrainingArguments

from src._shared import (
    load_config,
    setup_environment,
)
from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.modeling_md_pssm import T5EncoderModelForPssmGeneration
from src.model.utils.data_collator import DataCollatorForT5Pssm


In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

train_config = load_config()
model_name_identifier, device, report_to, run, USE_WANDB, SEED = setup_environment(train_config)

accelerate.utils.set_seed(SEED + 1)
transformers.set_seed(SEED + 2)
torch.manual_seed(SEED + 3)
random.seed(SEED + 4)

In [4]:
config = MDPSSMConfig()
model = T5EncoderModelForPssmGeneration(config=config)

In [None]:
target_modules = ["q", "v"]
modules_to_save = ["pssm_head"]

lora_config = LoraConfig(
    inference_mode=False,
    r=train_config["lora"]["r"],
    lora_alpha=train_config["lora"]["lora_alpha"],
    lora_dropout=train_config["lora"]["lora_dropout"],
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    use_rslora=train_config["lora"]["use_rslora"],
    use_dora=train_config["lora"]["use_dora"],
)

model = get_peft_model(model, lora_config)

print("target_modules:", target_modules)
print("modules_to_save:", modules_to_save)
model.print_trainable_parameters()


In [None]:
dataset = load_from_disk("../tmp/data/pssm/pssm_dataset_0_only")
# dataset = dataset.select(range(22, 30))
print(dataset)

for x in range(6):
    print(dataset[x]["name"])
    print(dataset[x]["sequence"])
    print(torch.tensor(dataset[x]["pssm_features"]).shape)
display(torch.tensor(dataset[x]["pssm_features"]))

dataset = dataset.rename_column("pssm_features", "labels")
dataset = dataset.remove_columns(["name", "sequence", "sequence_processed"])

print(dataset)

In [None]:
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path="Rostlab/prot_t5_xl_uniref50",
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)


data_collator = DataCollatorForT5Pssm(
    tokenizer=tokenizer,
    padding=True,
    max_length=512,
    pad_to_multiple_of=8,
)


def compute_metrics(eval_preds):
    return {}


training_args = TrainingArguments(
    output_dir=f"../tmp/models/checkpoints/{model_name_identifier}",
    run_name=model_name_identifier if USE_WANDB else None,
    report_to="wandb" if USE_WANDB else None,
    learning_rate=train_config["trainer"]["learning_rate"],
    per_device_train_batch_size=train_config["trainer"]["train_batch_size"],
    num_train_epochs=train_config["trainer"]["num_epochs"],
    eval_strategy=train_config["trainer"]["eval_strategy"],
    eval_steps=train_config["trainer"]["eval_steps"],
    per_device_eval_batch_size=train_config["trainer"]["eval_batch_size"],
    eval_on_start=train_config["trainer"]["eval_on_start"],
    batch_eval_metrics=train_config["trainer"]["batch_eval_metrics"],
    save_strategy=train_config["trainer"]["save_strategy"],
    save_steps=train_config["trainer"]["save_steps"],
    save_total_limit=train_config["trainer"]["save_total_limit"],
    remove_unused_columns=train_config["trainer"]["remove_unused_columns"],
    label_names=["input_ids", "attention_mask"],
    logging_strategy="steps",
    logging_steps=train_config["trainer"]["logging_steps"],
    seed=train_config["seed"],
    lr_scheduler_type=train_config["trainer"]["lr_scheduler_type"],
    warmup_steps=train_config["trainer"]["warmup_steps"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# torch.set_printoptions(profile="full")
torch.set_printoptions(profile="default")


gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()
trainer.evaluate()

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

pd.DataFrame(trainer.state.log_history)