In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import re

import torch
from transformers import AutoTokenizer, T5Tokenizer, AutoModelForCausalLM, T5EncoderModel

from src._shared import load_config
from src.model.modeling_protein_clip import smart_mean_pooling, attention_mask_to_trim_indices

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Matplotlib created a temporary cache directory at /tmp/matplotlib-j8s_jbk8 because the default path (/home/lfi/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [3]:
train_config = load_config()

tokenizer_llm = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=train_config["model"]["text_encoder_name"],
        )
tokenizer_plm = T5Tokenizer.from_pretrained(
            pretrained_model_name_or_path=train_config["model"]["protein_encoder_name"],
            do_lower_case=False,
            use_fast=True,
            legacy=False,
        )

dummy_texts = ["This is a test protein sequence text", "This is a different protein test sequence"]
dummy_proteins = [
    "MLKFVVVLAAVLSLYAYAPAFEVHNKKNVLMQRVGETLRISDRYLYQTLSKPYKVTLKTLDGHEIFEVVGEAPVTFRFKDKERPVVVASPEHVVGIVAVHNGKIYARNLYIQNISIVSAGGQHSYSGLSWRYNQPNDGKVTDYF",
    "MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGE",
]
dummy_proteins = [" ".join(list(re.sub(r"[UZOB]", "X", x))) for x in dummy_proteins]

text_tokens = tokenizer_llm(dummy_texts, return_tensors="pt", padding=True, truncation=False)
protein_tokens = tokenizer_plm(dummy_proteins, return_tensors="pt", padding=True, truncation=False)

In [4]:
print(text_tokens)
print(text_tokens.keys())
print(text_tokens["input_ids"][0])
print(text_tokens["attention_mask"][0])
print(tokenizer_llm.decode(text_tokens["input_ids"][0]))


