In [1]:
# !pip3.10 install --pre torch==2.5.0.dev20240912+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121

In [2]:
# !pip3.10 install torch -U

In [3]:
# !pip3.10 uninstall transformers -y

In [4]:
# !pip3.10 install -e . --no-deps

In [5]:
import json
import torch
import torch.nn.functional as F
import math
from glob import glob

In [6]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased', local_files_only = True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
model = T5ForConditionalGeneration.from_pretrained('mesolitica/nanot5-small-malaysian-translation-v2',
                                                  attn_implementation = 'sdpa', local_files_only = True).to(torch.bfloat16)

In [8]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'data': 'str',
}
hashes = 'sha1', 'xxh64'

In [9]:
def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

def pad_attention_mask(attention_mask):
    maxlen_right = max([attention_mask[i].shape[1] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[0] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[1], 0, maxlen_bottom - attention_mask[i].shape[0])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def pad_attention_mask_4d(attention_mask):
    maxlen_right = max([attention_mask[i].shape[-2] for i in range(len(attention_mask))])
    maxlen_bottom = max([attention_mask[i].shape[-1] for i in range(len(attention_mask))])
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[-2], 0, maxlen_bottom - attention_mask[i].shape[-1])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def block_diagonal_concat(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_4d(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(1) for mask in masks)
    combined_mask = torch.zeros(masks[0].shape[0], 
                                total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(1)
        combined_mask[:, current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_cross(*masks, dtype=torch.bfloat16):
    total_rows = sum(mask.size(0) for mask in masks)
    total_cols = sum(mask.size(1) for mask in masks)
    
    combined_mask = torch.zeros((total_rows, total_cols), dtype=dtype)
    
    current_row, current_col = 0, 0

    for mask in masks:
        rows, cols = mask.size()
        combined_mask[current_row:current_row + rows, current_col:current_col + cols] = mask
        current_row += rows
        current_col += cols
        
    return combined_mask

def multipack(input_ids, labels, lengths):
    results = {
        'input_ids': input_ids,
        'labels': labels,
    }
    attention_mask = []
    encoder_attention_mask = []
    decoder_attention_mask = []
    encoder_lengths = []
    decoder_lengths = []
    
    for length in lengths:
        left_len = length[0]
        right_len = length[1]
        
        attention_mask.append(torch.ones(left_len, left_len))
        encoder_attention_mask.append(torch.ones(right_len, left_len))
        decoder_attention_mask.append(torch.tril(torch.ones(right_len, right_len)))
        
        encoder_lengths.append([left_len, left_len])
        decoder_lengths.append([right_len, right_len])
        
    results['attention_mask'] = block_diagonal_concat(*attention_mask)
    results['encoder_attention_mask'] = block_diagonal_concat_cross(*encoder_attention_mask)
    results['decoder_attention_mask'] = block_diagonal_concat(*decoder_attention_mask)
    
    results['encoder_lengths'] = torch.tensor(encoder_lengths)
    results['decoder_lengths'] = torch.tensor(decoder_lengths)
    
    return results

def collator(batch, pad_token_id = 1, label_pad = -100):
    max_length = max(len(l['input_ids']) for l in batch)
    results = {}
    results['input_ids'] = [
        b['input_ids'] + [pad_token_id] * (max_length - len(b['input_ids']))
        for b in batch
    ]
    results['input_ids'] = torch.tensor(results['input_ids'], dtype = torch.int64)
    
    max_length = max(len(l['labels']) for l in batch)
    results['labels'] = [
        b['labels'] + [label_pad] * (max_length - len(b['labels']))
        for b in batch
    ]
    results['labels'] = torch.tensor(results['labels'], dtype = torch.int64)
    
    attention_mask = [b['attention_mask'] for b in batch]
    results['attention_mask'] = pad_attention_mask(attention_mask)
    encoder_attention_mask = [b['encoder_attention_mask'] for b in batch]
    results['encoder_attention_mask'] = pad_attention_mask(encoder_attention_mask)
    decoder_attention_mask = [b['decoder_attention_mask'] for b in batch]
    results['decoder_attention_mask'] = pad_attention_mask(decoder_attention_mask)
    
    results['encoder_lengths'] = torch.concat([b['encoder_lengths'] for b in batch])
    results['decoder_lengths'] = torch.concat([b['decoder_lengths'] for b in batch])

    return results

In [10]:
dataset = LocalDataset('/home/husein/mesolitica/t5-sdpa-multipack/packing')

In [11]:
len(dataset)

2390857

In [12]:
d = json.loads(dataset[0]['data'])
b = multipack(**d)
d = json.loads(dataset[1]['data'])
b1 = multipack(**d)

In [14]:
%%time

input_ids = collator([b])

CPU times: user 206 ms, sys: 21.9 ms, total: 228 ms
Wall time: 22.3 ms


In [15]:
input_ids['input_ids'].shape

torch.Size([1, 2001])

In [16]:
_ = model.cuda()

In [17]:
for k in input_ids.keys():
    input_ids[k] = input_ids[k].cuda()

In [19]:
o = model(**input_ids)

position_bias torch.Size([1, 8, 2001, 2001])
position_bias torch.Size([1, 8, 1756, 1756])


In [None]:
from huggingface_hub import create_repo, delete_repo

try:
    delete_repo(repo_id="mesolitica/malaysian-translation-v2-multipack-2048", repo_type="dataset")
except:
    pass
create_repo("mesolitica/malaysian-translation-v2-multipack-2048", repo_type="dataset", private = True)

In [None]:
!du -hs packing

In [None]:
from huggingface_hub import HfApi
api = HfApi()

api.upload_folder(
    folder_path="packing",
    repo_id="mesolitica/malaysian-translation-v2-multipack-2048",
    repo_type="dataset",
)