In [1]:
# !pip3.10 install --pre torch==2.5.0.dev20240912+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121

In [2]:
# !pip3.10 install torch -U

In [3]:
# !pip3.10 uninstall transformers -y

In [4]:
# !pip3.10 install -e . --no-deps

In [5]:
import json
import torch
import torch.nn.functional as F
import math
from glob import glob

In [6]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased', local_files_only = True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
model = T5ForConditionalGeneration.from_pretrained('mesolitica/nanot5-small-malaysian-cased',local_files_only = True).to(torch.bfloat16)

In [8]:
encoder_emb = torch.nn.Embedding(
    model.config.relative_attention_num_buckets, 
    model.config.num_heads
)
encoder_emb.load_state_dict(
    model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.state_dict()
)

decoder_emb = torch.nn.Embedding(
    model.config.relative_attention_num_buckets, 
    model.config.num_heads
)
decoder_emb.load_state_dict(
    model.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.state_dict()
)

<All keys matched successfully>

In [9]:
files = glob('/home/husein/ssd4/translation/post-translation.json*.splitted')
files.extend(glob('/home/husein/ssd4/translation/post-translation-part2.json*.splitted'))
files.extend(glob('/home/husein/ssd4/translation/post-translation-part5.json*.splitted'))
len(files)

61

In [10]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'data': 'str',
}
hashes = 'sha1', 'xxh64'

In [11]:
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
    """
    Adapted from Mesh Tensorflow:
    https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

    Translate relative position to a bucket number for relative attention. The relative position is defined as
    memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
    small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
    positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
    This should allow for more graceful generalization to longer sequences than the model has been trained on

    Args:
        relative_position: an int32 Tensor
        bidirectional: a boolean - whether the attention is bidirectional
        num_buckets: an integer
        max_distance: an integer

    Returns:
        a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
    """
    relative_buckets = 0
    if bidirectional:
        num_buckets //= 2
        relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)
    else:
        relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
    # now relative_position is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / math.log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    return relative_buckets

def compute_bias(
    query_length, 
    key_length,
    relative_attention_bias,
    bidirectional = True, 
    num_buckets = 32, 
    max_distance = 128, 
    device=None,
):
    """Compute binned relative position bias"""
    if device is None:
        device = relative_attention_bias.weight.device
    context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
    memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = _relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=bidirectional,
        num_buckets=num_buckets,
        max_distance=max_distance,
    )
    values = relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
    return values

def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

def pad_attention_mask(attention_mask, maxlen = 2048):
    maxlen_right = maxlen
    maxlen_bottom = maxlen
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[1], 0, maxlen_bottom - attention_mask[i].shape[0])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def pad_attention_mask_4d(attention_mask, maxlen = 2048):
    maxlen_right = maxlen
    maxlen_bottom = maxlen
    attention_mask = [
        F.pad(
            attention_mask[i],
            (0, maxlen_right - attention_mask[i].shape[-2], 0, maxlen_bottom - attention_mask[i].shape[-1])) for i in range(
            len(attention_mask))]
    return torch.stack(attention_mask)