{'input_ids': tensor([[  910,   338,   263,  1243, 26823,  5665,  1426],
        [  910,   338,   263,  1422, 26823,  1243,  5665]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}
dict_keys(['input_ids', 'attention_mask'])
tensor([  910,   338,   263,  1243, 26823,  5665,  1426])
tensor([1, 1, 1, 1, 1, 1, 1])
This is a test protein sequence text


In [5]:
# print(protein_tokens)
# print(protein_tokens.keys())
print(*protein_tokens["input_ids"][0].tolist())
print(*protein_tokens["attention_mask"][0].tolist())
print(tokenizer_plm.decode(protein_tokens["input_ids"][0]))
print(*protein_tokens["input_ids"][1].tolist())
print(*protein_tokens["attention_mask"][1].tolist())
print(tokenizer_plm.decode(protein_tokens["input_ids"][1]))


19 4 14 15 6 6 6 4 3 3 6 4 7 4 18 3 18 3 13 3 15 9 6 20 17 14 14 17 6 4 19 16 8 6 5 9 11 4 8 12 7 10 8 18 4 18 16 11 4 7 14 13 18 14 6 11 4 14 11 4 10 5 20 9 12 15 9 6 6 5 9 3 13 6 11 15 8 15 14 10 14 9 8 13 6 6 6 3 7 13 9 20 6 6 5 12 6 3 6 20 17 5 14 12 18 3 8 17 4 18 12 16 17 12 7 12 6 7 3 5 5 16 20 7 18 7 5 4 7 21 8 18 17 16 13 17 10 5 14 6 11 10 18 15 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
M L K F V V V L A A V L S L Y A Y A P A F E V H N K K N V L M Q R V G E T L R I S D R Y L Y Q T L S K P Y K V T L K T L D G H E I F E V V G E A P V T F R F K D K E R P V V V A S P E H V V G I V A V H N G K I Y A R N L Y I Q N I S I V S A G G Q H S Y S G L S W R Y N Q P N D G K V T D Y F</s>
19 17 12 15 9 19 4 8 12 10 9 5 4 8 4 14 12 18 14 10 11 9 5

In [6]:
print(protein_tokens["input_ids"].shape)
print(protein_tokens["attention_mask"].shape)

torch.Size([2, 145])
torch.Size([2, 145])


In [7]:
text_tokens = {k: v.to('cuda') for k, v in text_tokens.items()}
protein_tokens = {k: v.to('cuda') for k, v in protein_tokens.items()}


In [8]:
model_plm, loading_info_plm = T5EncoderModel.from_pretrained(
    pretrained_model_name_or_path='Rostlab/prot_t5_xl_uniref50',
    device_map='cuda:0',
    output_loading_info=True,
    torch_dtype="auto",
    trust_remote_code=True,
)
model_plm.to('cuda')

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


T5EncoderModel(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dropout): Dropo

In [9]:
model_plm.eval()
with torch.no_grad():
    model_output = model_plm(protein_tokens["input_ids"], protein_tokens["attention_mask"])

In [10]:
print(model_output['last_hidden_state'].shape)
print(model_output['last_hidden_state'])
print()
print(model_output['last_hidden_state'].mean(dim=1).shape)
print(model_output['last_hidden_state'].mean(dim=1))
print()
print(model_output['last_hidden_state'].mean(dim=1).mean(dim=-1).shape)
print(model_output['last_hidden_state'].mean(dim=1).mean(dim=-1))


torch.Size([2, 145, 1024])
tensor([[[ 0.2427, -0.2669,  0.2241,  ...,  0.4926, -0.0906,  0.0016],
         [-0.0530, -0.1876,  0.1500,  ...,  0.2387, -0.0264, -0.0929],
         [ 0.0649, -0.0347,  0.3434,  ...,  0.1018, -0.0669, -0.1747],
         ...,
         [ 0.0078,  0.1548, -0.0381,  ..., -0.0581, -0.0658, -0.3015],
         [-0.2471, -0.0544,  0.0465,  ...,  0.1646,  0.0707, -0.3671],
         [ 0.0017, -0.0949,  0.0099,  ..., -0.0407, -0.0417,  0.0333]],

        [[-0.0464, -0.3680,  0.2976,  ..., -0.0458, -0.0230, -0.0038],
         [-0.2348, -0.1525,  0.1276,  ...,  0.0795,  0.1079, -0.0907],
         [ 0.0544,  0.1019,  0.2195,  ...,  0.1389, -0.0771, -0.1166],
         ...,
         [-0.0218, -0.2385,  0.1239,  ..., -0.1804, -0.2499,  0.0594],
         [-0.0318, -0.2399,  0.1235,  ..., -0.1831, -0.2554,  0.0663],
         [ 0.0308, -0.2422,  0.0607,  ..., -0.1697, -0.2450, -0.0066]]],
       device='cuda:0')

torch.Size([2, 1024])
tensor([[ 0.0150,  0.0740,  0.0449,  ..., 

In [11]:
pooled_output = smart_mean_pooling(model_output['last_hidden_state'], protein_tokens["attention_mask"])
print(pooled_output.shape)
print(pooled_output)
print(pooled_output.mean(dim=-1))

torch.Size([2, 1024])
tensor([[ 0.0150,  0.0740,  0.0449,  ...,  0.0753,  0.0156, -0.0082],
        [ 0.0223, -0.0566,  0.0859,  ..., -0.0246, -0.0437, -0.0225]],
       device='cuda:0')
tensor([9.1497e-06, 1.4323e-03], device='cuda:0')


In [12]:
pooled_output = smart_mean_pooling(model_output['last_hidden_state'], attention_mask_to_trim_indices(protein_tokens["attention_mask"], trim_end=1))
print(pooled_output.shape)
print(pooled_output)
print(pooled_output.mean(dim=-1))

torch.Size([2, 1024])
tensor([[ 0.0151,  0.0752,  0.0451,  ...,  0.0761,  0.0160, -0.0085],
        [ 0.0227, -0.0564,  0.0874,  ..., -0.0242, -0.0443, -0.0229]],
       device='cuda:0')
tensor([4.9078e-06, 1.4458e-03], device='cuda:0')


In [13]:
print(attention_mask_to_trim_indices(torch.tensor([[0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0], [1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0]])))
print()
print(attention_mask_to_trim_indices(torch.tensor([[0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0], [1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0]]), trim_beginning=2, trim_end=2))


tensor([[0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0],
        [1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0]])

tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0]])
