In [2]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
import pandas as pd

from plms import (
    ProtT5,
    ProstT5,
    ProtT5Tokenizer,
    ProstT5Tokenizer,
    auto_model,
    auto_tokenizer,
    PLMConfig,
)

In [3]:
# model_name = "Rostlab/prot_t5_xl_uniref50"
model_name = "Rostlab/ProstT5"

tokenizer = auto_tokenizer(model_name)
# Or for more explicity:
# tokenizer = ProtT5Tokenizer(name_or_path="Rostlab/prot_t5_xl_uniref50")

In [None]:
sequence_fasta = """>seq1
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
>seq2
ACDEFGHIKLMNPQRSTVWY
>seq3
acdefghiklmnpqrstvwy
acdefghiklmnpqrstvwy
"""
sequence_strings = ["ACDEFGHIKLMNPQRSTVWYOOO", "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "acdefghiklmnpqrstvwyooo"]

# tokenizer_output = tokenizer.tokenize_fasta(
#     "../../prot-md-pssm-benchmark/scope-benchmark-minimal/data/scope40_sequences_aa_short.fasta",
#     padding=True,
# )
# tokenizer_output = tokenizer.tokenize_fasta(fasta=sequence_fasta, return_headers=True)
tokenizer_output = tokenizer.encode(sequence_strings, padding=True)

print("Encoded tokens:")
for ids, mask in zip(tokenizer_output["input_ids"], tokenizer_output["attention_mask"]):
    print(*[f"{x:<4d}" for x in ids], sep="")
    print(*[f"{x:<4d}" for x in mask], sep="")
    print()

print("\nDecoded tokens:")

input_ids = tokenizer_output["input_ids"]
attention_mask = tokenizer_output["attention_mask"]
decoded_tokens = tokenizer.decode(input_ids)
for i, m, d, o in zip(input_ids, attention_mask, decoded_tokens, sequence_strings):
    print(*[str(x)[:4].ljust(3) for x in i], sep=" ")
    print(*[str(x)[:4].ljust(3) for x in m], sep=" ")
    print(*[str(x)[:4].ljust(3) for x in d], sep=" ")
    print(*[str(x)[:4].ljust(3) for x in o], sep=" ")
    print()

In [5]:
model = auto_model(model_name=model_name)

# Or for more explicity:
# model = ProtT5(name_or_path="Rostlab/prot_t5_xl_uniref50")
# model = ProstT5(name_or_path="Rostlab/ProstT5")

# Or with a config:
# config=PLMConfig(name_or_path="Rostlab/prot_t5_xl_uniref50")
# model = ProtT5(config=config)
# config = PLMConfig(name_or_path="Rostlab/ProstT5")
# model = ProstT5(config=config)

In [6]:
model.eval()
with torch.no_grad():
    embeddings = model(
        input_ids=torch.tensor(tokenizer_output["input_ids"]).to(model.device),
        attention_mask=torch.tensor(tokenizer_output["attention_mask"]).to(model.device),
    )

In [None]:
index = 1
print(sequence_strings[index])
print(*embeddings["mask"][index].tolist(), sep="")

df = pd.DataFrame(embeddings["last_hidden_state"][index].cpu().numpy())
df["sequence"] = list(
    sequence_strings[index] + (embeddings["last_hidden_state"][index].shape[0] - len(sequence_strings[index])) * "-"
)
df["mask"] = embeddings["mask"][index].tolist()
display(df)

In [None]:
from plms.utils.modeling_utils import mean_pool
import numpy as np
print("First")

for x in range(len(embeddings["last_hidden_state"])):
    print(
        mean_pool(
            embeddings["last_hidden_state"][x].unsqueeze(0),
            embeddings["mask"][x].unsqueeze(0),
        ).mean().tolist()
    )

df_mean = mean_pool(
    embeddings["last_hidden_state"],
    embeddings["mask"],
)

print(*embeddings["last_hidden_state"].mean(dim=1).mean(dim=1).tolist(), sep="\n")


print("Second")
first = list(pd.DataFrame(df_mean.cpu().numpy()).iloc[0])
second = embeddings["last_hidden_state"][0].sum(dim=0) / embeddings["mask"][0].sum(dim=0)
second = second.tolist()

print(np.mean(first))
print(np.mean(second))
print(first == second)


# first = list(pd.DataFrame(df_mean.cpu().numpy()))
# second = embeddings["last_hidden_state"].sum(dim=0) / embeddings["mask"].sum(dim=0)
# print(np.mean(first))
# print(np.mean(second))
# print(first == second)


