In [1]:
import json
import torch
import torch.nn.functional as F
import math
from glob import glob

In [2]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
model = T5ForConditionalGeneration.from_pretrained('mesolitica/nanot5-small-malaysian-translation-v2',
                                                  attn_implementation = 'sdpa').to(torch.bfloat16)

In [4]:
encoder_emb = torch.nn.Embedding(
    model.config.relative_attention_num_buckets, 
    model.config.num_heads
)
encoder_emb.load_state_dict(
    model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.state_dict()
)

decoder_emb = torch.nn.Embedding(
    model.config.relative_attention_num_buckets, 
    model.config.num_heads
)
decoder_emb.load_state_dict(
    model.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.state_dict()
)

<All keys matched successfully>

In [5]:
files = glob('/home/husein/ssd3/translation/train.json*.splitted')
len(files)

68

In [6]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'data': 'str',
}
hashes = 'sha1', 'xxh64'

In [7]:
!mkdir tokenized

mkdir: cannot create directory ‘tokenized’: File exists


In [8]:
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
    """
    Adapted from Mesh Tensorflow:
    https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

    Translate relative position to a bucket number for relative attention. The relative position is defined as
    memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
    small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
    positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
    This should allow for more graceful generalization to longer sequences than the model has been trained on

    Args:
        relative_position: an int32 Tensor
        bidirectional: a boolean - whether the attention is bidirectional
        num_buckets: an integer
        max_distance: an integer

    Returns:
        a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
    """
    relative_buckets = 0
    if bidirectional:
        num_buckets //= 2
        relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)
    else:
        relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
    # now relative_position is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / math.log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    return relative_buckets

def compute_bias(
    query_length, 
    key_length,
    relative_attention_bias,
    bidirectional = True, 
    num_buckets = 32, 
    max_distance = 128, 
    device=None,
):
    """Compute binned relative position bias"""
    if device is None:
        device = relative_attention_bias.weight.device
    context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
    memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = _relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=bidirectional,
        num_buckets=num_buckets,
        max_distance=max_distance,
    )
    values = relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
    return values

def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

