In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import os
from datetime import datetime

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import pandas as pd
import torch
import wandb
import yaml
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model
from plms import PLMConfig, ProstT5, ProstT5Tokenizer
from torch import nn
from torch.nn import (
    KLDivLoss,
)
from transformers import (
    PreTrainedModel,
    TrainingArguments,
)

from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.metrics import compute_metrics
from src.model.modeling_outputs import PSSMOutput
from src.model.trainer_protein_subset import ProteinSampleSubsetTrainer
from src.model.utils.data_collator import DataCollatorForT5Pssm

with open("../configs/model.yaml", "r") as f:
    train_config = yaml.safe_load(f)

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)

In [None]:
VERBOSE = train_config["verbose"]
SEED = train_config["seed"]

project_name = train_config["project_name"]
custom_run_name = train_config["custom_run_name"].replace(" ", "-")
model_name_identifier = (
    project_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    + (f"-{custom_run_name.replace(' ', '-')}" if custom_run_name else "")
)

USE_WANDB = train_config["weights_and_biases"]["enabled"]

if USE_WANDB:
    run = wandb.init(project=project_name, name=model_name_identifier)

print(model_name_identifier)

In [4]:
class PSSMHead(nn.Module):
    """Head for PSSM generation from T5 embeddings. based on https://github.com/hefeda/PGP/blob/master/prott5_batch_predictor.py#L144"""

    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Conv1d(1024, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Conv1d(32, 20, kernel_size=7, padding=3),
        )

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.classifier(x)
        x = x.transpose(1, 2)
        pssm = torch.softmax(x, dim=2)
        return pssm


class T5EncoderModelForPssmGeneration(PreTrainedModel):
    def __init__(self, config: MDPSSMConfig):
        super().__init__(config=config)
        device_map = config.device if hasattr(config, "device") else "auto"
        plm_config = PLMConfig(
            name_or_path=config.model_name,
            device=device_map,
        )

        self.protein_encoder = ProstT5(config=plm_config)
        self.pssm_head = PSSMHead().to(device_map)
        self.loss_fct = KLDivLoss(reduction="batchmean")

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # print(input_ids.shape)
        # print(attention_mask.shape)
        # print(attention_mask.sum())
        encoder_outputs = self.protein_encoder.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
        )

        # [batch_size, seq_len, hidden_dim]
        hidden_states = encoder_outputs["last_hidden_state"]
        attention_mask = encoder_outputs["mask"]

        # print(hidden_states.shape)
        # print(attention_mask.shape)
        # print(attention_mask.sum())
        # df = pd.DataFrame(hidden_states[0].cpu().numpy())
        # df.insert(0, "attention_mask", attention_mask[0].cpu().numpy())
        # df.insert(0, "input_ids", input_ids[0][1:].cpu().numpy())
        # display(df)

        # [batch_size, seq_len, 20]
        pssm = self.pssm_head(hidden_states)

        loss = None
        if labels is not None:
            # [batch_size * seq_len, 20]
            pred = pssm.flatten(end_dim=1)
            target = labels.flatten(end_dim=1)
            # print(target.shape)
            # print(pred.shape)

            pred_mask = attention_mask.flatten(end_dim=1)
            target_mask = ~torch.any(target == -100, dim=1)

            pred = pred[pred_mask.bool()]
            target = target[target_mask.bool()]

            # print(target.shape)
            # print(pred.shape)

            loss = self.loss_fct(torch.log(pred), target)
            # print(loss)

        if not return_dict:
            output = (pssm, encoder_outputs[2:-1])
            return ((loss,) + output) if loss is not None else output

        return PSSMOutput(
            loss=loss,
            pssms=pssm,
            hidden_states=encoder_outputs["last_hidden_state"] if output_hidden_states else None,
            masks=attention_mask,
        )


device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

config = MDPSSMConfig(device=device, model_name="Rostlab/ProstT5")
model = T5EncoderModelForPssmGeneration(config)

---

# Training


In [None]:
def check_cuda_params(model):
    all_cuda = True
    non_cuda_params = []

    for name, param in model.named_parameters():
        if not param.is_cuda:
            all_cuda = False
            non_cuda_params.append(name)

    if all_cuda:
        print("All parameters are on CUDA")
    else:
        print("The following parameters are not on CUDA:")
        for param_name in non_cuda_params:
            print(f"- {param_name}")


check_cuda_params(model)


In [None]:
target_modules = ["q", "v"]
modules_to_save = ["pssm_head"]

