In [None]:
import tiktoken
import random
import tqdm
import numpy as np
from pathlib import Path
import multiprocessing as mp

from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

In [None]:
dataset_full = load_dataset("allenai/soda")
del dataset_full['test']

In [None]:
gpt2_base_tokeniser = tiktoken.get_encoding("gpt2")

tokeniser = tiktoken.Encoding(
    name="gpt2_soda",
    pat_str=gpt2_base_tokeniser._pat_str,
    mergeable_ranks=gpt2_base_tokeniser._mergeable_ranks,
    special_tokens={
        **gpt2_base_tokeniser._special_tokens,
        "<|sep|>": 50257,
        "<|turn|>": 50258
    }
)

encode = lambda s: tokeniser.encode(s, allowed_special=tokeniser.special_tokens_set)
decode = lambda s: tokeniser.decode(s)

In [None]:
def process(item):
    speakers = [e.title() for e in item['speakers'][0]]
    unique_speakers = list(set(speakers))
    dialog = item['dialogue'][0]

    context_str = (
        f"{item['narrative'][0]}<|sep|>The following is a conversation in the scene between "
        f"{unique_speakers[0]} and {unique_speakers[1]}<|sep|>"
    )
    
    x_tokens = []
    y_tokens = []
    first_speaker = speakers[0]
    last_x = context_str
    for i, (speaker, dialog_line) in enumerate(zip(speakers, dialog)):
        if speaker == first_speaker:
            last_x += f"{dialog_line}<|turn|>"
            continue
        
        x = last_x
        y = f"{dialog_line}"
        
        last_x += f"{dialog_line}<|turn|>"
        
        x_tokens.append(encode(x))
        y_tokens.append(encode(y))

    return {'x': x_tokens, 'y': y_tokens, 'x_size': [len(e) for e in x_tokens], 'y_size': [len(e) for e in y_tokens]}

In [None]:
dataset_tokenised = dataset_full.map(
    process,
    num_proc=mp.cpu_count(),
    batched=True,
    batch_size=1,
    remove_columns=['head', 'relation', 'tail', 'literal', 'narrative', 'dialogue', 'speakers', 'PersonX', 'PersonY', 'PersonZ', 'original_index', 'split', 'head_answer', 'pmi_head_answer', 'relation_tail_answer', 'pmi_relation_tail_answer']
)

In [None]:
print(dataset_tokenised)
print(dataset_full)

In [None]:
data_dir = Path('~/nanoGPT/data/soda/').expanduser()

for split, dataset in dataset_tokenised.items():
    x_total_length = np.sum(dataset['x_size'])
    y_total_length = np.sum(dataset['y_size'])

    x_filename = data_dir / f'{split}_x.bin'
    y_filename = data_dir / f'{split}_y.bin'
    x_size_filename = data_dir / f'{split}_x_lengths.bin'
    y_size_filename = data_dir / f'{split}_y_lengths.bin'
    
    x_array = np.memmap(x_filename, dtype=np.uint16, mode='w+', shape=(x_total_length,))
    y_array = np.memmap(y_filename, dtype=np.uint16, mode='w+', shape=(y_total_length,))
    x_size_array = np.memmap(x_size_filename, dtype=int, mode='w+', shape=(len(dataset), 2))
    y_size_array = np.memmap(y_size_filename, dtype=int, mode='w+', shape=(len(dataset), 2))

    x_idx = 0
    y_idx = 0
    x_size_idx = 0
    y_size_idx = 0
    for example in tqdm.tqdm(dataset, unit='examples', smoothing=0.01):
        x = example['x']
        y = example['y']
        x_size = len(x)
        y_size = len(y)
        
        x_start_index = x_idx
        y_start_index = y_idx
        x_end_index = x_start_index + x_size
        y_end_index = y_start_index + y_size

        if x_end_index > np.iinfo(int).max:
            print(f'index int too big! value: `{x_end_index:,}` > {np.iinfo(int).max:,}')
            break
            
        if y_end_index > np.iinfo(int).max:
            print(f'index int too big! value: `{y_end_index:,}` > {np.iinfo(int).max:,}')
            break
    
        x_array[x_start_index : x_end_index] = x
        y_array[y_start_index : y_end_index] = y
        x_size_array[x_size_idx, :] = [x_start_index, x_end_index]
        y_size_array[y_size_idx, :] = [y_start_index, y_end_index]
        
        x_idx += x_size
        y_idx += y_size
        x_size_idx += 1
        y_size_idx += 1

    x_array.flush()
    y_array.flush()
    x_size_array.flush()
    y_size_array.flush()

In [None]:
split = 'train'

x_data = np.memmap(f'{split}_x.bin', dtype=np.uint16, mode='r')
y_data = np.memmap(f'{split}_y.bin', dtype=np.uint16, mode='r')

x_size = np.memmap(f'{split}_x_lengths.bin', dtype=int, mode='r')
y_size = np.memmap(f'{split}_y_lengths.bin', dtype=int, mode='r')

x_size = x_size.reshape((x_size.shape[0]//2, 2))
y_size = y_size.reshape((y_size.shape[0]//2, 2))

In [None]:
# test a random example

i = random.randint(0, x_size.shape[0]-1)

print('example', i)

x_idxs = x_size[i]
y_idxs = y_size[i]

x_item = x_data[x_idxs[0]:x_idxs[1]]
y_item = y_data[y_idxs[0]:y_idxs[1]]


x = decode(x_item)
y = decode(y_item)

print(f'input: `{x}`')
print()
print(f'target: `{y}`')
print()
print()