# Splitting
This code is what I will be using to split the data into train test and validation sets. We will be doing it to get an even distribution along the depth of the code

In [6]:
# imports
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import HfApi
from datasets import load_dataset

import os
import random
import json

In [14]:
cwd = os.getcwd()
file_path = os.path.join(cwd, "data/stage1-mcts-refactored/stage1-mcts.parquet")
table = pq.read_table(file_path)


In [15]:
depth_counts = {0: 18, 1: 360, 2: 1329, 3: 4015, 4: 12504, 5: 37510, 6: 112584, 7: 128791, 8: 145823, 9: 157765, 10: 166838, 11: 166997, 12: 167103, 13: 167212, 14: 167349, 15: 167520, 16: 167626, 17: 167711, 18: 167835, 19: 167898, 20: 168031, 21: 168190, 22: 168351, 23: 168508, 24: 168646, 25: 168779, 26: 168927, 27: 169073, 28: 169222, 29: 169390, 30: 169524, 31: 169651, 32: 169766, 33: 169853, 34: 169919, 35: 169906, 36: 169875, 37: 169545, 38: 169225, 39: 168313, 40: 167381, 41: 165305, 42: 163596, 43: 159999, 44: 157046, 45: 151488, 46: 147262, 47: 139541, 48: 133786, 49: 124429, 50: 117479, 51: 107041, 52: 98868, 53: 87729, 54: 79050, 55: 68179, 56: 59294, 57: 49391, 58: 41341, 59: 32931, 60: 26411, 61: 19935, 62: 14964, 63: 10802, 64: 7684, 65: 5187, 66: 3397, 67: 2123, 68: 1294, 69: 753, 70: 424, 71: 218, 72: 108, 73: 47, 74: 19, 75: 5, 76: 1}

In [19]:
depth_col = table["depth"]

unique_depths = list(depth_counts.keys())

# this works as row indices, each index corresponds to a row in the table
row_indices = list(range(table.num_rows))
random.shuffle(row_indices)

depth_indices = {depth: [] for depth in unique_depths}
for i in row_indices:
    depth = depth_col[i].as_py()
    depth_indices[depth].append(i)


1
2
3


In [20]:
# validate the indices
print(depth_indices[0])
print(depth_indices[1])
print(depth_indices[2])