lora_config = LoraConfig(
    inference_mode=False,
    r=train_config["lora"]["r"],
    lora_alpha=train_config["lora"]["lora_alpha"],
    lora_dropout=train_config["lora"]["lora_dropout"],
    target_modules=target_modules,
    bias="none",
    modules_to_save=modules_to_save,
    use_rslora=train_config["lora"]["use_rslora"],
    use_dora=train_config["lora"]["use_dora"],
)

model = get_peft_model(model, lora_config)

print("target_modules:", target_modules)
print("modules_to_save:", modules_to_save)
model.print_trainable_parameters()

In [None]:
dataset = load_from_disk(train_config["dataset"]["path"])
dataset = dataset.rename_column("pssm_features", "labels")
dataset = dataset.remove_columns(["name", "sequence"])

tokenizer = ProstT5Tokenizer().get_tokenizer()

data_collator = DataCollatorForT5Pssm(
    tokenizer=tokenizer,
    padding=True,
    pad_to_multiple_of=8,
)

training_args = TrainingArguments(
    output_dir=f"../tmp/models/checkpoints/{model_name_identifier}",
    run_name=model_name_identifier if USE_WANDB else None,
    report_to="wandb" if USE_WANDB else None,
    learning_rate=train_config["trainer"]["learning_rate"],
    per_device_train_batch_size=train_config["trainer"]["train_batch_size"],
    num_train_epochs=train_config["trainer"]["num_epochs"],
    eval_strategy=train_config["trainer"]["eval_strategy"],
    eval_steps=train_config["trainer"]["eval_steps"],
    per_device_eval_batch_size=train_config["trainer"]["eval_batch_size"],
    eval_on_start=train_config["trainer"]["eval_on_start"],
    batch_eval_metrics=train_config["trainer"]["batch_eval_metrics"],
    save_strategy=train_config["trainer"]["save_strategy"],
    save_steps=train_config["trainer"]["save_steps"],
    save_total_limit=train_config["trainer"]["save_total_limit"],
    remove_unused_columns=train_config["trainer"]["remove_unused_columns"],
    label_names=["input_ids", "attention_mask"],
    logging_strategy="steps",
    logging_steps=train_config["trainer"]["logging_steps"],
    seed=train_config["seed"],
    lr_scheduler_type=train_config["trainer"]["lr_scheduler_type"],
    warmup_steps=train_config["trainer"]["warmup_steps"],
)

trainer = ProteinSampleSubsetTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()
trainer.evaluate()

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()


In [None]:
import pandas as pd
from src.plots.train_plots import plot_training_history
import matplotlib.pyplot as plt

model_save_path = f"../tmp/models/adapters/{model_name_identifier}"

model.save_pretrained(save_directory=model_save_path)

pd.DataFrame(trainer.state.log_history).to_csv(f"{model_save_path}/training_log.csv", index=False)

with open(f"{model_save_path}/train_config.yaml", "w") as f:
    train_config["model"]["reload_from_checkpoint_path"] = model_save_path
    yaml.dump(train_config, f, sort_keys=False)

fig = plot_training_history(
    log_history=pd.DataFrame(trainer.state.log_history), train_config=train_config, metrics_names=["loss"]
)
fig.savefig(f"{model_save_path}/training_history.png")
plt.close(fig)

print("Model, config, and log saved to:", model_save_path)


In [None]:
import re
from transformers import T5Tokenizer


def get_model_info(model):
    info = []

    info.append(f"Model device: {next(model.parameters()).device}")
    info.append(f"Model Protein Encoder device: {next(model.protein_encoder.parameters()).device}")

    info.append("\nProtein Encoder (T5) Parameter dtypes:")
    for name, param in model.protein_encoder.named_parameters():
        info.append(f"{name}: {param.dtype}")

    info.append("\nPSSM Head Parameter dtypes:")
    for name, param in model.pssm_head.named_parameters():
        info.append(f"{name}: {param.dtype}")

    return "\n".join(info)


