In [1]:
import os

import pandas as pd
import torch
from datasets import load_from_disk
from transformers import T5Tokenizer

from src.model.utils.data_collator import DataCollatorForT5Pssm

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
dataset = load_from_disk("../tmp/data/pssm/pssm_dataset_0_only/")
dataset = dataset.rename_column("pssm_features", "labels")
dataset = dataset.remove_columns(["name", "sequence", "sequence_processed"])
print(dataset)

In [3]:
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path="Rostlab/prot_t5_xl_uniref50",
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

data_collator = DataCollatorForT5Pssm(
    tokenizer=tokenizer,
    padding=True,
    pad_to_multiple_of=8,
)

In [4]:
batch = [dataset[i] for i in range(100, 140)]
batch = data_collator(batch)

In [5]:
pd.set_option("display.max_rows", 256)
pd.set_option("display.max_columns", 256)
# pd.DataFrame(batch["attention_mask"])

In [None]:
display(pd.DataFrame(batch["attention_mask"][0:100:5].tolist()))
display(pd.DataFrame([x.replace("<", " <").split(" ") for x in tokenizer.batch_decode(batch["input_ids"][0:100:5].tolist())]))


In [None]:
attention_mask = batch["attention_mask"][0:100:5]

print(attention_mask.device)
attention_mask = attention_mask.to("cuda")
print(attention_mask.device)

display(pd.DataFrame(attention_mask.tolist()).iloc[:, 70:])

attention_mask = attention_mask.clone()  #!

seq_lengths = attention_mask.sum(dim=1) - 1  #!

print("seq_lengths:", *seq_lengths.tolist())

batch_indices = torch.arange(attention_mask.size(0), device=attention_mask.device)  #!
print("batch_indices:", *batch_indices.tolist())

attention_mask[batch_indices, seq_lengths] = 0

display(pd.DataFrame(attention_mask.tolist()).iloc[:, 70:])


In [None]:
random_embeddings = torch.randn(8, attention_mask.size(1), 1024, device=attention_mask.device)

# Create a mask with shape [8, seq_len, 1024] by expanding attention_mask
# masked_embeddings = random_embeddings * attention_mask[:, :, None].expand_as(random_embeddings)

masked_embeddings = random_embeddings * attention_mask.unsqueeze(-1)

display(pd.DataFrame(masked_embeddings.cpu()[-1]).iloc[70:])


In [None]:
attention_mask.masked_fill(~attention_mask[:, None, :], float("-inf"))

In [None]:
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", 32)

random_embeddings = torch.randn(8, attention_mask.size(1), 1024, device=attention_mask.device)
print(f"Random embeddings shape: {random_embeddings.shape}")

display(pd.DataFrame(random_embeddings.cpu()[0]))

masked_embeddings = random_embeddings * attention_mask[:, :, None]
print(f"Masked embeddings shape: {masked_embeddings.shape}")
display(pd.DataFrame(masked_embeddings.cpu()[0]))
