In [1]:
# !pip3.10 install -e . --no-deps

In [2]:
import torch
import torch.nn.functional as F
import math

In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [78]:
model = T5ForConditionalGeneration.from_pretrained('mesolitica/nanot5-small-malaysian-cased',
                                                  attn_implementation = 'sdpa').to(torch.bfloat16)

In [5]:
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
    """
    Adapted from Mesh Tensorflow:
    https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

    Translate relative position to a bucket number for relative attention. The relative position is defined as
    memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
    small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
    positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
    This should allow for more graceful generalization to longer sequences than the model has been trained on

    Args:
        relative_position: an int32 Tensor
        bidirectional: a boolean - whether the attention is bidirectional
        num_buckets: an integer
        max_distance: an integer

    Returns:
        a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
    """
    relative_buckets = 0
    if bidirectional:
        num_buckets //= 2
        relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)
    else:
        relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
    # now relative_position is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / math.log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    return relative_buckets

def compute_bias(
    query_length, 
    key_length,
    relative_attention_bias,
    bidirectional = True, 
    num_buckets = 32, 
    max_distance = 128, 
    device=None,
):
    """Compute binned relative position bias"""
    if device is None:
        device = relative_attention_bias.weight.device
    context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
    memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = _relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=bidirectional,
        num_buckets=num_buckets,
        max_distance=max_distance,
    )
    values = relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
    return values

def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