def compare_model_parameters_state_dicts(model1, model2, should_match=True, verbose=False):
    """
    Compare two models parameter by parameter.

    Args:
        model1: First model
        model2: Second model
        verbose: If True, print details about each parameter comparison

    Returns:
        bool: True if models are identical, False otherwise
    """
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    keys1 = set(state_dict1.keys())
    keys2 = set(state_dict2.keys())

    missing_in_2 = keys1 - keys2
    missing_in_1 = keys2 - keys1

    if missing_in_2 or missing_in_1:
        print("Found mismatch in parameters\n")
        print(f"Models are identical: {False}")

        if missing_in_2:
            print("Keys missing from reloaded model:")
            for key in sorted(missing_in_2):
                print(f"Key {key} missing from reloaded model")

        if missing_in_1:
            print("Keys missing from original model:")
            for key in sorted(missing_in_1):
                print(f"Key {key} missing from original model")

        return False

    parameters_match = True
    mismatched_params = []

    for key in list(state_dict1.keys()):
        param1 = state_dict1[key]
        param2 = state_dict2[key]

        if param1.shape != param2.shape:
            parameters_match = False
            mismatched_params.append((key, "shape mismatch", param1.shape, param2.shape))
            continue

        if not torch.allclose(param1.float(), param2.float(), rtol=1e-5, atol=1e-8):
            parameters_match = False
            mismatched_params.append((key, "value mismatch", torch.max(torch.abs(param1 - param2)).item()))

    if verbose and not parameters_match:
        print("Parameter mismatches:")
        for index, param_info in enumerate(mismatched_params):
            if len(param_info) == 4:
                key, msg, shape1, shape2 = param_info
                print(f"{index}: {key}: {msg} - shape1: {shape1}, shape2: {shape2}")
            else:
                key, msg, diff = param_info
                print(f"{index}: {key}: {msg} - max difference: {diff}")

    print(f"State dictionaries are identical: {parameters_match}")

    return parameters_match


def compare_model_embeddings(model, reloaded_model, train_config, tokenizer_plm=None, dummy_proteins=None):
    """Compare embeddings between original and reloaded models using sample inputs.

    Args:
        model: Original model
        reloaded_model: Reloaded model to compare against
        train_config: Training configuration
        tokenizer_plm: Protein sequence tokenizer
        dummy_proteins: List of sample protein sequences. Defaults to two test sequences.

    Returns:
        tuple: (protein_match, protein_exact_match)
    """
    if dummy_proteins is None:
        dummy_proteins = [
            "MLKFVVVLAAVLSLYAYAPAFEVHNKKNVLMQRVGETLRISDRYLYQTLSKPYKVTLKTLDGHEIFEVVGEAPVTFRFKDKERPVVVASPEHVVGIVAVHNGKIYARNLYIQNISIVSAGGQHSYSGLSWRYNQPNDGKVTDYF",
            "MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGE",
        ]

    dummy_proteins = [" ".join(list(re.sub(r"[UZOB]", "X", x))) for x in dummy_proteins]

    if tokenizer_plm is None:
        tokenizer_plm = T5Tokenizer.from_pretrained(
            pretrained_model_name_or_path="Rostlab/ProstT5",
            do_lower_case=False,
            use_fast=True,
            legacy=False,
        )
    protein_tokens = tokenizer_plm(dummy_proteins, return_tensors="pt", padding=True, truncation=False)

    protein_tokens = {k: v.to(model.device) for k, v in protein_tokens.items()}

    model.eval()
    with torch.no_grad():
        protein_emb_orig = model(
            input_ids=protein_tokens["input_ids"],
            attention_mask=protein_tokens["attention_mask"],
            return_dict=True,
        )

    reloaded_model.eval()
    with torch.no_grad():
        protein_emb_reload = reloaded_model(
            input_ids=protein_tokens["input_ids"],
            attention_mask=protein_tokens["attention_mask"],
            return_dict=True,
        )

    protein_match = torch.allclose(protein_emb_orig.hidden_states, protein_emb_reload.hidden_states, rtol=1e-4, atol=1e-4)

    protein_exact_match = torch.equal(protein_emb_orig.hidden_states, protein_emb_reload.hidden_states)

    print(f"Protein embeddings shape: {protein_emb_orig.hidden_states.shape}")
    print(f"Protein embeddings match: {protein_match}")
    print(f"Protein embeddings exact match: {protein_exact_match}")

    return protein_match, protein_exact_match


def check_model_on_cuda(model):
    """Check if all model parameters are on CUDA device."""
    if torch.cuda.is_available():
        cuda_check_failed = False
        for name, param in model.named_parameters():
            if not param.is_cuda:
                print(f"WARNING: Parameter {name} is not on CUDA")
                cuda_check_failed = True
        if not cuda_check_failed:
            print("All model parameters are on CUDA")
        else:
            print("Some parameters are not on CUDA - see warnings above")
    else:
        print("CUDA is not available")


def check_model_parameters_requires_grad(model):
    """Check if all model parameters require gradients."""
    grad_check_failed = False
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # print(f"WARNING: Parameter {name} does not require gradients")
            grad_check_failed = True
    if not grad_check_failed:
        print("All model parameters require gradients")
    else:
        print("Some parameters do not require gradients - see warnings above")


