In [1]:
%load_ext autoreload
%autoreload 2

import gc
import os
from datetime import datetime

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import pandas as pd
import torch
import yaml
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model
from plms import PLMConfig, ProstT5, ProstT5Tokenizer
from torch import nn
from torch.nn import (
    KLDivLoss,
)
from transformers import (
    PreTrainedModel,
    TrainingArguments,
)

from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.metrics import compute_metrics
from src.model.modeling_outputs import PSSMOutput
from src.model.trainer_protein_subset import ProteinSampleSubsetTrainer
from src.model.utils.data_collator import DataCollatorForT5Pssm

with open("../configs/model.yaml", "r") as f:
    train_config = yaml.safe_load(f)

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)

In [2]:
class PSSMHead(nn.Module):
    """Head for PSSM generation from T5 embeddings. based on https://github.com/hefeda/PGP/blob/master/prott5_batch_predictor.py#L144"""

    def __init__(self):
        """
        Args:
            config (MDPSSMConfig): Configuration object for the model
        """
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Conv1d(1024, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Conv1d(32, 20, kernel_size=7, padding=3),
        )

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.classifier(x)
        x = x.transpose(1, 2)
        pssm = torch.softmax(x, dim=2)
        return pssm


class T5EncoderModelForPssmGeneration(PreTrainedModel):
    def __init__(self, config: MDPSSMConfig):
        super().__init__(config=config)
        device_map = config.device if hasattr(config, "device") else "auto"
        plm_config = PLMConfig(
            name_or_path=config.model_name,
            device=device_map,
        )

        self.protein_encoder = ProstT5(config=plm_config)
        self.pssm_head = PSSMHead().to(device_map)
        self.loss_fct = KLDivLoss(reduction="batchmean")

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # print(input_ids.shape)
        # print(attention_mask.shape)
        # print(attention_mask.sum())
        encoder_outputs = self.protein_encoder.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
        )

        # [batch_size, seq_len, hidden_dim]
        hidden_states = encoder_outputs["last_hidden_state"]
        attention_mask = encoder_outputs["mask"]

        # print(hidden_states.shape)
        # print(attention_mask.shape)
        # print(attention_mask.sum())
        # df = pd.DataFrame(hidden_states[0].cpu().numpy())
        # df.insert(0, "attention_mask", attention_mask[0].cpu().numpy())
        # df.insert(0, "input_ids", input_ids[0][1:].cpu().numpy())
        # display(df)

        # [batch_size, seq_len, 20]
        pssm = self.pssm_head(hidden_states)

        loss = None
        if labels is not None:
            # [batch_size * seq_len, 20]
            pred = pssm.flatten(end_dim=1)
            target = labels.flatten(end_dim=1)
            # print(target.shape)
            # print(pred.shape)

            pred_mask = attention_mask.flatten(end_dim=1)
            target_mask = ~torch.any(target == -100, dim=1)

            pred = pred[pred_mask.bool()]
            target = target[target_mask.bool()]

            # print(target.shape)
            # print(pred.shape)

            loss = self.loss_fct(torch.log(pred), target)
            # print(loss)

        if not return_dict:
            output = (pssm, encoder_outputs[2:-1])
            return ((loss,) + output) if loss is not None else output

        return PSSMOutput(
            loss=loss,
            pssms=pssm,
            hidden_states=encoder_outputs["last_hidden_state"] if output_hidden_states else None,
            masks=attention_mask,
        )


In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

config = MDPSSMConfig(device=device, model_name="Rostlab/ProstT5")
model = T5EncoderModelForPssmGeneration(config)

In [None]:
def check_cuda_params(model):
    all_cuda = True
    non_cuda_params = []

    for name, param in model.named_parameters():
        if not param.is_cuda:
            all_cuda = False
            non_cuda_params.append(name)

    if all_cuda:
        print("All parameters are on CUDA")
    else:
        print("The following parameters are not on CUDA:")
        for param_name in non_cuda_params:
            print(f"- {param_name}")