In [62]:
def pad_attention_mask(attention_mask):
    maxlen_right = max([attention_mask[i].shape[1] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[0] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[1], 0, maxlen_bottom - attention_mask[i].shape[0])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def pad_attention_mask_4d(attention_mask):
    maxlen_right = max([attention_mask[i].shape[-2] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[-1] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[-2], 0, maxlen_bottom - attention_mask[i].shape[-1])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def block_diagonal_concat(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_4d(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(1) for mask in masks)
    combined_mask = torch.zeros(masks[0].shape[0], 
                                total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(1)
        combined_mask[:, current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_cross(*masks, dtype=torch.bfloat16):
    total_rows = sum(mask.size(0) for mask in masks)
    total_cols = sum(mask.size(1) for mask in masks)
    
    combined_mask = torch.zeros((total_rows, total_cols), dtype=dtype)
    
    current_row, current_col = 0, 0

    for mask in masks:
        rows, cols = mask.size()
        combined_mask[current_row:current_row + rows, current_col:current_col + cols] = mask
        current_row += rows
        current_col += cols
        
    return combined_mask

def input_ids(left, right):
    inputs = tokenizer(left + tokenizer.eos_token)
    inputs['labels'] = tokenizer(right + tokenizer.eos_token)['input_ids']
    inputs.pop('token_type_ids')
    
    left_len = len(inputs['input_ids'])
    right_len = len(inputs['labels'])
    
    inputs['attention_mask'] = torch.ones(left_len, left_len)
    inputs['encoder_attention_mask'] = torch.ones(right_len, left_len)
    inputs['decoder_attention_mask'] = torch.tril(torch.ones(right_len, right_len))
    return inputs
    
def multipack(batch):
    batch = [b for b in batch if b is not None]
    
    results = {}
    
    input_ids, labels = [], []
    encoder_biases, decoder_biases = [], []
    for b in batch:
        input_ids.extend(b['input_ids'])
        labels.extend(b['labels'])
        
        left = len(b['input_ids'])
        right = len(b['labels'])
        
        encoder_bias = compute_bias(
            left, left,
            model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias,
            bidirectional=True,
            num_buckets=model.config.relative_attention_num_buckets,
            max_distance=model.config.relative_attention_max_distance,
        )
        encoder_biases.append(encoder_bias[0])
        
        decoder_bias = compute_bias(
            right, right,
            model.decoder.block[0].layer[0].SelfAttention.relative_attention_bias,
            bidirectional=False,
            num_buckets=model.config.relative_attention_num_buckets,
            max_distance=model.config.relative_attention_max_distance,
        )
        decoder_biases.append(decoder_bias[0])

    results['input_ids'] = input_ids
    results['labels'] = labels
    
    results['position_bias'] = block_diagonal_concat_4d(*encoder_biases)
    results['decoder_position_bias'] = block_diagonal_concat_4d(*decoder_biases)
    
    attention_mask = [b['attention_mask'] for b in batch]
    results['attention_mask'] = block_diagonal_concat(*attention_mask)
    encoder_attention_mask = [b['encoder_attention_mask'] for b in batch]
    results['encoder_attention_mask'] = block_diagonal_concat_cross(*encoder_attention_mask)
    decoder_attention_mask = [b['decoder_attention_mask'] for b in batch]
    results['decoder_attention_mask'] = block_diagonal_concat(*decoder_attention_mask)
    
    return results

def collator(batch, pad_token_id = 1, label_pad = -100):
    max_length = max(len(l['input_ids']) for l in batch)
    results = {}
    results['input_ids'] = [
        b['input_ids'] + [pad_token_id] * (max_length - len(b['input_ids']))
        for b in batch
    ]
    results['input_ids'] = torch.tensor(results['input_ids'], dtype = torch.int64)
    
    max_length = max(len(l['labels']) for l in batch)
    results['labels'] = [
        b['labels'] + [label_pad] * (max_length - len(b['labels']))
        for b in batch
    ]
    results['labels'] = torch.tensor(results['labels'], dtype = torch.int64)
    
    results['position_bias'] = pad_attention_mask_4d([b['position_bias'] for b in batch])
    results['decoder_position_bias'] = pad_attention_mask_4d([b['decoder_position_bias'] for b in batch])
    
    attention_mask = [b['attention_mask'] for b in batch]
    results['attention_mask'] = pad_attention_mask(attention_mask)
    encoder_attention_mask = [b['encoder_attention_mask'] for b in batch]
    results['encoder_attention_mask'] = pad_attention_mask(encoder_attention_mask)
    decoder_attention_mask = [b['decoder_attention_mask'] for b in batch]
    results['decoder_attention_mask'] = pad_attention_mask(decoder_attention_mask)
    return results

In [63]:
left = ['hidup ini', 'sangat bodoh ye', 'bodddddooooooo gile']
right = ['hidup ini sangat keras ye', 'sangat bodoh ye dan teramat', 'ye tahu']

In [64]:
b = input_ids(left[0], right[0])
b1 = input_ids(left[1], right[1])
b2 = input_ids(left[2], right[2])

In [65]:
m1 = multipack([b, b1])
m2 = multipack([b, b2])

In [67]:
inputs = collator([m1, m2])

In [72]:
m1['position_bias'].shape, m2['position_bias'].shape

(torch.Size([8, 7, 7]), torch.Size([8, 14, 14]))

In [79]:
o = model(**inputs)

self.is_decoder False
extended_attention_mask tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38, -0.0000e+00, -0.0000e+00,
           -0.0000e+00, -0.0000e+00, -3.3895e+38, -3.3895e+38, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38, -0.0000e+00, -0.0000e+00,
           -0.

else False tensor([[[[  4.0625,   7.3750,   5.6250,  ...,   0.0000,   0.0000,   0.0000],
          [ -0.1455,   4.0625,   7.3750,  ...,   0.0000,   0.0000,   0.0000],
          [ -1.6484,  -0.1455,   4.0625,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[-11.6250, -15.1250, -16.6250,  ...,   0.0000,   0.0000,   0.0000],
          [  6.9062, -11.6250, -15.1250,  ...,   0.0000,   0.0000,   0.0000],
          [  7.3438,   6.9062, -11.6250,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  3.4219

torch.Size([2, 13, 512])
T5Block self.is_decoder True
T5Block self.is_decoder True self attention
else False tensor([[[[  2.7812,   2.7812,   2.7812,  ...,   0.0000,   0.0000,   0.0000],
          [  3.9375,   2.7812,   2.7812,  ...,   0.0000,   0.0000,   0.0000],
          [  2.4062,   3.9375,   2.7812,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   2.7812,   2.7812,   2.7812],
          [  0.0000,   0.0000,   0.0000,  ...,   3.9375,   2.7812,   2.7812],
          [  0.0000,   0.0000,   0.0000,  ...,   2.4062,   3.9375,   2.7812]],

         [[  1.9844,   1.9844,   1.9844,  ...,   0.0000,   0.0000,   0.0000],
          [  5.0938,   1.9844,   1.9844,  ...,   0.0000,   0.0000,   0.0000],
          [  4.1250,   5.0938,   1.9844,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   1.9844,   1.9844,   1.9844],
          [  0.0000,   0.0000,   0.0000,  ...,   5.0938,   1.9844,   1.9844],
 

T5Block self.is_decoder True cross attention
torch.Size([2, 13, 512]) torch.Size([2, 14, 512])
else False tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          ...,
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [-3.3895e+38, -3.3895e+38, -3.3895e+38,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3895e+38,
           -3.3895e+38, -3.3895e+38],
          [ 0.0000e+0

In [80]:
o.logits.shape

torch.Size([2, 13, 32128])