In [1]:
import os

import pandas as pd
import torch
from datasets import load_from_disk
from transformers import T5Tokenizer

from src.model.utils.data_collator import DataCollatorForT5Pssm

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
dataset = load_from_disk("../tmp/data/pssm/pssm_dataset_0_only/")
dataset = dataset.rename_column("pssm_features", "labels")
dataset = dataset.remove_columns(["name", "sequence", "sequence_processed"])
print(dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 26990
})


In [3]:
tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path="Rostlab/prot_t5_xl_uniref50",
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

data_collator = DataCollatorForT5Pssm(
    tokenizer=tokenizer,
    padding=True,
    pad_to_multiple_of=8,
)

In [4]:
batch = [dataset[i] for i in range(100, 140)]
batch = data_collator(batch)

In [5]:
pd.set_option("display.max_rows", 256)
pd.set_option("display.max_columns", 256)
# pd.DataFrame(batch["attention_mask"])

In [6]:
display(pd.DataFrame(batch["attention_mask"][0:100:5].tolist()))
display(pd.DataFrame([x.replace("<", " <").split(" ") for x in tokenizer.batch_decode(batch["input_ids"][0:100:5].tolist())]))


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207
0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
6,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0
7,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207
0,G,E,A,V,I,K,V,I,S,S,A,C,K,T,Y,C,G,K,T,S,P,S,K,K,E,I,G,A,M,L,S,L,L,Q,K,E,G,L,L,M,S,P,S,D,L,Y,S,P,G,S,W,D,P,I,T,A,A,L,S,Q,R,A,M,I,L,G,K,S,G,E,L,K,T,W,G,L,V,L,G,A,L,K,A,A,R,E,E,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
1,E,V,Q,L,Q,Q,S,G,P,D,L,V,K,P,G,A,S,V,K,I,S,C,K,A,S,G,Y,S,F,S,T,Y,Y,M,X,W,V,K,Q,S,X,G,K,S,L,E,W,I,G,R,V,D,P,D,N,G,G,T,S,F,N,Q,K,F,K,G,K,A,I,L,T,V,D,K,S,S,S,T,A,Y,M,E,L,G,S,L,T,S,E,D,S,A,V,Y,Y,C,A,R,R,D,D,Y,Y,F,D,F,W,G,Q,G,T,S,L,T,V,S,S,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
2,A,G,I,L,A,D,A,D,C,A,A,A,V,K,A,C,E,A,A,D,S,F,S,Y,K,A,F,F,A,K,C,G,L,S,G,K,S,A,D,D,I,K,K,A,F,V,F,I,D,Q,D,K,S,G,F,I,E,E,D,E,L,K,L,F,L,Q,V,F,K,A,G,A,R,A,L,T,D,A,E,T,K,A,F,L,K,A,G,D,S,D,G,D,G,A,I,G,V,E,E,W,V,A,L,V,K,A,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
3,A,T,T,P,I,I,X,L,K,G,D,A,N,I,L,K,C,L,R,Y,R,L,S,K,Y,K,Q,L,Y,E,Q,V,S,S,T,W,X,W,T,C,T,D,G,K,X,K,N,A,I,V,T,L,T,Y,I,S,T,S,Q,R,D,D,F,L,N,T,V,V,I,P,N,T,V,S,V,S,T,G,Y,M,T,I,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
4,S,P,L,P,I,T,P,V,N,A,T,C,A,I,R,X,P,C,X,G,N,L,M,N,Q,I,K,N,Q,L,A,Q,L,N,G,S,A,N,A,L,F,I,S,Y,Y,T,A,Q,G,E,P,F,P,N,N,L,D,K,L,C,G,P,N,V,T,D,F,P,P,F,X,A,N,G,T,E,K,A,K,L,V,E,L,Y,R,M,V,A,Y,L,S,A,S,L,T,N,I,T,R,D,Q,K,V,L,N,P,S,A,V,S,L,X,S,K,L,N,A,T,I,D,V,M,R,G,L,L,S,N,V,L,C,R,L,C,N,K,Y,R,V,G,X,V,D,V,P,P,V,P,D,X,S,D,K,E,V,F,Q,K,K,K,L,G,C,Q,L,L,G,T,Y,K,Q,V,I,S,V,V,V,Q,A,F,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
5,S,A,K,V,G,E,I,T,I,T,P,D,N,S,K,P,G,R,Y,I,S,S,N,P,E,Y,S,L,L,A,K,L,I,D,A,E,S,I,K,G,T,E,V,Y,T,F,X,T,R,K,G,Q,Y,V,K,V,T,V,P,D,S,N,I,D,K,M,R,V,D,Y,V,N,W,K,G,P,K,Y,N,N,K,L,V,K,R,F,V,S,Q,F,L,L,F,R,K,E,E,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
6,K,E,K,N,E,K,E,A,L,L,K,A,S,E,L,V,S,G,M,G,D,K,L,G,E,Y,L,G,V,K,Y,K,N,V,A,K,E,V,A,N,D,I,K,N,F,X,G,R,N,I,R,S,Y,N,E,A,M,A,S,L,N,K,V,L,A,N,P,K,M,K,V,N,K,S,D,K,D,A,I,V,N,A,W,K,Q,V,N,A,K,D,M,A,N,K,I,G,N,L,G,K,A,F,K,V,A,D,L,A,I,K,V,E,K,I,R,E,K,S,I,E,G,Y,N,T,G,N,W,G,P,L,L,L,E,V,E,S,W,I,I,G,G,V,V,A,G,V,A,I,S,L,F,G,A,V,L,S,F,L,P,I,S,G,L,A,V,T,A,L,G,V,I,G,I,M,T,I,S,Y,L,S,S,F,I,D,A,N,R,V,S,N,I,N,N,I,I,S,S,V,I,R,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>
7,M,E,N,L,N,M,D,L,L,Y,M,A,A,A,V,M,M,G,L,A,A,I,G,A,A,I,G,I,G,I,L,G,G,K,F,L,E,G,A,A,R,Q,P,D,L,I,P,L,L,R,T,Q,F,F,I,V,M,G,L,V,D,A,I,P,M,I,A,V,G,L,G,L,Y,V,M,F,A,V,A,</s>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>