check_cuda_params(model)


In [None]:
target_modules = ["q", "v"]
modules_to_save = ["pssm_head"]

lora_config = LoraConfig(
    inference_mode=False,
    r=train_config["lora"]["r"],
    lora_alpha=train_config["lora"]["lora_alpha"],
    lora_dropout=train_config["lora"]["lora_dropout"],
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    use_rslora=train_config["lora"]["use_rslora"],
    use_dora=train_config["lora"]["use_dora"],
)

model = get_peft_model(model, lora_config)

print("target_modules:", target_modules)
print("modules_to_save:", modules_to_save)
model.print_trainable_parameters()

In [None]:
project_name = train_config["project_name"]
custom_run_name = train_config["custom_run_name"].replace(" ", "-")
model_name_identifier = (
    project_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    + (f"-{custom_run_name.replace(' ', '-')}" if custom_run_name else "")
)

dataset = load_from_disk(train_config["dataset"]["path"])
dataset = dataset.rename_column("pssm_features", "labels")
dataset = dataset.remove_columns(["name", "sequence"])

tokenizer = ProstT5Tokenizer().get_tokenizer()

data_collator = DataCollatorForT5Pssm(
    tokenizer=tokenizer,
    padding=True,
    pad_to_multiple_of=8,
)

training_args = TrainingArguments(
    output_dir=f"../tmp/models/checkpoints/{model_name_identifier}",
    # run_name=model_name_identifier if USE_WANDB else None,
    # report_to="wandb" if USE_WANDB else None,
    learning_rate=train_config["trainer"]["learning_rate"],
    per_device_train_batch_size=train_config["trainer"]["train_batch_size"],
    num_train_epochs=train_config["trainer"]["num_epochs"],
    eval_strategy=train_config["trainer"]["eval_strategy"],
    eval_steps=train_config["trainer"]["eval_steps"],
    per_device_eval_batch_size=train_config["trainer"]["eval_batch_size"],
    eval_on_start=train_config["trainer"]["eval_on_start"],
    batch_eval_metrics=train_config["trainer"]["batch_eval_metrics"],
    save_strategy=train_config["trainer"]["save_strategy"],
    save_steps=train_config["trainer"]["save_steps"],
    save_total_limit=train_config["trainer"]["save_total_limit"],
    remove_unused_columns=train_config["trainer"]["remove_unused_columns"],
    label_names=["input_ids", "attention_mask"],
    logging_strategy="steps",
    logging_steps=train_config["trainer"]["logging_steps"],
    seed=train_config["seed"],
    lr_scheduler_type=train_config["trainer"]["lr_scheduler_type"],
    warmup_steps=train_config["trainer"]["warmup_steps"],
)

trainer = ProteinSampleSubsetTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()
trainer.evaluate()

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()


In [None]:
import pandas as pd
from src.plots.train_plots import plot_training_history
import matplotlib.pyplot as plt

model_save_path = f"../tmp/models/adapters/{model_name_identifier}"

model.save_pretrained(save_directory=model_save_path)

pd.DataFrame(trainer.state.log_history).to_csv(f"{model_save_path}/training_log.csv", index=False)

with open(f"{model_save_path}/train_config.yaml", "w") as f:
    train_config["model"]["reload_from_checkpoint_path"] = model_save_path
    yaml.dump(train_config, f, sort_keys=False)

fig = plot_training_history(
    log_history=pd.DataFrame(trainer.state.log_history), train_config=train_config, metrics_names=["loss"]
)
fig.savefig(f"{model_save_path}/training_history.png")
plt.close(fig)

print("Model, config, and log saved to:", model_save_path)


In [None]:
abc = torch.arange(200).reshape(2, 5, 20)
abc = abc.flatten(end_dim=1)
display(pd.DataFrame(abc))

mask = torch.tensor([[1, 1, 1, 0, 0], [1, 0, 0, 0, 0]]).flatten(end_dim=1)
display(pd.DataFrame(abc[mask.bool()]))
