In [None]:
import glob as glob
import time
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

In [None]:
from helpers.dataset import load_shard_as_dataloader

shard_files = sorted(glob.glob("./../../data/train_shard_*.json"))
print(f"Found {len(shard_files)} shards.")

for shard_idx, shard_path in enumerate(shard_files[0:1]):
    start_time = time.time()
    print('Start')
    shard_dl = load_shard_as_dataloader(shard_path, tokenizer, batch_size = 64 * 4, seq_len = 2048, eos_seperator_id = tokenizer.eos_token_id, shuffle = False)
    print(time.time() - start_time)

## Alt Version 1 - Multiprocess import

In [None]:
from helpers.dataset import load_shard_as_dataloader_mp

shard_files = sorted(glob.glob("./../../data/train_shard_*.json"))
print(f"Found {len(shard_files)} shards.")

for shard_idx, shard_path in enumerate(shard_files[0:1]):
    start_time = time.time()
    print('Start')
    shard_dl = load_shard_as_dataloader_mp(shard_path, tokenizer, batch_size = 64 * 4, seq_len = 2048, eos_seperator_id = tokenizer.eos_token_id)
    print(time.time() - start_time)

## Alt Version 2 - Preprocess first, shard later

In [None]:
from helpers.dataset import load_pt_shard_as_dataloader

shard_pt_files = sorted(glob.glob("./../../data/olmoe-tokenizer/train_shard_*.pt"))
print(f"Found {len(shard_pt_files)} shards.")

for shard_idx, shard_pt_path in enumerate(shard_pt_files[0:1]):
    start_time = time.time()
    print('Start')
    shard_pt_dl = load_pt_shard_as_dataloader(shard_pt_path, tokenizer, batch_size = 64 * 4, seq_len = 2048, shuffle = False)
    print(time.time() - start_time)

## Validate approaches are the same

In [None]:
## VERIFY SAME OUTPUT!!.
for b in shard_dl:
    print(b)
    break

In [None]:
for b in shard_dl:
    print(b)
    break

In [None]:
# Flatten results from old approach
old_input_ids_list = []
old_attention_list = []

for b in shard_pt_dl:
    old_input_ids_list.append(b['input_ids'])
    old_attention_list.append(b['attention_mask'])

old_input_ids_all = torch.cat(old_input_ids_list, dim=0)
old_attention_all = torch.cat(old_attention_list, dim=0)
print(f"old_input_ids_all shape = {old_input_ids_all.shape}")

In [None]:
new_input_ids_list = []
new_attention_list = []

for b in shard_pt_dl:
    new_input_ids_list.append(b['input_ids'])
    new_attention_list.append(b['attention_mask'])

new_input_ids_all = torch.cat(new_input_ids_list, dim=0)
new_attention_all = torch.cat(new_attention_list, dim=0)
print(f"new_input_ids_all shape = {new_input_ids_all.shape}")

In [None]:
if old_input_ids_all.shape != new_input_ids_all.shape:
    print("ERROR: shapes differ in input_ids!")
else:
    same_ids = (old_input_ids_all == new_input_ids_all).all().item()
    same_mask = (old_attention_all == new_attention_all).all().item()

    print(same_ids, same_mask)

In [None]:
new_input_ids_all[-1, :]

In [None]:
old_input_ids_all[-1, :]