In [7]:
attention_mask = batch["attention_mask"][0:100:5]

print(attention_mask.device)
attention_mask = attention_mask.to("cuda")
print(attention_mask.device)

display(pd.DataFrame(attention_mask.tolist()).iloc[:, 70:])

attention_mask = attention_mask.clone()  #!

seq_lengths = attention_mask.sum(dim=1) - 1  #!

print("seq_lengths:", *seq_lengths.tolist())

batch_indices = torch.arange(attention_mask.size(0), device=attention_mask.device)  #!
print("batch_indices:", *batch_indices.tolist())

attention_mask[batch_indices, seq_lengths] = 0

display(pd.DataFrame(attention_mask.tolist()).iloc[:, 70:])


cpu
cuda:0


Unnamed: 0,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207
0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
6,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0
7,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


seq_lengths: 87 117 106 82 180 97 200 79
batch_indices: 0 1 2 3 4 5 6 7


Unnamed: 0,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207
0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
6,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0
7,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [8]:
random_embeddings = torch.randn(8, attention_mask.size(1), 1024, device=attention_mask.device)

# Create a mask with shape [8, seq_len, 1024] by expanding attention_mask
# masked_embeddings = random_embeddings * attention_mask[:, :, None].expand_as(random_embeddings)

masked_embeddings = random_embeddings * attention_mask.unsqueeze(-1)

