In [1]:
%reload_ext autoreload
%autoreload 2

import json
import os
import re
import warnings

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import random
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from Bio import BiopythonWarning
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer

from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.modeling_md_pssm import T5EncoderModelForPssmGeneration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

warnings.filterwarnings("ignore", category=BiopythonWarning)

SCOP40_SEQUENCES_FILE = "../tmp/data/scope/scope40_sequences.json"
MODEL_PATH = "../tmp/models/adapters/prot-md-pssm-2025-03-05-17-43-47-full-dataset"
PROTEIN_ENCODER_NAME = "Rostlab/prot_t5_xl_uniref50"

AA_ALPHABET = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
STRUCTURE_ALPHABET = [x.lower() for x in AA_ALPHABET]

Matplotlib created a temporary cache directory at /tmp/matplotlib-yelvdbe_ because the default path (/home/lfi/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
with open(SCOP40_SEQUENCES_FILE, "r") as f:
    sequences = json.load(f)

sequences_with_x = sum(1 for seq in sequences.values() if "X" in seq.upper())
total_x_chars = sum(seq.upper().count("X") for seq in sequences.values())

print("Number of sequences: ", len(sequences))
print(f"Number of sequences containing X: {sequences_with_x}")
print("Fraction of sequences containing X: ", sequences_with_x / len(sequences))
print()
print("Total number of characters across all sequences: ", sum(len(seq) for seq in sequences.values()))
print(f"Total number of X characters across all sequences: {total_x_chars}")
print(f"Fraction of X characters across all sequences: {total_x_chars / sum(len(seq) for seq in sequences.values())}")
print()

named_sequences = dict(list(sequences.items())[:33])

# for i, (k, v) in enumerate(sequences.items()):
#     print(f"{i} {k} {len(v)}: {v}")
# print()

for k, v in named_sequences.items():
    named_sequences[k] = " ".join(list(re.sub(r"[UZOB]", "X", v)))

for i, (k, v) in enumerate(named_sequences.items()):
    if i == 4:
        break
    print(f"{i} {k} {len(v)}: {v}")
print()

seq_lengths = {k: len(v.replace(" ", "")) for k, v in named_sequences.items()}
longest_seq = max(seq_lengths.items(), key=lambda x: x[1])
shortest_seq = min(seq_lengths.items(), key=lambda x: x[1])

print(f"Longest sequence: {longest_seq[0]} with length {longest_seq[1]}")
print(f"Shortest sequence: {shortest_seq[0]} with length {shortest_seq[1]}")

Number of sequences:  11211
Number of sequences containing X: 0
Fraction of sequences containing X:  0.0

Total number of characters across all sequences:  1941189
Total number of X characters across all sequences: 0
Fraction of X characters across all sequences: 0.0

0 d3ci0k3 211: G R T R S Q Q E Y Q Q A L W Y S A S A E S L A L S A L S L S L K N E K R V H L E Q P W A S G P R F F P L P Q G Q I A V T L R D A Q S N Y F W L R S D I T V N E I E L T M N S L I V R M G P Q H F S V L W H Q T G E S
1 d2g50a3 269: E L A R A S S Q S T D L M E A M A M G S V E A S Y K C L A A A L I V L T E S G R S A H Q V A R Y R P R A P I I A V T R N H Q T A R Q A H L Y R G I F P V V C K D P V Q E A W A E D V D L R V N L A M N V G K A R G F F K K G D V V I V L T G W R P G S G F T N T M R V V P V P
2 d2idra_ 353: A H P L E N A W T F W F D N P Q G K S R Q V A W G S T I H P I H T F S T V E D F W G L Y N N I H N P S K L N V G A D F H C F K N K I E P K W E D P I C A N G G K W T I S C G R G K S D T F W L H T L L A M I 

In [3]:
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=PROTEIN_ENCODER_NAME,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

model_config = MDPSSMConfig()
model = T5EncoderModelForPssmGeneration(model_config)
model.load_adapter(MODEL_PATH)
model.to(device)

# protein_tokens = tokenizer(list(sequences.values()), return_tensors="pt", padding=True, truncation=False)
# protein_tokens = {k: v.to(device) for k, v in protein_tokens.items()}

# decoded_sequence = tokenizer.decode(protein_tokens["input_ids"][0], skip_special_tokens=False)
# print(decoded_sequence)
# print(*protein_tokens["attention_mask"][0].tolist())

# model.eval()
# with torch.no_grad():
#     protein_emb = model(
#         input_ids=protein_tokens["input_ids"],
#         attention_mask=protein_tokens["attention_mask"],
#         return_dict=True,
#     )

T5EncoderModelForPssmGeneration(
  (protein_encoder): T5EncoderModel(
    (shared): Embedding(128, 1024)
    (encoder): T5Stack(
      (embed_tokens): Embedding(128, 1024)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): lora.Linear(
                  (base_layer): Linear(in_features=1024, out_features=4096, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_features=4096, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                  (

In [None]:
test_sequence = [" ".join(random.choices(AA_ALPHABET, k=1000)) for x in range(10)]
# print(test_sequence)
# protein_tokens = tokenizer(test_sequence, return_tensors="pt", padding=True, truncation=False)
# protein_tokens.to(device)
# print(protein_tokens)

In [None]:
def plot_pssm(pssm, mask, original_sequence, prost_values=None):
    plt.figure(figsize=(30, 7))
    sns.heatmap(
        pssm.T,
        cmap="viridis",
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Probability"},
        linewidths=0.5,
        linecolor="black",
    )
    plt.xticks(np.arange(len(original_sequence)) + 0.5, original_sequence, rotation=0, fontfamily="monospace")
    print()

    plt.yticks(np.arange(len(STRUCTURE_ALPHABET)) + 0.5, STRUCTURE_ALPHABET, rotation=0, fontfamily="monospace")

    plt.title("PSSM Heatmap", fontfamily="monospace")
    plt.xlabel("Sequence (Top: Amino Acid, Bottom: 3Di with ProstT5)", fontfamily="monospace")
    plt.ylabel("3Di", fontfamily="monospace")

    plt.show()


def run_batch(model, input_ids, attention_mask):
    # print(*attention_mask[0].tolist(), sep="")
    # print(*tokenizer.decode(input_ids[0], skip_special_tokens=True).replace(" ", ""), sep="")

    model.eval()
    with torch.no_grad():
        model_output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )

    for pssm, mask, ids in list(zip(model_output.pssms, model_output.masks, input_ids)):
        original_sequence = tokenizer.decode(ids, skip_special_tokens=True).replace(" ", "")
        # pssm = pssm.cpu().numpy()
        # pssm = pssm[mask.cpu().numpy().astype(bool)]
        # print(original_sequence)
        # print(*mask.tolist(), sep="")
        # print(original_sequence)
        # plot_pssm(pssm, mask.cpu().numpy(), original_sequence)


batch_size = 10
named_sequences_items = list(named_sequences.items())
for i in range(0, len(named_sequences_items), batch_size):
    batch = dict(named_sequences_items[i : i + batch_size])
    protein_tokens = tokenizer(list(batch.values()), return_tensors="pt", padding=True, truncation=False)
    protein_tokens = {k: v.to(device) for k, v in protein_tokens.items()}
    print(protein_tokens)


# run_batch(model, protein_tokens["input_ids"], protein_tokens["attention_mask"])

---


In [4]:
tokenizer_prost = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False)
model_prost = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
model_prost.float() if device.type == "cpu" else model_prost.half()

sequence_examples = ["<AA2fold> " + s for s in sequences.values()]
ids = tokenizer_prost.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest", return_tensors="pt").to(
    device
)

index = 0

print(len(list(sequences.values())[0]))
print(*[f"{x:<10}" for x in re.split("><| |</", tokenizer_prost.decode(ids["input_ids"][index]))])

print(*[f"{x:<10}" for x in ids["input_ids"][index].tolist()])
print(*[f"{x:<10}" for x in ids["attention_mask"][index].tolist()])

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


106
<AA2fold>  <unk       /s>       
149        2          1         
1          1          1         


In [5]:
gen_kwargs_aa2fold = {
    "do_sample": True,
    "num_beams": 3,
    "top_p": 0.95,
    "temperature": 1.2,
    "top_k": 6,
    "repetition_penalty": 1.2,
}

In [6]:
model_prost.eval()
with torch.no_grad():
    translations = model_prost.generate(
        ids.input_ids,
        attention_mask=ids.attention_mask,
        early_stopping=True,
        num_return_sequences=1,
        max_length=ids.input_ids.shape[1],
        **gen_kwargs_aa2fold,
    )

OutOfMemoryError: CUDA out of memory. Tried to allocate 790.00 MiB. GPU 0 has a total capacity of 47.45 GiB of which 535.50 MiB is free. Process 4183009 has 46.93 GiB memory in use. Of the allocated memory 45.47 GiB is allocated by PyTorch, and 668.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [13]:
decoded_translations = tokenizer_prost.batch_decode(translations, skip_special_tokens=True)
structure_sequences = ["".join(ts.split(" ")) for ts in decoded_translations]

In [None]:
print(translations)
print(structure_sequences)
for x, y in zip(structure_sequences, sequences.values()):
    print(len(x), len(y.replace(" ", "")))

In [None]:
total = 0

for x in sequences.values():
    total += len(x.replace(" ", ""))
    print(x.replace(" ", ""))
print(total)


---


In [7]:
import numpy as np
import pandas as pd

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)

for i, ((k, v), (logits, mask)) in enumerate(zip(sequences.items(), zip(protein_emb.pssms, protein_tokens["attention_mask"]))):
    # if i == 2:
    #     break

    df_pssms = pd.DataFrame(logits.cpu().numpy()[: mask.sum()])
    original_sequence = v.split()
    original_sequence = [f"{a}\n{b}" for a, b in zip(original_sequence, structure_sequences[i])]

    # fig = px.imshow(
    #     df_pssms.T,
    #     color_continuous_scale="viridis",
    #     range_color=[0, 1],
    #     title=f"PSSM Heatmap for {k}",
    #     labels={"x": "Sequence", "y": "Position"},
    # )

    # # Create labels combining original sequence and structure sequence
    # combined_labels = [f"{a}\n{b}" for a, b in zip(original_sequence, structure_sequences[i])]
    # fig.update_xaxes(ticktext=combined_labels, tickvals=list(range(len(original_sequence))))
    # fig.show()

    plt.figure(figsize=(30, 7))
    sns.heatmap(
        df_pssms.T,
        cmap="viridis",
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Probability"},
        linewidths=0.5,
        linecolor="black",
    )
    plt.xticks(np.arange(len(original_sequence)) + 0.5, original_sequence, rotation=0, fontfamily="monospace")
    print()

    plt.yticks(np.arange(len(STRUCTURE_ALPHABET)) + 0.5, STRUCTURE_ALPHABET, rotation=0, fontfamily="monospace")

    plt.title("PSSM Heatmap", fontfamily="monospace")
    plt.xlabel("Sequence (Top: Amino Acid, Bottom: 3Di with ProstT5)", fontfamily="monospace")
    plt.ylabel("3Di", fontfamily="monospace")

    plt.show()


NameError: name 'protein_emb' is not defined

---