def sanity_checks_new(model, train_config, model_save_path):
    # reload model

    reloaded_config = MDPSSMConfig(device=device, model_name="Rostlab/ProstT5")
    reloaded_model = T5EncoderModelForPssmGeneration(config)
    reloaded_model.to(model.device)

    if train_config["lora"]["enabled"]:
        model = model.base_model.model

    models_match = compare_model_parameters_state_dicts(reloaded_model, model, should_match=False, verbose=True)
    print("Models match (should NOT match):", models_match)

    if train_config["lora"]["enabled"]:
        reloaded_model.load_adapter(model_save_path)
    else:
        reloaded_model.load_projections_from_safetensors(model_save_path)

    models_match = compare_model_parameters_state_dicts(reloaded_model, model, should_match=True, verbose=True)
    print("Models match (should match):", models_match)

    compare_model_embeddings(model, reloaded_model, train_config)


sanity_checks_new(model, train_config, model_save_path)

---

# Inference


In [5]:
import warnings
from Bio import BiopythonWarning
import json
import os
from tqdm import tqdm
import re

warnings.filterwarnings("ignore", category=BiopythonWarning)

MODEL_NAME = "prot-md-pssm-2025-04-02-15-20-18-dataset_320_0_prostt5"
SCOP40_SEQUENCES_FILE = "../tmp/data/scope/scope40_sequences.json"
MODEL_PATH = f"../tmp/models/adapters/{MODEL_NAME}"
PSSM_SAVE_DIR = f"../tmp/data/generated_pssms/scope40_{MODEL_NAME}"
PROTEIN_ENCODER_NAME = "Rostlab/prot_t5_xl_uniref50"

AA_ALPHABET = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
STRUCTURE_ALPHABET = [x.lower() for x in AA_ALPHABET]

In [6]:
with open(SCOP40_SEQUENCES_FILE, "r") as f:
    scop_sequences = json.load(f)
    # scop_sequences = dict(list(scop_sequences.items())[:11])
    scop_sequences = dict(list(scop_sequences.items()))

# for k, v in scop_sequences.items():
#     scop_sequences[k] = " ".join(list(re.sub(r"[UZOB]", "X", v)))

In [11]:
# scop_sequences = {"d12asa_": scop_sequences['d12asa_']}

In [None]:
# scop_sequences

In [None]:
model.load_adapter(MODEL_PATH)
model.to(device)
model.eval()
print("Loaded model")

In [None]:
os.makedirs(PSSM_SAVE_DIR, exist_ok=True)

tokenizer = ProstT5Tokenizer()


def pssm_to_csv(name, pssm):
    df_pssm = pd.DataFrame(pssm)
    with open(f"{PSSM_SAVE_DIR}/{name}.tsv", "w") as f:
        f.write(f"Query profile of sequence {name}\n")
        f.write("     " + "      ".join(AA_ALPHABET) + "      \n")
        df_pssm = df_pssm.round(4)
        df_pssm.to_csv(f, index=False, sep=" ", float_format="%.4f", header=False, lineterminator=" \n")


batch_size = 20
sequence_items = list(scop_sequences.items())
sequence_batches = [dict(sequence_items[i : i + batch_size]) for i in range(0, len(sequence_items), batch_size)]


for batch in tqdm(sequence_batches, desc="Processing batches"):
    # print(batch.values())
    protein_tokens = tokenizer.encode(list(batch.values()), return_tensors="pt", padding=True, truncation=False).to(device)

    with torch.no_grad():
        model_output = model(
            input_ids=protein_tokens["input_ids"],
            attention_mask=protein_tokens["attention_mask"],
            output_hidden_states=True,
            return_dict=True,
        )
    torch.cuda.empty_cache()

    for name, pssm, mask, ids in zip(batch.keys(), model_output.pssms, model_output.masks, protein_tokens["input_ids"]):
        pssm = pssm[mask.cpu().numpy().astype(bool)].cpu().numpy()
        original_sequence = tokenizer.decode(ids, skip_special_tokens=True)  # .replace(" ", "")
        # print(name)
        # print(pssm.shape, mask.sum(), len(original_sequence))
        # print(*[f"{x:<4}" for x in original_sequence[1:]], sep="")
        # print(*[f"{x:<4}" for x in ids[1:]], sep="")
        # print(*[f"{x:<4}" for x in mask], sep="")
        # print(*[f"{x:<4}" for x in pssm.argmax(axis=1)], sep="")
        # print()
        # print(name, pssm.shape, len(original_sequence))
        pssm_to_csv(name, pssm)

In [None]:
index = 1

print(len(list(batch.values())[index].replace(" ", "")))
print(list(batch.values())[index].replace(" ", ""))
print(*model_output["masks"][index].tolist(), sep="")
print(model_output["pssms"][index].shape)
print(model_output["masks"][index].sum())