In [1]:
import os
import numpy as np
import torch
from tokenizers import Tokenizer

In [2]:
device = 'cpu'

In [3]:
tokenizer = Tokenizer.from_file("/data/evan/CS285_Final_Project/model/tokenizer.model")

In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
dataset = load_dataset("evanfrick/lichess", num_proc=32)

Resolving data files: 100%|██████████| 40/40 [00:01<00:00, 33.14it/s]


In [6]:
split_dataset = dataset["train"].train_test_split(test_size=0.00025, seed=2357, shuffle=True)
split_dataset['val'] = split_dataset.pop('test')

In [7]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'game', 'result'],
        num_rows: 19052702
    })
    val: Dataset({
        features: ['id', 'game', 'result'],
        num_rows: 4765
    })
})

In [8]:
lengths = split_dataset['train'].map(lambda s: {'length' : len(s)}, input_columns=['game'], num_proc=64)

In [9]:
lengths

Dataset({
    features: ['id', 'game', 'result', 'length'],
    num_rows: 19052702
})

In [10]:
split_dataset_sorted = lengths.sort(['length'])

In [64]:
split_dataset_sorted = lengths.sort(['length']).filter(lambda x : x >= 36, input_columns=['length'], num_proc=64)

In [50]:
import torch
from torch.utils.data import Sampler
import random

# Define a custom sampler that shuffles the batches, but not the items within them
class BatchShuffleSampler(Sampler):
    def __init__(self, data_source, batch_size):
        # Initialize the base class
        super().__init__(data_source)
        # Store the data source, batch size, and generator
        self.data_source = data_source
        self.batch_size = batch_size
        # Compute the number of batches
        self.num_batches = (len(self.data_source) + self.batch_size - 1) // self.batch_size
    
    def __iter__(self):
        # Create a list of batch indices
        batch_indices = list(range(self.num_batches))
        # Shuffle the batch indices
        random.shuffle(batch_indices)
        # For each batch index
        for batch_index in batch_indices:
            # Compute the start and end indices of the items in the batch
            start_index = batch_index * self.batch_size
            end_index = min(start_index + self.batch_size, len(self.data_source))
            # Yield the item indices in the batch
            yield from range(start_index, end_index)
    
    def __len__(self):
        # Return the total number of items
        return len(self.data_source)


In [57]:
ds = split_dataset_sorted.with_format('torch', columns=['game'])
train_loader = DataLoader(ds, batch_size=96, pin_memory=True, pin_memory_device='cpu', sampler=BatchShuffleSampler(ds, 96))


In [58]:
tokenizer.token_to_id('<eos>')
tokenizer.enable_truncation(768)

In [59]:

def make_batch_iterator():
    for batch in iter(train_loader):
        ids = torch.IntTensor([enc.ids for enc in tokenizer.encode_batch(batch['game'])]).type(torch.int16)
        X = ids[:,:-1]
        Y = ids[:,1:]
        yield X, Y

    return

getter = make_batch_iterator()

def get_batch():
    return next(getter)


In [60]:
batch = get_batch()

In [61]:
k = next(iter(train_loader))

In [62]:
k

