In [None]:
%reload_ext autoreload
%autoreload 2

import json
import os
import re
import warnings

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import random
import re
import os, glob

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from Bio import BiopythonWarning
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer

from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.modeling_md_pssm import T5EncoderModelForPssmGeneration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

warnings.filterwarnings("ignore", category=BiopythonWarning)

MODEL_NAME = "prot-md-pssm-2025-03-05-17-43-47-full-dataset"
SCOP40_SEQUENCES_FILE = "../tmp/data/scope/scope40_sequences.json"
MODEL_PATH = f"../tmp/models/adapters/{MODEL_NAME}"
PSSM_SAVE_DIR = f"../tmp/data/generated_pssms/{MODEL_NAME}"
PROTEIN_ENCODER_NAME = "Rostlab/prot_t5_xl_uniref50"

AA_ALPHABET = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
STRUCTURE_ALPHABET = [x.lower() for x in AA_ALPHABET]

In [2]:
with open(SCOP40_SEQUENCES_FILE, "r") as f:
    scop_sequences = json.load(f)
    # scop_sequences = dict(list(scop_sequences.items())[:11])
    scop_sequences = dict(list(scop_sequences.items()))

for k, v in scop_sequences.items():
    scop_sequences[k] = " ".join(list(re.sub(r"[UZOB]", "X", v)))

In [None]:
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=PROTEIN_ENCODER_NAME,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

model_config = MDPSSMConfig()
model = T5EncoderModelForPssmGeneration(model_config)
model.load_adapter(MODEL_PATH)
model.to(device)
print("Loaded model")

In [None]:
print("Saving PSSMs to", PSSM_SAVE_DIR)

os.makedirs(PSSM_SAVE_DIR, exist_ok=True)

# Remove existing TSV files in directory
for f in glob.glob(f"{PSSM_SAVE_DIR}/*.tsv"):
    os.remove(f)


def pssm_to_csv(name, pssm):
    df_pssm = pd.DataFrame(pssm)
    with open(f"{PSSM_SAVE_DIR}/{name}.tsv", "w") as f:
        f.write(f"Query profile of sequence {name}\n")
        f.write("     " + "      ".join(AA_ALPHABET) + "      \n")
        df_pssm = df_pssm.round(4)
        df_pssm.to_csv(f, index=False, sep=" ", float_format="%.4f", header=False, lineterminator=" \n")


batch_size = 20
sequence_items = list(scop_sequences.items())
sequence_batches = [dict(sequence_items[i : i + batch_size]) for i in range(0, len(sequence_items), batch_size)]

model.eval()

for batch in tqdm(sequence_batches, desc="Processing batches"):
    protein_tokens = tokenizer(list(batch.values()), return_tensors="pt", padding=True, truncation=False).to(device)

    with torch.no_grad():
        model_output = model(
            input_ids=protein_tokens["input_ids"],
            attention_mask=protein_tokens["attention_mask"],
            output_hidden_states=True,
            return_dict=True,
        )
    torch.cuda.empty_cache()

    for name, pssm, mask, ids in list(zip(batch.keys(), model_output.pssms, model_output.masks, protein_tokens["input_ids"])):
        pssm = pssm[mask.cpu().numpy().astype(bool)].cpu().numpy()
        original_sequence = tokenizer.decode(ids, skip_special_tokens=True).replace(" ", "")
        # print(name, pssm.shape, len(original_sequence))
        pssm_to_csv(name, pssm)