def pad_attention_mask(attention_mask):
    maxlen_right = max([attention_mask[i].shape[1] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[0] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[1], 0, maxlen_bottom - attention_mask[i].shape[0])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def pad_attention_mask_4d(attention_mask):
    maxlen_right = max([attention_mask[i].shape[-2] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[-1] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[-2], 0, maxlen_bottom - attention_mask[i].shape[-1])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def block_diagonal_concat(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_4d(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(1) for mask in masks)
    combined_mask = torch.zeros(masks[0].shape[0], 
                                total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(1)
        combined_mask[:, current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_cross(*masks, dtype=torch.bfloat16):
    total_rows = sum(mask.size(0) for mask in masks)
    total_cols = sum(mask.size(1) for mask in masks)
    
    combined_mask = torch.zeros((total_rows, total_cols), dtype=dtype)
    
    current_row, current_col = 0, 0

    for mask in masks:
        rows, cols = mask.size()
        combined_mask[current_row:current_row + rows, current_col:current_col + cols] = mask
        current_row += rows
        current_col += cols
        
    return combined_mask

def multipack(input_ids, labels, lengths):
    results = {
        'input_ids': input_ids,
        'labels': labels,
    }
    attention_mask = []
    encoder_attention_mask = []
    decoder_attention_mask = []
    encoder_biases = []
    decoder_biases = []
    
    for length in lengths:
        left_len = length[0]
        right_len = length[1]
        
        attention_mask.append(torch.ones(left_len, left_len))
        encoder_attention_mask.append(torch.ones(right_len, left_len))
        decoder_attention_mask.append(torch.tril(torch.ones(right_len, right_len)))
        
        encoder_bias = compute_bias(
            left_len, left_len,
            encoder_emb,
            bidirectional=True,
            num_buckets=model.config.relative_attention_num_buckets,
            max_distance=model.config.relative_attention_max_distance,
        )
        encoder_biases.append(encoder_bias[0])
        
        decoder_bias = compute_bias(
            right_len, right_len,
            decoder_emb,
            bidirectional=False,
            num_buckets=model.config.relative_attention_num_buckets,
            max_distance=model.config.relative_attention_max_distance,
        )
        decoder_biases.append(decoder_bias[0])
        
    results['attention_mask'] = block_diagonal_concat(*attention_mask)
    results['encoder_attention_mask'] = block_diagonal_concat_cross(*encoder_attention_mask)
    results['decoder_attention_mask'] = block_diagonal_concat(*decoder_attention_mask)
    
    results['position_bias'] = block_diagonal_concat_4d(*encoder_biases)
    results['decoder_position_bias'] = block_diagonal_concat_4d(*decoder_biases)
    
    return results

def collator(batch, pad_token_id = 1, label_pad = -100):
    max_length = max(len(l['input_ids']) for l in batch)
    results = {}
    results['input_ids'] = [
        b['input_ids'] + [pad_token_id] * (max_length - len(b['input_ids']))
        for b in batch
    ]
    results['input_ids'] = torch.tensor(results['input_ids'], dtype = torch.int64)
    
    max_length = max(len(l['labels']) for l in batch)
    results['labels'] = [
        b['labels'] + [label_pad] * (max_length - len(b['labels']))
        for b in batch
    ]
    results['labels'] = torch.tensor(results['labels'], dtype = torch.int64)
    
    results['position_bias'] = pad_attention_mask_4d([b['position_bias'] for b in batch])
    results['decoder_position_bias'] = pad_attention_mask_4d([b['decoder_position_bias'] for b in batch])
    
    attention_mask = [b['attention_mask'] for b in batch]
    results['attention_mask'] = pad_attention_mask(attention_mask)
    encoder_attention_mask = [b['encoder_attention_mask'] for b in batch]
    results['encoder_attention_mask'] = pad_attention_mask(encoder_attention_mask)
    decoder_attention_mask = [b['decoder_attention_mask'] for b in batch]
    results['decoder_attention_mask'] = pad_attention_mask(decoder_attention_mask)
    
    dtype = results['attention_mask'].dtype
    encoder_extended_attention_mask = results['attention_mask'][:, None, :, :]
    encoder_extended_attention_mask = encoder_extended_attention_mask
    encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min
    results['position_bias'] = results['position_bias'] + encoder_extended_attention_mask
    
    dtype = results['decoder_attention_mask'].dtype
    encoder_extended_attention_mask = results['decoder_attention_mask'][:, None, :, :]
    encoder_extended_attention_mask = encoder_extended_attention_mask
    encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min
    results['decoder_position_bias'] = results['decoder_position_bias'] + encoder_extended_attention_mask
    
    return results

In [9]:
def loop(files, block_size = 200):
    files, index = files
    
    out_root = f'tokenized/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    
    count_input_ids = 0
    count_labels = 0
    input_ids = []
    labels = []
    lengths = []
    found = False
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for f in files:
            with open(f) as fopen:
                for l in tqdm(fopen):
                    l = json.loads(l)['translation']
                    left = l['prefix'] + l['src'] + tokenizer.eos_token
                    right = l['tgt'] + tokenizer.eos_token

                    left = tokenizer(left)['input_ids']
                    right = tokenizer(right)['input_ids']

                    if count_input_ids + len(left) >= block_size or count_labels + len(left) >= block_size:
                        
                        if len(input_ids) and len(labels):
                            d = {
                                'input_ids': input_ids,
                                'labels': labels,
                                'lengths': lengths
                            }
                            count_input_ids = 0
                            count_labels = 0
                            input_ids = []
                            labels = []
                            lengths = []

                            d = json.dumps(d)

                            out.write({
                                'data': d,
                            })       

                    else:
                        count_input_ids += len(left)
                        count_labels += len(right)
                        input_ids.extend(left)
                        labels.extend(right)
                        lengths.append((len(left), len(right)))

In [10]:
# loop((files[:1], 0))

In [11]:
dataset = LocalDataset('tokenized/tokenized-0')

In [12]:
len(dataset)

5368

In [13]:
dataset[0]

{'data': '{"input_ids": [4130, 10670, 344, 8349, 304, 30248, 29, 8034, 15, 11638, 954, 1354, 2957, 564, 4636, 887, 5873, 519, 2610, 537, 9370, 677, 1267, 20977, 69, 3834, 17, 304, 935, 498, 384, 4834, 16, 1978, 1066, 17306, 29932, 14725, 17, 941, 435, 12435, 390, 2553, 17362, 34, 12284, 15, 355, 1221, 3037, 355, 12692, 7931, 564, 384, 14725, 17, 1860, 10905, 15, 331, 3654, 7597, 11484, 498, 1069, 390, 25788, 11419, 384, 19697, 17, 410, 1384, 4, 2, 4130, 10670, 344, 3146, 29, 224, 8987, 354, 12388, 6061, 4209, 1357, 1172, 8330, 3896, 688, 14317, 409, 2680, 354, 6407, 561, 2680, 3139, 409, 7726, 19527, 4077, 354, 22644, 12149, 9971, 4919, 224, 224, 2377, 409, 354, 872, 1549, 1357, 14527, 1243, 106, 465, 1400, 5494, 354, 1075, 9981, 30362, 1549, 13041, 16857, 16787, 641, 465, 13866, 4, 2], "labels": [224, 8987, 354, 12388, 6061, 4209, 1357, 1172, 8330, 3896, 688, 14317, 409, 2680, 354, 6407, 561, 2680, 3139, 409, 7726, 19527, 4077, 354, 22644, 12149, 9971, 4919, 224, 224, 2377, 409, 354, 

In [14]:
d = json.loads(dataset[0]['data'])
b = multipack(**d)
d = json.loads(dataset[1]['data'])
b1 = multipack(**d)

In [15]:
%%time

input_ids = collator([b, b1])

CPU times: user 2.6 ms, sys: 1.25 ms, total: 3.85 ms
Wall time: 984 µs


In [16]:
o = model(**input_ids)

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

False True position_bias tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...

In [17]:
o.loss

tensor(2.1719, dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)

In [20]:
o.logits.argmax(-1)

tensor([[  224,  8987,   354, 12388,  1423,  1357,  1357,  3113,  1015,   354,
          3113, 14317,   409,  2680,   354,   224,   561,  2680,  4506,  7726,
          7726,  4303, 14527,   354, 21877,  2377,  9971,  6426,  1213,   224,
          1295,  1549,   354,   872,  2377,   872, 14527,   465,   257,   354,
          1400,  5494,   354,  1075,  9981,   561,   561,  8366, 16857,   465,
           641,   465, 13866,   522,     2,     1,    15,  8034,   700,  5346,
          3256,  3448,   344,  1951,   647,   515,   436,  3196,   689,  1305,
          3196,    17,  1637,   715, 11860,   492,   918,  1330, 25267,   287,
           691, 20525,   871,     2,  2647,   449,   689,    34,     2,    15,
           689,   515,  3335,  1553,   603,  3448,   344,   436,     2,  1856,
           449,  1401,  2498,  2498,  2150,  3448,   689, 14088, 14906, 12349,
            17,   410,  1384,     4,     2,   224,   224,   224,   224,   224,
           224,   224,   224,   224,   224,   224,  

In [None]:
o.logits.argmax(-1)

In [None]:
o.loss

In [None]:
b = multipack(**d)
input_ids = collator([b])

In [None]:
input_ids

In [None]:
o = model(**input_ids)

In [None]:
o.loss