def block_diagonal_concat(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_4d(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(1) for mask in masks)
    combined_mask = torch.zeros(masks[0].shape[0], 
                                total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(1)
        combined_mask[:, current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    return combined_mask

def block_diagonal_concat_cross(*masks, dtype=torch.bfloat16):
    total_rows = sum(mask.size(0) for mask in masks)
    total_cols = sum(mask.size(1) for mask in masks)
    
    combined_mask = torch.zeros((total_rows, total_cols), dtype=dtype)
    
    current_row, current_col = 0, 0

    for mask in masks:
        rows, cols = mask.size()
        combined_mask[current_row:current_row + rows, current_col:current_col + cols] = mask
        current_row += rows
        current_col += cols
        
    return combined_mask

def collator(batch, pad_token_id = 1, label_pad = -100, maxlen = 2048):
    max_length = maxlen
    results = {}
    results['input_ids'] = [
        b['input_ids'] + [pad_token_id] * (max_length - len(b['input_ids']))
        for b in batch
    ]
    results['input_ids'] = torch.tensor(results['input_ids'], dtype = torch.int64)
    
    max_length = maxlen
    results['labels'] = [
        b['labels'] + [label_pad] * (max_length - len(b['labels']))
        for b in batch
    ]
    results['labels'] = torch.tensor(results['labels'], dtype = torch.int64)
    
    results['position_bias'] = pad_attention_mask_4d([b['position_bias'] for b in batch])
    results['decoder_position_bias'] = pad_attention_mask_4d([b['decoder_position_bias'] for b in batch])
    
    attention_mask = [b['attention_mask'] for b in batch]
    results['attention_mask'] = pad_attention_mask(attention_mask)
    encoder_attention_mask = [b['encoder_attention_mask'] for b in batch]
    results['encoder_attention_mask'] = pad_attention_mask(encoder_attention_mask)
    decoder_attention_mask = [b['decoder_attention_mask'] for b in batch]
    results['decoder_attention_mask'] = pad_attention_mask(decoder_attention_mask)
    
    dtype = results['attention_mask'].dtype
    encoder_extended_attention_mask = results['attention_mask'][:, None, :, :]
    encoder_extended_attention_mask = encoder_extended_attention_mask
    encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min
    results['position_bias'] = results['position_bias'] + encoder_extended_attention_mask
    
    dtype = results['decoder_attention_mask'].dtype
    encoder_extended_attention_mask = results['decoder_attention_mask'][:, None, :, :]
    encoder_extended_attention_mask = encoder_extended_attention_mask
    encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min
    results['decoder_position_bias'] = results['decoder_position_bias'] + encoder_extended_attention_mask
    
    return results

In [12]:
!rm -rf tokenized-post
!mkdir tokenized-post

In [13]:
def loop(files, block_size = 2048):
    files, index = files
    
    out_root = f'tokenized-post/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    
    count_input_ids = 0
    count_labels = 0
    input_ids = []
    labels = []
    lengths = []
    found = False
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for f in files:
            with open(f) as fopen:
                for l in tqdm(fopen):
                    try:
                        l = json.loads(l)['translation']
                        
                        left = len(l['src'].split())
                        right = len(l['src'].split())
                        
                        if left < 1:
                            continue
                            
                        if right < 1:
                            continue
                        
                        if (left / right) < 0.5:
                            continue
                        
                        if (right / left) < 0.5:
                            continue
                            
                        left = l['prefix'] + l['src'] + tokenizer.eos_token
                        right = l['tgt'] + tokenizer.eos_token

                        left = tokenizer(left)['input_ids']
                        right = tokenizer(right)['input_ids']
                        
                        if len(left) > block_size or len(right) > block_size:
                            continue
                    except:
                        continue

                    if count_input_ids + len(left) >= block_size or count_labels + len(right) >= block_size:
                        
                        if len(input_ids) and len(labels):
                            d = {
                                'input_ids': input_ids,
                                'labels': labels,
                                'lengths': lengths
                            }
                            count_input_ids = len(left)
                            count_labels = len(right)
                            input_ids = left
                            labels = right
                            lengths = [(len(left), len(right))]
                            # print(input_ids, '\n\n', labels, '\n\n', lengths, d)

                            d = json.dumps(d)

                            out.write({
                                'data': d,
                            })
                            # break

                    else:
                        count_input_ids += len(left)
                        count_labels += len(right)
                        input_ids.extend(left)
                        labels.extend(right)
                        lengths.append((len(left), len(right)))
        
        if len(input_ids) and len(labels):
            d = {
                'input_ids': input_ids,
                'labels': labels,
                'lengths': lengths
            }
            
            d = json.dumps(d)
            
            out.write({
                'data': d,
            })

In [14]:
# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py

In [15]:
# loop((files[:1], 0))

In [16]:
import mp

In [17]:
len(files)

61

In [18]:
# loop((files[-1:], 0))

In [None]:
mp.multiprocessing(files, loop, cores = 10, returned = False)

201170it [00:35, 5725.23it/s]
500000it [01:07, 7406.28it/s]]
500000it [01:25, 5868.10it/s] 
500000it [02:20, 3557.34it/s]
500000it [01:47, 4629.94it/s]
353508it [00:58, 6413.03it/s]
500000it [01:20, 6221.59it/s]]
213365it [00:53, 808.09it/s]] 
500000it [01:18, 6338.11it/s] 
411036it [04:03, 1449.84it/s] 
500000it [01:04, 7703.55it/s] 
258234it [02:02, 2116.26it/s]
500000it [04:39, 1789.88it/s] 
500000it [02:18, 3612.69it/s]]
500000it [04:43, 1762.79it/s] 
39325it [00:06, 5818.09it/s]] 
500000it [04:52, 1709.24it/s]
500000it [01:29, 5617.53it/s] 
500000it [01:07, 7402.54it/s] 
500000it [01:23, 5992.95it/s] 
500000it [01:40, 4965.33it/s]]
160984it [01:38, 1641.81it/s]]
500000it [01:05, 7668.47it/s] 
500000it [07:00, 1189.20it/s] 
500000it [02:51, 2919.98it/s] 
500000it [01:09, 7191.82it/s] 
500000it [00:55, 9075.74it/s] 
500000it [01:47, 4663.53it/s]
500000it [02:08, 3879.38it/s]]
500000it [04:42, 1769.56it/s] 
500000it [02:01, 4121.69it/s] 
500000it [00:51, 9653.03it/s] 
500000it [09:25

In [None]:
folders = sorted(glob('tokenized-post/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

In [None]:
!rm -rf packing-post

In [None]:
with MDSWriter(out='packing-post', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

In [None]:
dataset = LocalDataset('packing-post')

In [None]:
len(dataset) * 2048

In [None]:
!du -hs packing-post

In [None]:
# d = json.loads(dataset[0]['data'])
# b = multipack(**d)
# d = json.loads(dataset[1]['data'])
# b1 = multipack(**d)

In [None]:
# %%time

# input_ids = collator([b, b1])

In [None]:
model.config.vocab_size

In [None]:
for i in tqdm(range(len(dataset))):
    d = json.loads(dataset[i]['data'])
    c = (np.array(d['input_ids']) > model.config.vocab_size).sum()
    if c > 0:
        print(i)
    c = (np.array(d['labels']) > model.config.vocab_size).sum()
    if c > 0:
        print(i)

In [None]:
(np.array(d['labels']) > model.config.vocab_size).sum()

In [None]:
input_ids['input_ids']

In [None]:
input_ids['decoder_attention_mask'].shape

In [None]:
from huggingface_hub import create_repo, delete_repo

try:
    delete_repo(repo_id="mesolitica/malaysian-translation-v2-multipack-2048-post", repo_type="dataset")
except:
    pass
create_repo("mesolitica/malaysian-translation-v2-multipack-2048-post", repo_type="dataset", private = True)

In [None]:
!du -hs packing-post

In [None]:
(len(dataset) * 2048) / 1e9

In [None]:
from huggingface_hub import HfApi
api = HfApi()

api.upload_folder(
    folder_path="packing-post",
    repo_id="mesolitica/malaysian-translation-v2-multipack-2048-post",
    repo_type="dataset",
)

In [None]:
len(b['labels'])