[0, 3, 12, 2, 6, 9, 5, 1, 16, 15, 13, 17, 14, 8, 10, 7, 11, 4]
[122, 221, 159, 191, 247, 169, 21, 35, 177, 20, 75, 19, 217, 49, 234, 348, 362, 315, 195, 97, 47, 287, 179, 55, 87, 121, 175, 116, 95, 196, 311, 305, 147, 91, 151, 189, 37, 78, 90, 190, 270, 126, 286, 325, 228, 118, 368, 231, 327, 45, 182, 365, 240, 252, 224, 33, 192, 63, 61, 70, 149, 62, 107, 41, 24, 213, 359, 299, 134, 370, 290, 125, 193, 343, 242, 330, 233, 57, 371, 109, 165, 40, 261, 374, 218, 74, 136, 337, 38, 65, 155, 56, 238, 227, 350, 264, 262, 319, 239, 265, 351, 232, 201, 369, 52, 260, 86, 104, 29, 188, 274, 367, 150, 375, 25, 103, 211, 110, 199, 148, 296, 22, 145, 84, 114, 320, 324, 210, 332, 297, 222, 212, 318, 279, 289, 80, 364, 141, 340, 23, 144, 309, 166, 66, 146, 372, 173, 248, 263, 275, 214, 259, 46, 241, 139, 361, 170, 200, 284, 163, 202, 58, 124, 342, 28, 373, 243, 83, 356, 308, 292, 267, 154, 72, 105, 339, 27, 181, 307, 172, 76, 312, 54, 226, 335, 18, 225, 132, 273, 357, 360, 316, 96, 352, 295, 301, 254,

In [21]:
test_split = 0.2
validation_split = 0.1

test_split_count = {d: int(depth_counts[d] * test_split) for d in depth_counts}
validation_split_count = {d: int(depth_counts[d] * validation_split) for d in depth_counts}

# 75 will have some special cases
test_split_count[75] = 1
validation_split_count[75] = 1

In [22]:
# set up indices lists
train_indices, val_indices, test_indices = [], [], []
for depth, indices in depth_indices.items():
    test_indices.extend(indices[:test_split_count[depth]])
    val_indices.extend(indices[test_split_count[depth]:test_split_count[depth] + validation_split_count[depth]])
    train_indices.extend(indices[test_split_count[depth] + validation_split_count[depth]:])

In [25]:
print(len(train_indices))
print(len(val_indices))
print(len(test_indices))

print(train_indices[:20])

5601458
800165
1600367
8001990
8001990
[6, 9, 5, 1, 16, 15, 13, 17, 14, 8, 10, 7, 11, 4, 29, 188, 274, 367, 150, 375]


In [26]:
# little shuffle for security
random.shuffle(train_indices)
random.shuffle(val_indices)
random.shuffle(test_indices)

train_table = table.take(train_indices)
val_table = table.take(val_indices)
test_table = table.take(test_indices)

In [27]:
# validating the split
index_train_count = {d: 0 for d in depth_counts}
index_val_count = {d: 0 for d in depth_counts}
index_test_count = {d: 0 for d in depth_counts}

for i in train_indices:
    index_train_count[depth_col[i].as_py()] += 1

for i in val_indices:
    index_val_count[depth_col[i].as_py()] += 1
    
for i in test_indices:
    index_test_count[depth_col[i].as_py()] += 1
    
for i in depth_counts:
    print(f"Depth {i}: {index_train_count[i]} {index_val_count[i]} {index_test_count[i]}")


Depth 0: 14 1 3
Depth 1: 252 36 72
Depth 2: 932 132 265
Depth 3: 2811 401 803
Depth 4: 8754 1250 2500
Depth 5: 26257 3751 7502
Depth 6: 78810 11258 22516
Depth 7: 90154 12879 25758
Depth 8: 102077 14582 29164
Depth 9: 110436 15776 31553
Depth 10: 116788 16683 33367
Depth 11: 116899 16699 33399
Depth 12: 116973 16710 33420
Depth 13: 117049 16721 33442
Depth 14: 117146 16734 33469
Depth 15: 117264 16752 33504
Depth 16: 117339 16762 33525
Depth 17: 117398 16771 33542
Depth 18: 117485 16783 33567
Depth 19: 117530 16789 33579
Depth 20: 117622 16803 33606
Depth 21: 117733 16819 33638
Depth 22: 117846 16835 33670
Depth 23: 117957 16850 33701
Depth 24: 118053 16864 33729
Depth 25: 118147 16877 33755
Depth 26: 118250 16892 33785
Depth 27: 118352 16907 33814
Depth 28: 118456 16922 33844
Depth 29: 118573 16939 33878
Depth 30: 118668 16952 33904
Depth 31: 118756 16965 33930
Depth 32: 118837 16976 33953
Depth 33: 118898 16985 33970
Depth 34: 118945 16991 33983
Depth 35: 118935 16990 33981
Depth 36:

In [29]:
pq.write_table(train_table, "data/stage1-mcts-refactored/train.parquet")
pq.write_table(test_table, "data/stage1-mcts-refactored/test.parquet")
pq.write_table(val_table, "data/stage1-mcts-refactored/validation.parquet")

In [30]:
# uploads
api = HfApi()
api.upload_file(
    path_or_fileobj='data/stage1-mcts-refactored/train.parquet',
    path_in_repo='data/train.parquet',
    repo_id='markstanl/u3t',
    repo_type='dataset'
)
api.upload_file(
    path_or_fileobj='data/stage1-mcts-refactored/test.parquet',
    path_in_repo='data/test.parquet',
    repo_id='markstanl/u3t',
    repo_type='dataset'
)
api.upload_file(
    path_or_fileobj='data/stage1-mcts-refactored/validation.parquet',
    path_in_repo='data/validation.parquet',
    repo_id='markstanl/u3t',
    repo_type='dataset'
)

train.parquet:   0%|          | 0.00/883M [00:00<?, ?B/s]

test.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/126M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/markstanl/u3t/commit/6ee2e7da4f8e0d640715e7960bac243244f8acc0', commit_message='Upload data/validation.parquet with huggingface_hub', commit_description='', oid='6ee2e7da4f8e0d640715e7960bac243244f8acc0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/markstanl/u3t', endpoint='https://huggingface.co', repo_type='dataset', repo_id='markstanl/u3t'), pr_revision=None, pr_num=None)

In [7]:
table = pq.read_table("data/stage1-mcts-refactored/train.parquet")
dataset = load_dataset('markstanl/u3t', split="train")

README.md:   0%|          | 0.00/5.95k [00:00<?, ?B/s]

train.parquet:   0%|          | 0.00/883M [00:00<?, ?B/s]

test.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/126M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5601458 [00:00<?, ? examples/s]

Failed to read file 'C:\Users\markt\.cache\huggingface\hub\datasets--markstanl--u3t\snapshots\fc6bdaeb4e0865fb420f2291466eb9e7103e0d8d\data\train.parquet' with error <class 'datasets.table.CastError'>: Couldn't cast
state: string
num_visits: int64
num_wins: int64
num_draws: int64
num_losses: int64
actions: list<element: struct<index: int64, num_draws: int64, num_losses: int64, num_wins: int64, symbol: int64>>
  child 0, element: struct<index: int64, num_draws: int64, num_losses: int64, num_wins: int64, symbol: int64>
      child 0, index: int64
      child 1, num_draws: int64
      child 2, num_losses: int64
      child 3, num_wins: int64
      child 4, symbol: int64
depth: int64
to
{'state': Value(dtype='string', id=None), 'num_visits': Value(dtype='int64', id=None), 'num_wins': Value(dtype='int64', id=None), 'num_draws': Value(dtype='int64', id=None), 'num_losses': Value(dtype='int64', id=None), 'actions': [{'index': Value(dtype='int64', id=None), 'num_draws': Value(dtype='int64', id

DatasetGenerationError: An error occurred while generating the dataset

In [5]:
print(table.schema)

state: string
num_visits: int64
num_wins: int64
num_draws: int64
num_losses: int64
actions: list<element: struct<index: int64, num_draws: int64, num_losses: int64, num_wins: int64, symbol: int64>>
  child 0, element: struct<index: int64, num_draws: int64, num_losses: int64, num_wins: int64, symbol: int64>
      child 0, index: int64
      child 1, num_draws: int64
      child 2, num_losses: int64
      child 3, num_wins: int64
      child 4, symbol: int64
depth: int64


In [None]:
print(dataset.features)