display(pd.DataFrame(masked_embeddings.cpu()[-1]).iloc[70:])


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,...,896,897,898,899,900,901,902,903,904,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943,944,945,946,947,948,949,950,951,952,953,954,955,956,957,958,959,960,961,962,963,964,965,966,967,968,969,970,971,972,973,974,975,976,977,978,979,980,981,982,983,984,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999,1000,1001,1002,1003,1004,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
70,0.561359,1.293465,-1.714545,-1.106702,-0.952702,-1.686057,0.805851,0.679847,1.117028,0.994236,-0.810386,0.273791,-1.069556,0.053401,-0.405108,0.664275,1.60865,-0.276963,-0.845644,0.447291,-2.367003,0.629626,-0.501591,-0.356254,1.473984,1.236181,-1.172098,1.30965,1.23018,0.198716,0.407126,-0.931169,-0.270573,-1.834012,-0.017454,1.240105,0.669067,-0.037401,1.807875,-0.966988,-0.584711,-1.621653,1.002524,0.767202,-0.65891,0.06214,-0.485499,2.552914,-0.084167,1.452801,1.736065,1.943252,0.961067,0.234755,0.974091,2.352182,-1.13108,0.6108,0.047064,-0.714467,1.206639,1.027667,-0.933463,-1.27872,-0.03347,0.084827,0.841631,-0.209145,0.681074,-1.985563,1.677782,-0.960931,-0.688954,-0.567404,0.460509,-1.901227,-0.125547,-0.239597,0.348871,1.108995,2.15551,0.676828,-1.498259,-0.182052,0.05958,-0.530113,0.267465,0.644697,-0.233345,1.289753,3.212892,-1.340336,2.539402,-1.766818,-0.640133,0.348606,-0.556638,-0.465316,-0.19896,-1.979727,0.312184,1.538851,1.018783,0.578636,-2.200294,0.52537,-0.399511,0.356617,-0.025037,-0.138553,-0.596599,1.687275,0.026329,0.164793,1.457117,-0.729422,-1.079206,-0.70944,0.624423,0.2284,-1.240625,-0.666977,0.643524,0.028333,0.260709,2.427839,0.433972,-0.184081,...,-0.243392,1.043083,0.455985,-1.483434,0.154047,0.957325,0.587665,-0.621019,1.810383,1.140293,-0.460457,0.248811,-2.21285,1.451595,0.629415,-1.249763,0.253318,2.017636,0.824246,1.220112,0.512322,-0.254285,-1.077262,-0.233358,0.662609,-1.186234,-0.801927,0.165936,0.762556,0.073339,0.268235,-0.203201,0.880474,1.690391,0.732994,1.698448,0.356247,2.485138,1.55273,-0.196302,-0.099549,-1.441009,-1.083942,0.054598,1.52645,-0.736516,0.919423,-0.366782,-0.426495,-0.772909,0.103632,-1.822573,-0.158104,-1.309901,-0.994564,0.009831,-0.725236,0.891025,1.824725,-1.099034,0.804259,-0.57778,1.294978,-0.119113,-0.531341,1.159603,-0.370735,0.526743,-2.300165,1.303996,-1.705064,0.605347,-0.885711,0.687241,0.250935,-0.604001,0.862549,-0.110622,-0.345346,0.085912,-1.280274,-0.95257,0.044201,0.343178,0.647395,0.404941,-0.330349,-1.992079,-0.755824,0.902442,-0.120175,0.684003,-1.198514,-0.009014,-0.414203,1.191985,0.921292,0.500681,1.193069,-0.783967,0.084561,-0.784663,-0.87149,-0.6692,-0.501796,0.43075,0.798521,0.967758,1.432566,-0.800976,1.077289,1.777586,0.842614,1.346673,0.359535,-1.143981,0.123423,0.039212,-0.394604,-0.242108,2.637919,1.931587,-0.397086,0.361226,1.105887,-0.170685,0.621389,0.829523
71,-2.81633,0.734815,0.86101,0.015977,0.725366,-1.182192,0.854787,-0.76322,0.772916,-1.573389,1.601826,2.600528,-1.247223,-0.184189,-0.796711,-0.667799,-1.110986,-1.443784,-0.177409,-0.85833,0.61275,-0.105407,2.341597,-1.920806,-0.15736,0.043879,0.945817,1.258875,1.116436,0.066099,1.316952,-1.969362,-0.537198,-0.543172,0.752654,0.153006,0.056174,-1.556831,-0.362195,-0.318256,-0.815123,-0.958631,1.414276,-0.431538,-1.253696,0.985871,1.762918,0.94228,-1.29617,0.195163,0.542875,-2.666186,-1.951096,0.984618,-1.348997,0.645963,-0.328268,0.110687,-0.111538,-0.293492,-0.626054,-0.314433,-0.782206,-1.106169,-0.771355,0.053457,1.034348,0.542292,-0.342624,0.218924,1.749038,0.142462,0.590502,0.860979,0.924996,-0.403876,-0.491195,2.108575,1.080398,0.670915,0.922182,1.576095,0.596736,0.125577,0.588867,-1.033126,-0.695192,1.276087,-0.557938,0.419669,0.693122,0.549178,0.160818,0.05631,0.288766,1.750629,1.57858,-0.265425,0.345992,1.40684,0.180496,0.202434,1.167845,-1.936759,1.088571,-0.311347,0.001472,-0.086934,1.167128,-0.193952,0.448894,-1.064202,2.340641,-1.629629,-0.009077,1.561366,0.190005,-0.416331,-0.354488,1.605923,-0.230835,0.580857,0.014023,-0.919451,-1.455112,1.308192,1.574155,-1.574024,...,-1.387092,0.674315,-0.562064,0.889245,0.850181,-1.379,1.360928,0.283924,-0.870093,1.574916,-0.099671,2.131582,-1.364982,-0.511302,0.772279,0.601008,-1.350574,1.291034,1.35059,-1.256811,0.797916,-1.024814,0.20941,0.236889,0.086637,-1.021841,-1.300445,1.429545,-0.144039,0.511366,-0.959736,-0.643153,1.311152,-0.683613,0.934408,0.780484,-0.012503,-1.122848,0.527482,0.759115,0.358753,-0.956899,1.55217,0.294023,-0.260269,1.504288,-1.654041,0.432047,-0.738953,0.801363,1.226628,0.312712,2.780505,0.502185,0.004409,-0.388169,0.072101,-1.536226,0.369633,0.510997,0.641181,-1.30367,0.01561,-0.388149,-0.743759,1.154593,-1.971677,-1.416281,0.059949,1.083285,-1.264993,0.248147,0.941392,0.177605,0.537322,1.517876,-1.318025,1.024754,2.40227,0.998589,-0.527477,0.640735,1.246558,-0.118628,0.488932,1.146441,-0.435496,-1.752233,0.807499,-1.19674,0.052504,0.12457,-0.152111,-0.885024,0.829948,0.050355,-1.210146,-0.337227,-0.317885,-0.924683,1.122421,-0.957578,0.710633,-0.176964,-0.266324,0.503849,-0.254303,-0.564192,1.562201,-0.318061,-0.195295,-0.888564,0.628833,0.800427,0.583291,-1.476126,-0.970438,1.116144,-0.065029,1.137745,1.378209,-0.815385,1.141792,0.68681,0.148614,0.421916,-1.724353,0.627813
72,0.83979,0.78756,-1.396913,-0.152296,-0.097517,-1.844065,0.192965,0.367647,1.024197,-1.091069,1.390817,1.007125,1.047367,-1.276119,-0.68959,-0.318462,0.182513,0.456492,-0.534451,-0.90926,1.229816,0.126511,1.88909,-0.41886,0.927714,0.635661,0.235202,-0.481564,1.399505,0.367647,1.018609,1.705888,0.294394,-0.810969,0.263462,1.450186,-0.078282,0.569885,-0.680416,1.147708,1.167934,-0.144541,0.964346,-1.333882,0.819839,-1.182935,0.281109,0.644556,-1.187449,0.083265,-1.165487,1.027943,0.035933,-0.461221,1.577343,0.300679,0.610102,-0.949055,-0.920145,1.443651,1.79853,1.047553,-0.893084,-1.07958,1.452439,-1.843726,0.876791,0.705752,-0.025173,2.154903,0.269979,0.080939,-0.028397,-0.421209,0.389367,-0.426266,-2.779283,-0.124244,-0.881131,-1.005868,-1.630245,-0.052144,0.237097,-1.91503,-0.978838,-1.199138,0.932195,-0.023805,0.872279,2.570677,0.667984,-0.046285,0.486567,1.05557,1.042498,0.190372,0.928762,0.891946,-1.907079,-0.130259,-0.011292,-1.361716,0.512903,-1.233113,-0.16496,0.34143,-0.246945,-0.165084,0.190173,-1.823566,-0.438425,-1.014094,0.492405,0.383641,-1.185966,-0.954377,-0.782984,-0.116461,0.578384,-1.156708,-0.37356,-0.038782,-0.104312,0.254054,0.116561,0.924578,0.76885,-0.223231,...,-0.180496,-0.904812,1.229555,0.285752,1.972863,0.544063,-1.69587,-0.283662,-0.821942,0.654928,-0.328649,-0.606372,2.213628,1.970129,0.779319,0.706382,-0.793325,0.164304,-0.427318,-0.802272,0.002459,0.021097,0.748457,-0.649512,0.304935,-1.467517,0.369157,-0.160861,-0.714013,1.253595,0.21119,-0.961057,-0.16297,-0.631857,-1.591741,1.686975,-0.501743,-0.314669,0.862126,1.184529,1.354703,0.63942,1.104429,-0.715303,-0.695258,-1.101178,-0.271542,-0.249675,0.039991,0.18958,-0.2723,-0.119877,-0.052555,-0.141244,-1.330991,-0.629941,1.0731,1.047196,0.152844,-0.834429,-0.060878,0.755271,1.616507,-0.833361,-0.759606,-0.796707,-1.836808,1.010042,2.168688,0.318382,-0.01042,-0.87996,-1.744009,-0.610175,-0.002357,-1.620495,-0.777494,-0.422524,-0.123179,-0.145546,0.433056,-0.588693,-0.180988,0.384092,-0.842818,-0.817586,0.942736,0.065597,-1.533777,-0.070834,-2.281589,-0.099712,0.399622,-0.602428,-0.294217,-1.877374,1.798501,-1.08149,-0.93352,0.874894,-0.115862,-1.508759,0.247991,0.302388,-0.333662,1.361538,1.055434,-1.125378,-1.924777,-0.136276,1.42998,-0.596442,-1.360358,1.194259,1.130043,1.253627,1.333516,0.466034,0.550204,0.221757,1.194172,0.41892,0.354532,-0.458714,1.547453,1.021887,0.589062,-0.482921
73,-0.238578,0.528111,0.33736,0.974683,-2.301317,1.243873,-1.268411,-1.287285,-0.586741,0.727817,0.918803,-1.155384,-0.409249,0.758711,-1.097583,-1.965496,0.553247,1.453513,-0.969461,1.421965,1.316514,0.512356,2.014128,-0.274339,0.547147,-0.807841,-2.07959,-2.274706,-0.115228,-0.296401,2.214706,-1.348952,-0.883554,0.778489,0.156197,-1.102015,-0.024179,-0.752734,0.697479,-0.419197,-0.47634,-1.404852,-1.3768,-2.575574,-0.57565,0.201354,-2.048348,-0.2157,0.590082,0.58414,0.605862,0.093689,0.044609,-2.864722,-1.266668,-0.391163,-0.934608,0.809965,-1.307156,0.623646,-1.163956,-0.641223,-1.669206,0.342198,2.533758,-0.564346,-0.75498,0.679821,1.92089,0.092779,0.391887,0.189732,0.178197,1.005693,0.723065,0.586747,1.11693,0.369206,-0.919226,-0.080715,1.392007,0.815658,-0.224036,-1.501244,-1.979007,0.222096,0.031531,0.123869,-1.25043,-1.649102,-1.496541,0.517283,0.999946,0.14858,-0.536612,-0.541651,0.107746,-0.295936,-0.749183,0.898662,-0.843688,-0.979431,1.392767,-0.202007,-1.376245,0.556628,1.549659,-2.408323,1.960774,2.238667,-0.026818,0.541739,0.596506,-0.283034,1.2122,-0.353834,-1.775429,-1.666009,-0.046291,1.934512,1.228066,1.255244,-0.784885,1.276392,1.005271,0.211982,-1.889768,-0.357913,...,0.79943,0.666074,-1.308844,0.37094,0.948541,-0.197936,0.746076,-0.570582,1.270175,0.085272,0.215098,-0.084932,-0.897334,-2.00944,0.448866,-0.086473,1.621261,0.531491,-1.004356,1.320942,-0.643383,0.461525,0.48351,-0.61526,2.420102,-1.124014,-1.235853,0.397105,0.170048,-0.497943,-0.903997,1.23023,0.281276,1.928679,0.260533,0.684565,0.72131,-0.252024,-0.150601,0.027832,-0.185245,1.031849,1.256576,0.793411,0.345264,-0.512801,1.266636,0.250741,1.503787,0.618339,1.516746,-0.792366,-0.182002,0.78779,0.461149,0.107578,1.754319,-1.200391,1.578906,-0.427638,-1.43701,0.199423,1.103845,0.579822,0.837286,0.4676,-0.168553,0.283165,-0.084608,-0.127546,-0.708428,-0.167767,-0.653974,1.473835,0.401064,0.744186,0.307676,-0.432621,0.219519,-0.213134,-1.066209,-1.363991,-0.287433,-0.421248,-0.282183,1.83917,-1.60605,-0.691611,0.969474,-1.022302,-0.779501,1.131147,0.522705,0.822811,1.949573,-0.057996,1.445631,-0.835687,0.53737,-0.745605,-0.002129,-0.50838,-0.397854,0.293663,-0.355134,0.105843,0.408919,0.295744,-0.551878,0.329224,-0.942266,-0.248307,-0.453691,0.708887,0.237799,1.927828,0.248766,1.875688,-0.930072,-1.304665,1.846984,0.207128,-1.100932,0.446787,0.566196,-1.468139,-0.067401,0.828445
74,-0.139449,0.993389,-0.620221,1.335137,-0.413601,0.641423,1.458499,1.669834,0.843501,-1.455607,0.599659,0.184745,0.379053,1.205154,-0.567561,1.381958,0.315912,-1.196859,0.470861,-1.563513,-0.010791,1.408725,0.657477,-0.172742,-0.874244,0.267687,0.574764,-0.76683,-0.721947,0.181524,-1.546653,-0.424096,-0.242946,-2.216582,1.570356,-0.237189,-0.296206,-1.154306,-0.455164,1.000355,-0.257187,-0.426971,1.722413,0.560312,0.652138,0.109462,1.420833,0.219691,1.04766,-1.368121,-0.134704,1.406061,-0.552784,-0.56712,-1.285125,-0.743856,-0.43465,0.461762,-0.618833,-1.638595,-1.572219,3.058269,-0.30726,-0.512365,1.310975,1.601153,-1.224257,0.873876,-0.671609,1.79065,1.09943,-1.872571,0.907982,1.496497,0.816035,1.529781,0.547445,-0.110838,-1.812107,1.186197,-0.519746,0.664645,-0.124757,0.823432,-1.224443,-0.970158,0.316065,-0.741254,-0.456264,0.309826,0.073437,1.519151,0.607654,0.322432,-1.160266,0.481979,-0.046263,1.207959,0.02433,-0.519629,-0.118729,-0.135106,-2.302926,0.204983,0.154014,-0.823027,0.561504,-1.731582,-1.473726,-0.240334,1.759773,-0.128184,-0.102177,1.189097,-0.150791,-0.737411,1.392241,-1.033603,-1.198136,1.315487,1.052588,-1.217151,1.340869,-1.325918,-0.736737,-0.310368,-2.166844,0.90508,...,-0.383995,-0.688127,-1.443533,-0.471824,0.315644,0.822342,-0.957535,-0.100022,1.532791,-0.690234,-1.049347,-0.115854,0.665152,0.451567,1.470575,-0.897241,-0.55277,-0.10513,-0.955789,-0.997662,-0.491212,-0.928895,-1.059088,-0.002882,0.279817,1.984156,-0.480957,-1.139908,-0.957367,0.048378,-0.299986,-1.380676,0.201796,0.660662,1.09687,0.018148,-0.284105,3.164117,-1.06815,-1.149733,-0.380987,0.037062,-1.113744,-0.45466,-0.82744,0.006925,0.551592,-0.650079,-0.987407,0.12365,0.861645,0.16793,0.184555,-0.817586,-1.534794,0.672573,0.140349,1.095309,-0.312646,0.050244,0.954095,-1.201286,0.275896,0.07848,-0.026723,0.186076,-0.27113,-1.969471,-1.178644,-0.404819,-1.686092,0.308457,1.741624,-0.920096,-1.230401,-2.709389,-0.16425,0.285212,0.27158,0.622479,-0.062145,-0.72198,-1.11038,-1.28898,-0.000562,-0.324405,0.450893,-0.374199,0.707637,0.855419,0.060896,0.742792,0.108039,0.644225,-0.367709,0.798654,-0.635619,-0.200447,1.263038,-0.07741,-1.05016,-0.39,-0.483803,1.411899,0.774761,-0.265163,0.633471,-2.650584,1.0033,1.75731,-0.762234,0.251995,0.381454,-0.023087,-1.758436,1.784353,-0.319955,0.414028,-0.675014,1.130008,0.465075,-0.365562,-0.438778,0.20189,-0.259244,0.448234,0.765714,-0.898663
75,-0.401798,1.605834,1.025579,0.364479,-1.217507,0.556258,-1.045106,1.830024,1.298572,0.896757,-0.122129,0.659084,-0.52432,0.09649,2.121836,0.218393,0.114717,-0.74804,0.845945,-0.103211,-0.287063,-0.719473,-1.590595,0.85566,-2.402812,-0.277935,-1.376309,-0.08887,0.340761,-0.535123,0.572874,-1.545173,1.585153,1.911344,0.515786,-0.625565,-0.349296,-0.542288,0.837518,1.010208,-1.130061,-0.554137,-0.6382,0.248393,0.010019,-0.156826,-0.328593,-0.000963,0.864057,1.578892,-0.725909,1.746472,-1.867288,-1.482579,-1.017996,-0.37895,0.758871,1.929718,2.603052,0.724012,1.82832,-0.467927,-1.086125,0.292282,1.246555,0.261061,0.460174,-1.195243,-1.019017,-1.810944,-0.389612,-2.685623,-0.130096,0.581966,0.69919,-0.180975,0.748096,-0.059929,-0.013665,0.398778,0.530399,-0.855625,-0.913149,0.22884,-0.398549,0.184866,-0.541182,0.051607,-0.500466,-0.641794,1.328022,0.094773,0.071123,-0.078554,0.103045,0.48456,-1.125608,-0.732438,-1.395933,-0.806574,-0.518313,1.701406,-1.547046,0.173587,0.096991,-1.261613,-0.304478,-0.087205,-0.195203,0.509509,-0.430825,0.911305,0.585451,1.087865,0.440096,0.067539,0.310353,1.179215,-0.144134,0.230359,1.334187,-1.570529,1.254367,1.751838,-0.297421,0.023701,0.233737,-0.730381,...,0.113768,-1.087116,-1.394177,-0.132048,1.83809,0.226251,-0.817174,0.509889,-1.513944,1.603503,0.738869,-0.288137,2.413967,0.007213,-1.408599,-0.824773,1.659112,1.519203,-1.134265,0.34926,0.273918,-1.191947,-0.186201,0.325184,1.518457,0.815245,1.369062,0.900526,-0.613411,-0.684044,-0.661544,0.403464,-1.328495,-1.958555,-0.425741,0.955532,-0.778114,0.985608,-0.028181,2.450496,0.499826,-0.972471,0.66473,-0.304588,1.381757,-0.230508,0.333732,1.664476,-0.42728,-0.85564,-0.821952,-0.517037,-0.502979,-1.08668,-0.096717,1.460046,1.102506,1.15604,0.300171,-1.165693,-0.268138,0.444831,0.062798,-0.46486,0.771175,0.931604,-0.068459,0.52703,0.261893,-0.856358,-0.007039,-0.258113,-0.514879,-1.430918,-0.382464,-1.506844,-0.842201,-1.913091,1.806575,0.333435,0.638284,0.474256,1.391269,1.832034,0.033176,-0.395742,-0.198532,-0.404378,-0.062879,1.177911,-0.617719,0.751823,0.756909,0.979209,-1.383787,0.601455,0.326984,-0.376218,0.510312,-0.754883,-0.550703,0.249139,1.015669,0.568199,0.75745,0.29727,1.333271,1.674182,-0.458495,-1.239259,-0.41626,-0.578293,-0.066369,0.052017,0.739836,1.048273,1.253419,0.843449,0.812459,0.017433,-0.185757,1.484963,-0.830111,0.787134,-0.830537,-2.176436,-0.31032,-0.761419
76,-2.103021,1.494955,-0.479166,-0.641063,-1.409153,-0.555872,0.560713,0.543476,0.102527,0.89456,-0.199601,-1.047135,-0.053883,1.554372,0.465411,0.32428,0.900713,0.995122,2.279093,-0.58874,-0.619802,-0.328069,-0.980116,-1.708529,0.279629,-0.074118,1.523858,1.045153,-0.278666,-0.271868,-0.975878,1.517254,1.183362,0.41203,0.748666,1.041673,0.572525,0.521048,-0.074266,-0.367938,0.866145,-0.717418,-0.21326,-0.260279,0.122936,-0.339753,2.059657,0.970608,-1.119884,-0.792899,-0.701945,0.402769,0.930962,-1.228297,0.527817,-0.293073,-0.81997,0.31246,2.236013,0.912872,-0.107399,0.060397,-1.245773,-1.121294,-0.794178,0.345787,-1.167119,-1.037602,-1.85453,0.051691,-1.035633,1.707477,-1.2765,-0.850391,-0.017625,-0.14851,1.833359,-1.573401,-0.531905,-1.017858,-0.790638,-0.517032,0.325692,0.969355,0.561118,1.386343,1.106185,0.534921,0.934126,-0.443171,1.203743,-0.530324,0.905116,0.450464,1.157611,-0.005306,0.507149,-2.124785,-1.128429,-0.403285,0.985023,0.502249,0.096913,1.170322,2.146341,0.808604,-0.452799,0.957749,-1.067937,-0.343553,0.647565,-1.635156,-2.218938,-0.166131,-0.804044,-1.007559,0.051974,-1.964276,1.148287,0.672903,-0.893753,0.668609,0.815772,0.323084,0.284053,0.762081,0.478125,-0.442881,...,-1.65566,0.028225,-0.193132,-0.401259,-1.169354,0.096842,0.581771,0.453464,-0.751633,-1.189085,0.529334,1.367857,-0.508582,-2.689246,-0.65426,-1.194825,-0.806893,0.802052,0.883107,0.255266,-1.715464,0.749442,0.78664,0.65275,0.911224,1.188962,0.361388,0.068078,-0.297758,1.298524,0.066777,1.059052,1.621613,-0.921555,-0.612679,0.872085,0.500242,-0.939171,0.297649,0.268674,1.016084,0.121239,1.466786,0.555539,0.284587,0.996619,1.663297,-0.478952,0.695869,-0.424118,2.304816,-0.276053,0.98203,0.576314,-0.150263,0.066757,-0.08847,-0.394468,1.285669,0.687243,1.404772,0.9993,-0.544202,1.006984,-0.900555,-0.645634,1.445042,0.380741,1.300879,-0.794961,-0.80588,-1.42027,0.048059,0.615987,0.653028,-0.453661,-1.340657,0.255766,0.19723,1.804404,-2.001785,-0.222488,0.148886,-0.389911,-0.6052,-1.366884,0.434304,0.107639,0.032815,0.724288,0.203515,0.858038,-0.547337,0.01596,1.514198,-2.380525,-0.148926,-0.008834,-0.664057,0.521123,0.095073,0.274009,1.157728,-0.211263,-1.035399,1.796161,-1.797436,-1.095188,-0.94449,-0.142344,2.625052,-0.567282,-1.294164,-0.181219,-1.627482,0.364037,-0.021832,-0.67436,-0.463762,-0.649923,-0.295907,1.131421,-0.265052,1.008866,0.357296,-0.841112,-1.385538,0.920072
77,1.574522,2.029862,0.670599,-1.702644,0.053664,-0.879597,1.596601,-1.193508,0.10524,-0.196343,-1.97019,2.053218,-0.028593,0.348897,0.369246,1.413106,-0.563812,0.529782,0.692975,-0.53531,0.918502,0.731147,-0.389938,-0.863607,-0.258073,0.840611,0.804799,0.468742,1.150492,-0.874291,0.982341,-0.276078,-1.859561,0.288724,1.417195,0.723211,1.41642,-0.267323,-0.044051,0.296828,-1.581376,-0.388666,-0.832224,0.398872,-1.649461,0.179371,-1.52128,1.387928,-1.025936,-0.199561,-0.751202,-1.773969,0.497988,-1.067965,-0.825039,-1.491479,0.068368,-1.914459,-0.301406,-0.522329,-1.347845,-0.064728,0.915245,-0.780521,-0.188942,1.263396,-0.363749,0.998983,-0.76812,0.930826,1.112245,-0.160654,2.414436,-0.086798,-0.56003,0.779974,-0.55266,-1.543445,-0.355251,-0.148253,0.03642,-0.116113,0.654703,1.590119,0.893058,0.660585,-0.874773,-1.46496,-0.779351,-1.106339,-1.848394,0.049876,-1.002442,-0.142155,-0.542498,-1.29975,1.19968,-0.330761,-0.816518,-0.847058,-0.312715,0.612648,-1.631388,-0.746687,0.440927,0.240586,0.393489,0.243078,-0.457464,-0.672761,-0.144279,0.278166,0.155857,0.469468,0.241421,0.334614,-0.471083,-0.357609,0.027171,-0.534814,0.0378,1.236968,-0.875716,0.094614,-0.663242,0.47722,-0.723469,-0.982765,...,-0.43007,0.164184,-1.311591,0.487084,0.977991,0.185393,-0.685266,0.683797,-0.879524,-0.71484,-0.405432,-1.041436,-1.328355,1.139459,-0.288987,-1.035476,-0.067964,0.272432,-0.843049,0.318483,-0.290006,-0.111755,1.416239,-0.226373,1.551428,-0.292368,1.882424,0.329068,0.809145,-2.701997,-0.937761,-2.346121,-0.039545,1.301907,-0.709678,-1.390559,-2.003779,-1.083949,1.264693,-0.192084,0.463061,-1.531164,-1.175702,1.029332,-1.914879,-0.502518,0.62043,0.320235,1.027373,0.874852,0.183884,0.648628,0.453372,-0.325113,1.392014,1.446987,-1.303837,-0.859139,0.450616,0.45757,-1.597237,-0.939064,0.358812,-1.644956,1.213311,-0.520538,-1.406539,-0.734681,1.069546,-1.74798,1.363857,0.001832,0.903136,-0.641754,1.911649,0.85789,-0.245148,-0.05855,-2.684394,-0.294902,0.834304,-0.01341,-0.001891,0.853565,0.123617,-0.520315,0.002192,1.250246,-0.691558,0.432724,-0.133321,-1.846329,-0.029975,-0.001684,-0.613234,1.535929,-0.602034,-1.310015,0.245608,-0.289588,-1.470129,-2.49112,0.913015,-0.988046,-0.236765,0.348729,-1.984911,1.588244,0.668524,0.228767,0.658831,-0.156068,-0.394013,0.604722,0.665497,0.005642,-1.488805,-1.050802,-0.221285,-0.937523,-0.210657,0.252069,-1.160771,1.210404,1.424799,1.31936,0.133303,0.796561
78,-1.54465,-0.590516,-0.071332,-0.542838,0.103536,-0.420064,-0.083184,-1.492521,-0.116769,-1.018257,0.231034,-0.417809,0.418711,-2.177373,0.013697,-0.061483,1.669038,-0.676833,-1.820671,1.693282,-0.129046,1.179477,-1.186123,-0.424377,-1.234182,0.975176,-0.59083,-0.81971,0.962937,0.255318,-0.32919,-0.515268,0.443413,-0.738666,-0.124192,-0.91586,1.552581,-0.589304,-3.149258,0.089448,1.38451,-2.161188,0.084016,-1.377949,1.405786,1.300329,-0.949065,0.892443,0.637502,-0.153541,0.880949,-1.646719,1.511763,-0.034271,0.876951,0.649124,1.447865,0.632723,-1.40929,-0.000626,-0.117964,0.591168,0.986716,1.117524,0.246398,-0.180667,0.604845,-0.956418,0.621234,-1.146938,-0.262765,-0.124894,2.201695,-0.126073,-0.804607,0.369506,-0.739495,0.004958,0.625336,-0.205094,0.977011,1.633765,1.26333,1.449535,-0.793289,2.254357,0.250587,-1.35472,-0.296698,0.060062,-1.434679,0.448551,0.318442,-0.560852,1.057899,1.079867,-0.114622,1.547646,-0.286509,0.79764,0.017334,1.506121,0.086214,-0.723084,0.831897,-0.836302,0.573392,0.673393,0.406881,-0.497833,-0.191019,0.1239,-0.87239,-0.426389,0.183741,0.257095,-1.364837,-1.823686,-0.752313,1.486159,2.266516,-0.049081,-1.284168,-0.42937,-0.234575,1.050772,0.163294,-0.104186,...,0.463323,0.30887,0.738947,1.074358,0.972543,0.327184,1.319268,1.32159,0.007472,1.345273,0.225171,0.089809,-0.029491,0.186524,1.399819,0.462033,2.341275,-0.488546,0.172844,0.099218,-1.253115,1.191519,1.802826,-0.357381,0.28591,0.829772,0.38548,0.197222,0.355962,0.742229,-0.669446,-0.418264,0.61985,-0.12072,1.99076,-0.35872,-1.048537,0.052089,1.0941,-0.314443,0.81853,-2.112221,0.653322,0.913713,0.23188,0.982191,-0.834128,1.387411,0.159367,-1.138356,0.520591,1.106018,0.980567,-0.905684,1.332848,-0.497112,0.67225,-1.259853,-0.407138,-1.227283,-1.073032,0.183815,-0.875296,-1.39325,0.218305,-0.133468,1.521602,0.473478,0.329653,1.151341,0.24296,1.357642,-0.614949,0.846134,0.237954,-0.580176,1.760548,-0.872776,-0.506684,-0.920763,-0.025371,-0.409698,1.231031,0.385882,1.477442,0.586752,0.406803,1.038668,1.306178,2.257884,0.464262,-0.407766,0.602834,-0.250097,-0.861129,1.765219,-1.119744,0.530884,0.321929,-0.934067,-1.117402,-0.192899,-0.467621,-0.35338,1.124013,1.445264,0.225488,-0.376535,-0.535725,-0.567047,-0.792302,-0.69857,-0.448437,-0.130267,-1.188159,-0.456427,1.870351,1.249107,-1.042436,-0.469798,-0.442865,0.418282,-0.620768,-2.453227,0.942678,0.562795,-0.014431,0.876265
79,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,0.0,0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,0.0,-0.0,0.0,...,0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0


In [None]:
attention_mask.masked_fill(~attention_mask[:, None, :], float("-inf"))

In [None]:
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", 32)

random_embeddings = torch.randn(8, attention_mask.size(1), 1024, device=attention_mask.device)
print(f"Random embeddings shape: {random_embeddings.shape}")

display(pd.DataFrame(random_embeddings.cpu()[0]))

masked_embeddings = random_embeddings * attention_mask[:, :, None]
print(f"Masked embeddings shape: {masked_embeddings.shape}")
display(pd.DataFrame(masked_embeddings.cpu()[0]))