{'game': ['<w>e2-e4<b>c7-c5<w>Ng1-f3<b>Nb8-c6<w>d2-d4<b>c5xd4<w>Nf3xd4<b>g7-g6<w>Nb1-c3<b>Bf8-g7<w>Nd4xc6<b>b7xc6<w>Bf1-c4<b>d7-d6<w>O-O<b>Ng8-f6<w>e4-e5<b>d6xe5<w>Qd1xd8+<b>Ke8xd8<w>Rf1-e1<b>Kd8-c7<w>Re1xe5<b>e7-e6<w>Bc1-f4<b>Kc7-b7<w>Nc3-a4<b>Nf6-d5<w>Bc4xd5<b>c6xd5<w>Re5-e3<b>d5-d4<w>Re3-b3+<b>Kb7-c6<w>c2-c3<b>e6-e5<w>Bf4-d2<b>Bc8-e6<w>Rb3-b4<b>Ra8-b8<w>c3xd4<b>Rb8xb4<w>Bd2xb4<b>Kc6-b5<w>d4-d5<b>Kb5xb4<w>d5xe6<b>Kb4xa4<w>e6xf7<b>Rh8-f8<w>Ra1-c1<b>Rf8xf7<w>Rc1-c6<b>Rf7-b7<w>b2-b3+<b>Ka4-b5<w>Rc6-e6<b>Kb5-c5<w>g2-g3<b>Kc5-d5<w>Re6-a6<b>e5-e4<w>Kg1-g2<b>Bg7-d4<w>Ra6-a5+<b>Kd5-e6<w>Ra5-a4<b>Rb7-d7<w>Kg2-f1<b>e4-e3<w>f2xe3<b>Bd4-b6<w>Kf1-e2<b>Rd7-c7<w>Ke2-d3<b>Rc7-f7<w>Kd3-e2<b>Rf7-e7<w>h2-h3<b>Ke6-d6<w>e3-e4<b>Re7-f7<w>h3-h4<b>Rf7-f2+<w>Ke2-d3<b>Rf2-f3+<w>Kd3-c4<b>Rf3xg3',
  '<w>d2-d4<b>d7-d5<w>c2-c4<b>c7-c6<w>Ng1-f3<b>Ng8-f6<w>e2-e3<b>e7-e6<w>Nb1-c3<b>Bf8-b4<w>Qd1-c2<b>Qd8-a5<w>Bc1-d2<b>O-O<w>Bf1-d3<b>h7-h6<w>Nf3-e5<b>Nb8-d7<w>f2-f4<b>Nd7xe5<w>f4xe5<b>Nf6-d7<w>c4xd5<b>e6xd5<w>O-O<b>Nd7

In [25]:
print(batch[0])

tensor([[    0, 22047,     1,  ..., 23253, 23253, 23253],
        [    0, 22041,     1,  ..., 23253, 23253, 23253],
        [    0, 22014,     1,  ..., 23253, 23253, 23253],
        ...,
        [    0, 22044,     1,  ..., 23253, 23253, 23253],
        [    0, 22047,     1,  ..., 23253, 23253, 23253],
        [    0, 22047,     1,  ..., 23253, 23253, 23253]], dtype=torch.int16)


In [161]:
batch = get_batch()
max_lens = []
while batch:
    max_lens.append(batch[0].shape[1])
    batch = get_batch()


KeyboardInterrupt: 

In [162]:
len(max_lens)

1775

In [163]:
max_lens

[367,
 345,
 479,
 393,
 339,
 345,
 377,
 361,
 597,
 391,
 541,
 357,
 305,
 511,
 497,
 355,
 431,
 407,
 465,
 541,
 391,
 439,
 327,
 355,
 353,
 331,
 495,
 497,
 387,
 395,
 427,
 335,
 369,
 387,
 299,
 589,
 395,
 347,
 401,
 607,
 373,
 335,
 405,
 325,
 325,
 505,
 611,
 455,
 599,
 403,
 415,
 435,
 593,
 545,
 347,
 399,
 375,
 337,
 343,
 359,
 323,
 355,
 475,
 345,
 327,
 389,
 637,
 309,
 359,
 403,
 451,
 323,
 353,
 325,
 595,
 345,
 367,
 407,
 355,
 455,
 485,
 339,
 425,
 415,
 355,
 325,
 437,
 321,
 347,
 775,
 481,
 555,
 477,
 407,
 457,
 335,
 319,
 953,
 347,
 361,
 459,
 313,
 471,
 489,
 515,
 571,
 389,
 461,
 357,
 341,
 527,
 369,
 363,
 361,
 357,
 507,
 347,
 329,
 343,
 425,
 467,
 347,
 381,
 359,
 381,
 357,
 359,
 313,
 547,
 305,
 307,
 423,
 363,
 387,
 323,
 505,
 417,
 379,
 299,
 433,
 367,
 321,
 339,
 329,
 443,
 613,
 295,
 423,
 335,
 451,
 377,
 361,
 485,
 765,
 567,
 357,
 345,
 515,
 371,
 381,
 337,
 311,
 401,
 329,
 889,
 611,
 383

In [165]:
max_lens = np.array(max_lens)

In [166]:
max_lens.mean()

418.2078873239437

In [168]:
max_lens.std()

106.85184765010301

In [169]:
max_lens.max()

1023

In [170]:
max_lens.min()

277

In [8]:
import os
import numpy as np
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from ccc_dataset import ParquetDataset
from sampler import BatchShuffleSampler
from datasets import load_dataset

In [2]:
tokenizer = Tokenizer.from_file("/data/evan/CS285_Final_Project/model/tokenizer.model")


In [49]:
dataset = load_dataset("/data/evan/CS285_Final_Project/data/", data_files=["/data/evan/CS285_Final_Project/data/ccc_processed.parquet"])

ds = dataset['train'].with_format('torch', columns=['seq_in', 'seq_out'])
print("hit2")
train_loader = DataLoader(ds, batch_size=4, pin_memory=True, pin_memory_device='cpu', sampler=BatchShuffleSampler(ds, 4))

tokenizer = Tokenizer.from_file("/data/evan/CS285_Final_Project/model/tokenizer.model")
tokenizer.enable_truncation(768 + 1)

Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 3650.40it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 451.73it/s]
Generating train split: 12044268 examples [01:10, 170381.20 examples/s]


hit2




In [32]:
EOS = '<eos>'


In [33]:
train_iter = iter(train_loader)

#

In [34]:
batch = next(train_iter)
batch

{'seq_in': ['<w>e2-e4<b>c7-c5<w>Ng1-f3<b>Nb8-c6<w>Bf1-b5<b>g7-g6<w>O-O<b>Bf8-g7<w>c2-c3<b>Ng8-f6<w>Rf1-e1<b>O-O<w>h2-h3<b>Qd8-b6<w>Bb5-a4<b>d7-d5<w>e4-e5<b>Nf6-d7<w>d2-d4<b>Rf8-d8<w>a2-a3<b>Qb6-c7<w>Bc1-g5<b>Nd7-f8<w>b2-b4<b>Nf8-e6<w>b4xc5<b>Ne6xg5<w>Nf3xg5<b>h7-h6<w>Ng5-f3<b>Nc6-a5<w>Nb1-d2<b>Ra8-b8<w>e5-e6<b>Bc8xe6<w>Re1xe6<b>f7xe6<w>Nf3-h4<b>g6-g5<w>Nh4-g6<b>e6-e5<w>Qd1-g4<b>e7-e6<w>Nd2-f3<b>e5-e4<w>Nf3-e5<b>Bg7xe5<w>Ng6xe5<b>Qc7-e7<w>Qg4-h5<b>Qe7-g7<w>Ra1-e1<b>',
  '<w>d2-d4<b>d7-d6<w>e2-e4<b>Ng8-f6<w>Nb1-c3<b>g7-g6<w>Bc1-e3<b>Nb8-d7<w>f2-f4<b>Nd7-b6<w>Ng1-f3<b>Bf8-g7<w>a2-a4<b>O-O<w>a4-a5<b>Nb6-d7<w>Bf1-c4<b>c7-c6<w>e4-e5<b>Nf6-g4<w>Be3-g1<b>b7-b5<w>a5xb6<b>Qd8xb6<w>b2-b3<b>Qb6-b4<w>Qd1-d2<b>Nd7-b6<w>Bc4-d3<b>f7-f6<w>h2-h3<b>Ng4-h6<w>e5xd6<b>e7xd6<w>Bg1-f2<b>d6-d5<w>O-O<b>a7-a5<w>g2-g4<b>Rf8-e8<w>Rf1-e1<b>Bc8-e6<w>Ra1-a2<b>Nh6-f7<w>Re1-a1<b>Nf7-d6<w>Bf2-e1<b>Nb6-c8<w>Nc3-a4<b>Qb4xd2<w>Be1xd2<b>Nd6-e4<w>Na4-c5<b>',
  '<w>e2-e4<b>c7-c5<w>Ng1-f3<b>d7-d6<w>d2-d4<b>c5xd4<w>Nf3xd4<b>Ng8

In [45]:
seq_ins = batch['seq_in']
seq_outs = batch['seq_out']

last_game_in_batch = seq_ins[-1]

last_prompt_tokens = len(tokenizer.encode(last_game_in_batch))

last_prompt_tokens

107

In [48]:
list(map(lambda t: "".join(t) + EOS, zip(seq_ins, seq_outs)))

['<w>e2-e4<b>c7-c5<w>Ng1-f3<b>Nb8-c6<w>Bf1-b5<b>g7-g6<w>O-O<b>Bf8-g7<w>c2-c3<b>Ng8-f6<w>Rf1-e1<b>O-O<w>h2-h3<b>Qd8-b6<w>Bb5-a4<b>d7-d5<w>e4-e5<b>Nf6-d7<w>d2-d4<b>Rf8-d8<w>a2-a3<b>Qb6-c7<w>Bc1-g5<b>Nd7-f8<w>b2-b4<b>Nf8-e6<w>b4xc5<b>Ne6xg5<w>Nf3xg5<b>h7-h6<w>Ng5-f3<b>Nc6-a5<w>Nb1-d2<b>Ra8-b8<w>e5-e6<b>Bc8xe6<w>Re1xe6<b>f7xe6<w>Nf3-h4<b>g6-g5<w>Nh4-g6<b>e6-e5<w>Qd1-g4<b>e7-e6<w>Nd2-f3<b>e5-e4<w>Nf3-e5<b>Bg7xe5<w>Ng6xe5<b>Qc7-e7<w>Qg4-h5<b>Qe7-g7<w>Ra1-e1<b>Rd8-f8<w>f2-f3<b>e4xf3<w>g2xf3<b>Na5-c6<w>Ne5-g4<b>Rf8-f6<w>Re1-b1<b>e6-e5<w>Ng4xf6+<b>Qg7xf6<w>h3-h4<b>e5xd4<w>Ba4xc6<b>Qf6xc6<w>c3xd4<b>Qc6-f6<w>Kg1-g2<b>b7-b6<w>h4xg5<b>Qf6xg5+<w>Qh5xg5+<b>h6xg5<w>Kg2-g3<b>Kg8-f8<w>Kg3-g4<b>Kf8-e7<w>Rb1-e1+<b>Ke7-d7<w>Re1-e5<b>Kd7-c6<w>Re5-e6+<b>Kc6-c7<w>c5xb6+<b>a7xb6<w>Kg4xg5<b>Rb8-g8+<w>Kg5-f5<b>Rg8-f8+<w>Re6-f6<b>Rf8-e8<w>a3-a4<eos>',
 '<w>d2-d4<b>d7-d6<w>e2-e4<b>Ng8-f6<w>Nb1-c3<b>g7-g6<w>Bc1-e3<b>Nb8-d7<w>f2-f4<b>Nd7-b6<w>Ng1-f3<b>Bf8-g7<w>a2-a4<b>O-O<w>a4-a5<b>Nb6-d7<w>Bf1-c4<b>c7-c6<w>e4-e5<b>

In [38]:
max_prompt_length

107

In [41]:
dir(last_prompt_tokens)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'attention_mask',
 'char_to_token',
 'char_to_word',
 'ids',
 'merge',
 'n_sequences',
 'offsets',
 'overflowing',
 'pad',
 'sequence_ids',
 'set_sequence_id',
 'special_tokens_mask',
 'token_to_chars',
 'token_to_sequence',
 'token_to_word',
 'tokens',
 'truncate',
 'type_ids',
 'word_ids',
 'word_to_chars',
 'word_to_tokens',
 'words']