In [1]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
import pandas as pd

from plms.configurations.configuration_base_plm import PLMConfig
from plms.models import auto_model, auto_tokenizer
from plms.models.T5.modeling_protT5 import ProtT5
from plms.models.T5.tokenization_protT5 import ProtT5Tokenizer

In [2]:
# model_name = "Rostlab/prot_t5_xl_uniref50"
model_name = "Rostlab/ProstT5"

tokenizer = auto_tokenizer(model_name)
# Or for more explicity:
# tokenizer = ProtT5Tokenizer(name_or_path="Rostlab/prot_t5_xl_uniref50")


model = auto_model(model_name=model_name)
# Or for more explicity:
# model = ProtT5(name_or_path="Rostlab/prot_t5_xl_uniref50")
# Or with a config:
# config=PLMConfig(name_or_path="Rostlab/prot_t5_xl_uniref50")
# model = ProtT5(config=config)

In [37]:
sequence_fasta = """>seq1
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
ACDEFGHIKLMNPQRSTVWY
>seq2
ACDEFGHIKLMNPQRSTVWY
>seq3
acdefghiklmnpqrstvwy
acdefghiklmnpqrstvwy
"""
sequence_strings = ["ACDEFGHIKLMNPQRSTVWYO", "ACDEFGHIKLMNPQRSTVWYOACDEFGHIKLMNPQRSTVWYO", "acdefghiklmnpqrstvwyo"]

# tokenizer_output = tokenizer.tokenize_fasta(
#     "../../prot-md-pssm-benchmark/scope-benchmark-minimal/data/scope40_sequences_aa_short.fasta",
#     padding=True,
# )
# tokenizer_output = tokenizer.tokenize_fasta(fasta=sequence_fasta, return_headers=True)
tokenizer_output = tokenizer.encode(sequence_strings, padding=True)


for ids, mask in zip(tokenizer_output["input_ids"], tokenizer_output["attention_mask"]):
    print(*[f"{x:<4d}" for x in ids], sep="")
    print(*[f"{x:<4d}" for x in mask], sep="")

149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   
148 128 147 135 134 140 130 145 137 139 129 144 142 138 141 133 132 136 131 146 143 1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   0   0   0   0   0   0   0  

In [38]:
input_ids = tokenizer_output["input_ids"]
attention_mask = tokenizer_output["attention_mask"]
decoded_tokens = tokenizer.decode(input_ids)
for i, m, d in zip(input_ids, attention_mask, decoded_tokens):
    print(*[str(x)[:4].ljust(3) for x in i], sep=" ")
    print(*[str(x)[:4].ljust(3) for x in m], sep=" ")
    print(*[str(x)[:4].ljust(3) for x in d], sep=" ")
    print()


149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  
A   C   D   E   F   G   H   I   K   L   M   N   P   Q   R   S   T   V   W   Y   X  

149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1  
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1  
A   C   D   E   F   G   H   I   K   L   M   N   P   Q   R   S   T   V   W   Y   X   A   C   D   E   F   G   H   I   K   L   M   N   P   Q   R   S   T   V   W   Y   X  

148 128 147 135 134 140 130 145 137 139 12

In [70]:
model.eval()
with torch.no_grad():
    embeddings = model(
        input_ids=torch.tensor(tokenizer_output["input_ids"]).to(model.device),
        attention_mask=torch.tensor(tokenizer_output["attention_mask"]).to(model.device),
    )

149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
149 3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  3   22  10  9   15  5   20  12  14  4   19  17  13  16  8   7   11  6   21  18  23  1   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   
148 128 147 135 134 140 130 145 137 139 129 144 142 138 141 133 132 136 131 146 143 1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   
1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   0   0   0   0   0   0   0  

In [17]:
embeddings.keys()

odict_keys(['last_hidden_state', 'mask'])

In [63]:
index = 2

print(sequence_strings[index])
print(*embeddings["mask"][index].tolist(), sep="")

acdefghiklmnpqrstvwyo
1111111111111111111100000000000000000000000


In [None]:
print(tokenizer_output["attention_mask"][0])

In [None]:
print(tokenizer_output["attention_mask"])

In [None]:
index = 0

df = pd.DataFrame(embeddings["last_hidden_state"][index].cpu().numpy())
df["sequence"] = list(
    "-" + sequence_strings[index] + (embeddings["last_hidden_state"][index].shape[0] - len(sequence_strings[index]) - 1) * "-"
)

In [None]:
from plms.utils import modeling_utils

for x in range(len(embeddings["last_hidden_state"])):
    print(
        modeling_utils.mean_pool(
            embeddings["last_hidden_state"][x].unsqueeze(0),
            embeddings["masks"][x].unsqueeze(0),
        ).mean()
    )

df_mean = modeling_utils.mean_pool(
    embeddings["last_hidden_state"],
    embeddings["masks"],
)

first = list(pd.DataFrame(df_mean.cpu().numpy()).iloc[0])
second = list(embeddings["last_hidden_state"][0].sum(dim=0) / embeddings["masks"][0].sum(dim=0))

print(first == second)