# preprocessing glue for fine-tuning

In [8]:
import os
import pickle
import tiktoken
import itertools

import torch
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

In [2]:
def get_dataset(task):
    
    dataset = load_dataset("nyu-mll/glue", task)
    num_labels = len(dataset['train'].features['label'].names)
    
    return dataset, num_labels

In [3]:
dataset, num_labels = get_dataset('cola')

In [148]:
def least_power_of_two(n):
    return 1 << (n-1).bit_length()

In [11]:
def tokenize_batch(sents):
    
    tokens = tokenizer.encode_batch(sents, allowed_special='all')
    
    padded = list(zip(*itertools.zip_longest(*tokens, fillvalue=pad_token)))
    padded = np.array(padded)
    
    attention_mask = (padded != pad_token).astype(int)
        
    return padded, attention_mask

In [22]:
x = torch.randint(0, 10, (5, 2))

In [30]:
torch.randperm(x.shape[0])

tensor([1, 4, 0, 3, 2])

In [26]:
x[torch.randperm(x.shape[0])]

tensor([[8, 8],
        [0, 0],
        [8, 7],
        [7, 2],
        [7, 2]])

In [13]:
padded, mask = tokenize_batch(['the dog', 'went', 'to school', 'today in the store'])

In [172]:
batch_size = 1
n_batches  = 6
dataloader = DataLoader(dataset['train'], batch_size = batch_size, shuffle = True)

In [176]:
for i, batch in enumerate(dataloader):
    print(batch)
    break

for i, batch in enumerate(dataloader):
    print(batch)
    break

{'sentence': ['That you will marry any particular student is not certain.'], 'label': tensor([1]), 'idx': tensor([1978])}
{'sentence': ['Stephen seemed to be intelligent.'], 'label': tensor([1]), 'idx': tensor([4276])}


In [43]:
tokenized = dataset.map(
    process,
    remove_columns=['sentence'],
    desc="tokenizing the splits"
)

tokenizing the splits:   0%|          | 0/8551 [00:00<?, ? examples/s]

tokenizing the splits:   0%|          | 0/1043 [00:00<?, ? examples/s]

tokenizing the splits:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [44]:
for split, dset in tokenized.items():
    arr_len = np.sum(dset['len'], dtype=np.uint64)
    filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
    dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
    arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
    total_batches = 1024

    idx = 0
    for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
        # Batch together samples for faster write
        batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
        arr_batch = np.concatenate(batch['ids'])
        # Write into mmap
        arr[idx : idx + len(arr_batch)] = arr_batch
        idx += len(arr_batch)
    arr.flush()

{'label': 1,
 'idx': 0,
 'ids': [5122,
  2460,
  1839,
  470,
  2822,
  428,
  3781,
  11,
  1309,
  3436,
  262,
  1306,
  530,
  356,
  18077,
  13,
  50256],
 'len': 17}