In [19]:
import os
import sys

(parent_folder_path, current_dir) = os.path.split(os.path.abspath(''))
sys.path.append(parent_folder_path)
sys.path.append(os.path.join(parent_folder_path, 'equities'))

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"


from pathlib import Path
from typing import Optional
import random
import numpy as np
import pandas as pd
import pickle as pkl
from tqdm import tqdm
from glob import glob
from decimal import Decimal
from matplotlib import pyplot as plt

from equities.data_processing.itch_encoding import Vocab, encode_msgs, decode_msg, encode_msg
from tokenizer.bpe_basic import BasicTokenizer


In [2]:
# load the data
train_data_dir = os.path.join(os.path.dirname(os.path.abspath('')), 'dataset/proc/ITCH/train/')
train_message_files = sorted(glob(str(train_data_dir) + '/*message*.npy'))[0]

proc_messages = np.array(np.load(train_message_files, mmap_mode='r'))
print("proc_messages.shape:", proc_messages.shape)

proc_messages

proc_messages.shape: (1586524, 18)


array([[       40,   9795465,         4, ...,     -9999,     -9999,
            -9999],
       [       40,   9806105,         1, ...,     -9999,     -9999,
            -9999],
       [       40,   9806141,         1, ...,     -9999,     -9999,
            -9999],
       ...,
       [       40, 327992593,         5, ...,     57599, 961935496,
            16027],
       [       40, 327992597,         1, ...,     -9999,     -9999,
            -9999],
       [       40, 327992857,         5, ...,     57599, 981953005,
            16528]])

In [3]:
test_messages = proc_messages[:1000]

# encode the messages
vocab = Vocab()
encoded_messages = encode_msgs(test_messages, vocab.ENCODING)

# add a "end of message" token to the end of each message by appending a new column
eom_token_val = 0
encoded_messages = np.concatenate([encoded_messages, np.full((encoded_messages.shape[0], 1), eom_token_val)], axis=1)

print("encoded_messages.shape:", encoded_messages.shape)

encoded_messages

encoded_messages.shape: (1000, 25)


array([[12051,  1006, 12011, ...,     2,     2,     0],
       [12051,  1003, 12010, ...,     2,     2,     0],
       [12051,  1003, 12011, ...,     2,     2,     0],
       ...,
       [12051,  1003, 12011, ...,     2,     2,     0],
       [12051,  1003, 12010, ...,     2,     2,     0],
       [12051,  1003, 12010, ...,     2,     2,     0]])

In [4]:
for msg in encoded_messages:
    print(msg)
    for pair in zip(msg, msg[1:]):
        print(pair)
    break

[12051  1006 12011 12009 11013  1108  1008     3     3    13   633    37
   203     3   259    29     2     2     2     2     2     2     2     2
     0]
(12051, 1006)
(1006, 12011)
(12011, 12009)
(12009, 11013)
(11013, 1108)
(1108, 1008)
(1008, 3)
(3, 3)
(3, 13)
(13, 633)
(633, 37)
(37, 203)
(203, 3)
(3, 259)
(259, 29)
(29, 2)
(2, 2)
(2, 2)
(2, 2)
(2, 2)
(2, 2)
(2, 2)
(2, 2)
(2, 0)


In [5]:
def get_stats(encoded_messages):
    counts = {}
    for msg in encoded_messages:   
        for pair in zip(msg, msg[1:]): # Pythonic way to iterate consecutive elements
            counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(encoded_messages)
stats

# # get the statistics of the encoded messages by reading each message one by one
# message_stats = []
# for i in range(encoded_messages.shape[0]):
#     stats = get_stats(encoded_messages[i])
#     message_stats.append(stats)


{(12051, 1006): 99,
 (1006, 12011): 54,
 (12011, 12009): 485,
 (12009, 11013): 28,
 (11013, 1108): 21,
 (1108, 1008): 86,
 (1008, 3): 188,
 (3, 3): 1151,
 (3, 13): 19,
 (13, 633): 1,
 (633, 37): 2,
 (37, 203): 837,
 (203, 3): 1,
 (3, 259): 2,
 (259, 29): 1,
 (29, 2): 2,
 (2, 2): 5285,
 (2, 0): 755,
 (12051, 1003): 742,
 (1003, 12010): 388,
 (12010, 12008): 483,
 (12008, 12007): 115,
 (12007, 1108): 39,
 (1108, 2): 246,
 (2, 3): 745,
 (3, 8): 25,
 (8, 833): 1,
 (833, 209): 1,
 (209, 37): 1,
 (203, 9): 2,
 (9, 89): 1,
 (89, 235): 1,
 (235, 2): 1,
 (1003, 12011): 354,
 (12009, 12007): 71,
 (3, 22): 3,
 (22, 163): 1,
 (163, 37): 2,
 (9, 108): 1,
 (108, 395): 1,
 (395, 2): 2,
 (1006, 12010): 45,
 (12008, 11137): 1,
 (11137, 2908): 1,
 (2908, 1108): 1,
 (1108, 3): 1,
 (3, 4): 133,
 (4, 167): 1,
 (167, 528): 1,
 (528, 37): 2,
 (203, 10): 2,
 (10, 272): 1,
 (272, 920): 1,
 (920, 2): 2,
 (12008, 11737): 1,
 (11737, 1208): 1,
 (1208, 1008): 1,
 (3, 208): 2,
 (208, 195): 1,
 (195, 37): 3,
 (10, 4

In [6]:
top_pair = max(stats, key=stats.get)
top_pair

(2, 2)

In [7]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99)) # test

new_msgs = []
for msg in encoded_messages:
    new_msg = merge(msg, top_pair, len(vocab)+1)
    new_msgs.append(new_msg)

new_msgs

[[12051,
  1006,
  12011,
  12009,
  11013,
  1108,
  1008,
  3,
  3,
  13,
  633,
  37,
  203,
  3,
  259,
  29,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1003,
  12010,
  12008,
  12007,
  1108,
  2,
  3,
  8,
  833,
  209,
  37,
  203,
  9,
  89,
  235,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1003,
  12011,
  12009,
  12007,
  1108,
  2,
  3,
  3,
  22,
  163,
  37,
  203,
  9,
  108,
  395,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1006,
  12010,
  12008,
  11137,
  2908,
  1108,
  3,
  4,
  167,
  528,
  37,
  203,
  10,
  272,
  920,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1006,
  12010,
  12008,
  11737,
  1208,
  1008,
  3,
  3,
  208,
  195,
  37,
  203,
  10,
  478,
  112,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1003,
  12011,
  12009,
  11017,
  1108,
  2,
  3,
  10,
  931,
  433,
  37,
  203,
  18,
  406,
  542,
  12516,
  12516,
  12516,
  12516,
  0],
 [12051,
  1003,
  12010,
  12008,
  11991,
  1108,
  2,
  3,

In [8]:
# get a freq count of each message length in the new messages list using a dictionary
msg_lens = {}
for msg in new_msgs:
    msg_len = len(msg)
    msg_lens[msg_len] = msg_lens.get(msg_len, 0) + 1

print(msg_lens)

# total number of tokens in the new messages
total_tokens = sum([k*v for k, v in msg_lens.items()])
print("total_tokens:", total_tokens)

{21: 755, 25: 245}
total_tokens: 21980


In [9]:
# -------------------------------------------
vocab_size = 12544 # the desired final vocabulary size
num_merges = vocab_size - len(vocab) # the number of merges to perform
ids = [msg for msg in encoded_messages] # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = len(vocab) + i
  print(f"merging {pair} into a new token {idx}")
  new_msgs = []
  for msg in ids:
      new_msg = merge(msg, pair, idx)
      new_msgs.append(new_msg)
  ids = new_msgs
  merges[pair] = idx

merging (2, 2) into a new token 12515
merging (12515, 12515) into a new token 12516
merging (3, 3) into a new token 12517
merging (37, 203) into a new token 12518
merging (12516, 12516) into a new token 12519
merging (12519, 0) into a new token 12520
merging (12051, 1003) into a new token 12521
merging (2, 12517) into a new token 12522
merging (12011, 12009) into a new token 12523
merging (12010, 12008) into a new token 12524
merging (37, 204) into a new token 12525
merging (12521, 12524) into a new token 12526
merging (12521, 12523) into a new token 12527
merging (12518, 969) into a new token 12528
merging (1108, 12522) into a new token 12529
merging (12051, 1004) into a new token 12530
merging (1008, 12517) into a new token 12531
merging (2, 3) into a new token 12532
merging (12526, 12007) into a new token 12533
merging (12517, 12518) into a new token 12534
merging (12051, 1006) into a new token 12535
merging (12530, 12523) into a new token 12536
merging (303, 426) into a new token 1

In [10]:
# number of tokens in original messages (encoded_messages)
orig_tokens_len = sum([len(msg) for msg in encoded_messages])
print("orig_tokens_len:", orig_tokens_len)

# number of tokens in the new messages
new_tokens_len = sum([len(msg) for msg in ids])
print("new_tokens_len:", new_tokens_len)

# compression ratio
compression_ratio = orig_tokens_len / new_tokens_len
print(f"compression ratio: {compression_ratio:.2f}X")

ids

orig_tokens_len: 25000
new_tokens_len: 11956
compression ratio: 2.09X


[[12535, 12523, 11013, 12543, 13, 633, 12518, 3, 259, 29, 12520],
 [12533, 12540, 8, 833, 209, 12518, 9, 89, 235, 12520],
 [12541, 12529, 22, 163, 12518, 9, 108, 395, 12520],
 [12535, 12524, 11137, 2908, 1108, 3, 4, 167, 528, 12518, 10, 272, 920, 12520],
 [12535, 12524, 11737, 1208, 12531, 208, 195, 12518, 10, 478, 112, 12520],
 [12527, 11017, 12540, 10, 931, 433, 12518, 18, 406, 542, 12520],
 [12526, 11991, 12540, 46, 344, 844, 12518, 61, 748, 383, 12520],
 [12527, 11992, 12529, 29, 55, 12518, 61, 774, 435, 12520],
 [12527, 11013, 12540, 12, 672, 769, 12518, 71, 444, 201, 12520],
 [12535,
  12523,
  11013,
  12543,
  500,
  617,
  12518,
  71,
  941,
  815,
  12009,
  11013,
  12539,
  71,
  444,
  201,
  0],
 [12535,
  12523,
  11017,
  12543,
  14,
  328,
  12518,
  71,
  953,
  140,
  12009,
  11017,
  12539,
  18,
  406,
  542,
  0],
 [12541, 1258, 12532, 4, 61, 977, 12518, 73, 12, 114, 12520],
 [12527, 11013, 12529, 415, 552, 12518, 73, 424, 663, 12520],
 [12526, 11025, 1020, 125

In [11]:
# get a freq count of each message length in the new messages list using a dictionary
msg_lens = {}
for msg in ids:
    msg_len = len(msg)
    msg_lens[msg_len] = msg_lens.get(msg_len, 0) + 1

msg_lens

{11: 245,
 10: 299,
 9: 62,
 14: 18,
 12: 56,
 17: 73,
 13: 10,
 20: 26,
 18: 32,
 21: 15,
 22: 1,
 19: 29,
 8: 41,
 6: 36,
 16: 35,
 15: 22}

### Decoding

In [12]:
vocab.ENCODING

{'time': (array([-10000, -20000,  -9999, ...,    997,    998,    999], dtype=int32),
  array([   0,    1,    2, ..., 1000, 1001, 1002], dtype=int32)),
 'type': (array([-10000, -20000,  -9999,      1,      2,      3,      4,      5],
        dtype=int32),
  array([   0,    1,    2, 1003, 1004, 1005, 1006, 1007], dtype=int32)),
 'size': (array([-10000, -20000,  -9999, ...,   9997,   9998,   9999], dtype=int32),
  array([    0,     1,     2, ..., 11005, 11006, 11007], dtype=int32)),
 'price': (array([-10000, -20000,  -9999, ...,    997,    998,    999], dtype=int32),
  array([    0,     1,     2, ..., 12005, 12006, 12007], dtype=int32)),
 'sign': (array([-10000, -20000,  -9999,     -1,      1], dtype=int32),
  array([    0,     1,     2, 12008, 12009], dtype=int32)),
 'side': (array([-10000, -20000,  -9999,      0,      1], dtype=int32),
  array([    0,     1,     2, 12010, 12011], dtype=int32)),
 'ticker': (array([-10000, -20000,  -9999,      1,      2,      3,      4,      5,
          

In [13]:
# fill out new byte-pair encoding vocab
bpe_vocab = {}

# # print each key in vocab.ENCODING
# for key in vocab.ENCODING:
#     key_enc = vocab.ENCODING[key]
#     for i in range(len(key_enc[1])):
#         bpe_vocab[key_enc[1][i]] = key_enc[0][i]

# print("Length before adding merges:", len(bpe_vocab))

# add new merges to the bpe vocab
for (p0, p1), idx in merges.items():
    bpe_vocab[idx] = (p0, p1)

# print("Length after adding merges:", len(bpe_vocab))
print("Length of bpe_vocab:", len(bpe_vocab)) # should be equal to num_merges

bpe_vocab

Length of bpe_vocab: 29


{12515: (2, 2),
 12516: (12515, 12515),
 12517: (3, 3),
 12518: (37, 203),
 12519: (12516, 12516),
 12520: (12519, 0),
 12521: (12051, 1003),
 12522: (2, 12517),
 12523: (12011, 12009),
 12524: (12010, 12008),
 12525: (37, 204),
 12526: (12521, 12524),
 12527: (12521, 12523),
 12528: (12518, 969),
 12529: (1108, 12522),
 12530: (12051, 1004),
 12531: (1008, 12517),
 12532: (2, 3),
 12533: (12526, 12007),
 12534: (12517, 12518),
 12535: (12051, 1006),
 12536: (12530, 12523),
 12537: (303, 426),
 12538: (12537, 119),
 12539: (1108, 12518),
 12540: (1108, 12532),
 12541: (12527, 12007),
 12542: (12522, 12534),
 12543: (1108, 12531)}

In [14]:
bpe_vocab[12515]

(2, 2)

In [15]:
def bpe_decode(tokens, vocab):
    # decode the tokens into a message
    msg = []
    for token in tokens:
        if token >= len(vocab):
            # get the corresponding byte-pair encoding
            bpe = bpe_vocab[token]
            # append both elements
            msg.extend(bpe)
        else:
            msg.append(token)

    # check that all tokens in message are valid
    valid = all(token < len(vocab) for token in msg)
    while not valid:
        new_msg = []
        for token in msg:
            if token >= len(vocab):
                # get the corresponding byte-pair encoding
                bpe = bpe_vocab[token]
                # append both elements
                new_msg.extend(bpe)
            else:
                new_msg.append(token)
        msg = new_msg
        valid = all(token < len(vocab) for token in msg)
    return msg

example_tokens = ids[0]
print(example_tokens)

decoded_msg = bpe_decode(example_tokens, vocab)
decoded_msg = np.array(decoded_msg)
print(decoded_msg)
print(encoded_messages[0])

# decode the message back to the original message
decoded_msg = decode_msg(np.array(decoded_msg[-25:-1].tolist()), vocab.ENCODING)
print(decoded_msg)

[12535, 12523, 11013, 12543, 13, 633, 12518, 3, 259, 29, 12520]
[12051  1006 12011 12009 11013  1108  1008     3     3    13   633    37
   203     3   259    29     2     2     2     2     2     2     2     2
     0]
[12051  1006 12011 12009 11013  1108  1008     3     3    13   633    37
   203     3   259    29     2     2     2     2     2     2     2     2
     0]
[    40  -9999      4      1  -9999      5    100      0      0  10630
  34200 256026  -9999  -9999  -9999  -9999  -9999  -9999]


### Encoding

In [16]:
def get_stats_single(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

merges

{(2, 2): 12515,
 (12515, 12515): 12516,
 (3, 3): 12517,
 (37, 203): 12518,
 (12516, 12516): 12519,
 (12519, 0): 12520,
 (12051, 1003): 12521,
 (2, 12517): 12522,
 (12011, 12009): 12523,
 (12010, 12008): 12524,
 (37, 204): 12525,
 (12521, 12524): 12526,
 (12521, 12523): 12527,
 (12518, 969): 12528,
 (1108, 12522): 12529,
 (12051, 1004): 12530,
 (1008, 12517): 12531,
 (2, 3): 12532,
 (12526, 12007): 12533,
 (12517, 12518): 12534,
 (12051, 1006): 12535,
 (12530, 12523): 12536,
 (303, 426): 12537,
 (12537, 119): 12538,
 (1108, 12518): 12539,
 (1108, 12532): 12540,
 (12527, 12007): 12541,
 (12522, 12534): 12542,
 (1108, 12531): 12543}

In [17]:
def bpe_encode(msg):
    ids = list(msg)
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats_single(ids)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = merges[pair]
        ids = merge(ids, pair, idx)
    return ids

# compute first phase encoding of example message (before merges)
example_encoding = encode_msg(decoded_msg, vocab.ENCODING)
print(example_encoding)

ids = bpe_encode(example_encoding)
print(ids)
print(example_tokens)

[12051  1006 12011 12009 11013  1108  1008     3     3    13   633    37
   203     3   259    29     2     2     2     2     2     2     2     2]
[12535, 12523, 11013, 12543, 13, 633, 12518, 3, 259, 29, 12519]
[12535, 12523, 11013, 12543, 13, 633, 12518, 3, 259, 29, 12520]


In [21]:
print(bpe_decode(ids, vocab))
print(bpe_decode(example_tokens, vocab))

[12051, 1006, 12011, 12009, 11013, 1108, 1008, 3, 3, 13, 633, 37, 203, 3, 259, 29, 2, 2, 2, 2, 2, 2, 2, 2]
[12051, 1006, 12011, 12009, 11013, 1108, 1008, 3, 3, 13, 633, 37, 203, 3, 259, 29, 2, 2, 2, 2, 2, 2, 2, 2, 0]


### Using Tokenizer Class

In [24]:
# instantiate
bpe_tokenizer = BasicTokenizer()

# prepare dataset
messages = encode_msgs(proc_messages, vocab.ENCODING)

# # add a "end of message" token to the end of each message by appending a new column
# eom_token_val = 0
# messages = np.concatenate([messages, np.full((messages.shape[0], 1), eom_token_val)], axis=1)

print("messages.shape:", messages.shape)

messages

messages.shape: (1586524, 24)


array([[12051,  1006, 12011, ...,     2,     2,     2],
       [12051,  1003, 12010, ...,     2,     2,     2],
       [12051,  1003, 12011, ...,     2,     2,     2],
       ...,
       [12051,  1007, 12010, ...,   964,   938,   499],
       [12051,  1003, 12010, ...,     2,     2,     2],
       [12051,  1007, 12011, ...,   984,   956,     8]], dtype=int32)

In [25]:
# train
bpe_tokenizer.train(messages, 12544, vocab, verbose=True)

merge 1/29: (2, 2) -> 12515 ((2, 2)) had 5180819 occurrences
merge 2/29: (12515, 12515) -> 12516 ((12515, 12515)) had 2220351 occurrences
merge 3/29: (3, 3) -> 12517 ((3, 3)) had 1205701 occurrences
merge 4/29: (12010, 12008) -> 12518 ((12010, 12008)) had 786733 occurrences
merge 5/29: (12011, 12009) -> 12519 ((12011, 12009)) had 766025 occurrences
merge 6/29: (12516, 12516) -> 12520 ((12516, 12516)) had 740117 occurrences
merge 7/29: (12051, 1003) -> 12521 ((12051, 1003)) had 739582 occurrences
merge 8/29: (12051, 1006) -> 12522 ((12051, 1006)) had 638282 occurrences
merge 9/29: (2, 12517) -> 12523 ((2, 12517)) had 608509 occurrences
merge 10/29: (1008, 12517) -> 12524 ((1008, 12517)) had 406446 occurrences
merge 11/29: (12521, 12518) -> 12525 ((12521, 12518)) had 374236 occurrences
merge 12/29: (1108, 12523) -> 12526 ((1108, 12523)) had 357606 occurrences
merge 13/29: (12521, 12519) -> 12527 ((12521, 12519)) had 344281 occurrences
merge 14/29: (12522, 12518) -> 12528 ((12522, 12518))

In [26]:
bpe_tokenizer.merges
bpe_tokenizer.bpe_vocab

{12515: (2, 2),
 12516: (12515, 12515),
 12517: (3, 3),
 12518: (12010, 12008),
 12519: (12011, 12009),
 12520: (12516, 12516),
 12521: (12051, 1003),
 12522: (12051, 1006),
 12523: (2, 12517),
 12524: (1008, 12517),
 12525: (12521, 12518),
 12526: (1108, 12523),
 12527: (12521, 12519),
 12528: (12522, 12518),
 12529: (12522, 12519),
 12530: (1008, 3),
 12531: (11009, 1108),
 12532: (2, 3),
 12533: (1108, 12524),
 12534: (12051, 1007),
 12535: (12008, 11008),
 12536: (1108, 12532),
 12537: (1108, 12530),
 12538: (12009, 12531),
 12539: (12051, 1004),
 12540: (12008, 12531),
 12541: (12527, 11009),
 12542: (12525, 11009),
 12543: (12534, 12519)}

In [27]:
# test encoding and decoding capabilities
example_msg = messages[0]
print("example_msg:", example_msg)

# encode
encoded_msg = bpe_tokenizer.bpe_encode(example_msg)
print("encoded_msg:", encoded_msg)

# decode
decoded_msg = bpe_tokenizer.bpe_decode(encoded_msg, vocab)
print("decoded_msg:", decoded_msg)


example_msg: [12051  1006 12011 12009 11013  1108  1008     3     3    13   633    37
   203     3   259    29     2     2     2     2     2     2     2     2]
encoded_msg: [12529, 11013, 12533, 13, 633, 37, 203, 3, 259, 29, 12520]
decoded_msg: [12051, 1006, 12011, 12009, 11013, 1108, 1008, 3, 3, 13, 633, 37, 203, 3, 259, 29, 2, 2, 2, 2, 2, 2, 2, 2]


In [28]:
# compress the dataset
compressed_msgs = []
for msg in messages:
    compressed_msg = bpe_tokenizer.bpe_encode(msg)
    compressed_msgs.append(compressed_msg)

# number of tokens in the original messages
orig_tokens_len = sum([len(msg) for msg in messages])
print("orig_tokens_len:", orig_tokens_len)

# number of tokens in the compressed messages
compressed_tokens_len = sum([len(msg) for msg in compressed_msgs])
print("compressed_tokens_len:", compressed_tokens_len)

# compression ratio
compression_ratio = orig_tokens_len / compressed_tokens_len
print(f"compression ratio: {compression_ratio:.2f}X")

orig_tokens_len: 38076576
compressed_tokens_len: 24235883
compression ratio: 1.57X


In [29]:
compressed_msgs

# TODO: pad(?) and then save these

# I could bpe_encode these messages dynamically when I load them in the future (i.e at train time)
# this also means I don't need to save the new bpe_encoded dataset

[[12529, 11013, 12533, 13, 633, 37, 203, 3, 259, 29, 12520],
 [12525, 12007, 12536, 8, 833, 209, 37, 203, 9, 89, 235, 12520],
 [12527, 12007, 12526, 22, 163, 37, 203, 9, 108, 395, 12520],
 [12528, 11137, 2908, 1108, 3, 4, 167, 528, 37, 203, 10, 272, 920, 12520],
 [12528, 11737, 1208, 12524, 208, 195, 37, 203, 10, 478, 112, 12520],
 [12527, 11017, 12536, 10, 931, 433, 37, 203, 18, 406, 542, 12520],
 [12525, 11991, 12536, 46, 344, 844, 37, 203, 61, 748, 383, 12520],
 [12527, 11992, 12526, 29, 55, 37, 203, 61, 774, 435, 12520],
 [12527, 11013, 12536, 12, 672, 769, 37, 203, 71, 444, 201, 12520],
 [12529,
  11013,
  12533,
  500,
  617,
  37,
  203,
  71,
  941,
  815,
  12009,
  11013,
  1108,
  37,
  203,
  71,
  444,
  201],
 [12529,
  11017,
  12533,
  14,
  328,
  37,
  203,
  71,
  953,
  140,
  12009,
  11017,
  1108,
  37,
  203,
  18,
  406,
  542],
 [12527, 12007, 1258, 12532, 4, 61, 977, 37, 203, 73, 12, 114, 12520],
 [12527, 11013, 12526, 415, 552, 37, 203, 73, 424, 663, 12520],

In [30]:
# save the tokenizer
with open('bpe_tokenizer.pkl', 'wb') as f:
    pkl.dump(bpe_tokenizer, f)


#### Using Tokenizer in Training Loop

In [31]:
import torch

In [32]:
train_data_dir = parent_folder_path + '/dataset/proc/ITCH/train/'
train_message_files = sorted(glob(str(train_data_dir) + '/*message*.npy'))
assert len(train_message_files) > 0, f'no message files found in {train_data_dir}'

val_data_dir = parent_folder_path + '/dataset/proc/ITCH/val/'
val_message_files = sorted(glob(str(val_data_dir) + '/*message*.npy'))
assert len(val_message_files) > 0, f'no message files found in {val_data_dir}'

print("len(train_message_files):", len(train_message_files))
print("len(val_message_files):", len(val_message_files))

# set seed
seed = 42
rng = random.Random(seed)

# load a list with all the train message files
train_datasets = []
for file in train_message_files:
    train_datasets.append(np.load(file, mmap_mode='r'))

# load a list with all the val message files
val_datasets = []
for file in val_message_files:
    val_datasets.append(np.load(file, mmap_mode='r'))

# print shape of each dataset in train_datasets
for dataset in train_datasets:
    print(dataset.shape)

train_datasets

len(train_message_files): 6
len(val_message_files): 1
(1586524, 18)
(1942488, 18)
(1953936, 18)
(1206471, 18)
(727346, 18)
(1388800, 18)


[memmap([[       40,   9795465,         4, ...,     -9999,     -9999,
              -9999],
         [       40,   9806105,         1, ...,     -9999,     -9999,
              -9999],
         [       40,   9806141,         1, ...,     -9999,     -9999,
              -9999],
         ...,
         [       40, 327992593,         5, ...,     57599, 961935496,
              16027],
         [       40, 327992597,         1, ...,     -9999,     -9999,
              -9999],
         [       40, 327992857,         5, ...,     57599, 981953005,
              16528]]),
 memmap([[       40,  10134653,         1, ...,     -9999,     -9999,
              -9999],
         [       40,  10140325,         1, ...,     -9999,     -9999,
              -9999],
         [       40,  10140337,         1, ...,     -9999,     -9999,
              -9999],
         ...,
         [       40, 387932841,         1, ...,     -9999,     -9999,
              -9999],
         [       40, 387932961,         5, ...,   

In [77]:
device = 'cuda' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
sample_dataset = rng.choice(train_datasets)
print("train_data.shape:", sample_dataset.shape)
vocab = Vocab()
eom_token_val = 0
msg_seq_len = 432 # 112 # 432
batch_size = 1
use_sink = True
use_bpe = True

# load the tokenizer
with open('bpe_tokenizer.pkl', 'rb') as f:
    bpe_tokenizer = pkl.load(f)

train_data.shape: (1942488, 18)


In [78]:
# poor man's data loader
def get_batch(split):
    # data = train_data if split == 'train' else val_data
    datasets = train_datasets if split == 'train' else val_datasets
    data = rng.choice(datasets)
    ix = torch.randint(len(data) - msg_seq_len, (batch_size,))
    if use_bpe:
        assert batch_size == 1, "batch size must be 1 for BPE encoding (for now)" # TODO: make this batch-wise
        # basic encoding of the messages
        basic_encoded = [(encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)) for i in ix]
        # bpe encode the messages and concat EOM token to the end
        bpe_encoded = []
        for batch in range(len(basic_encoded)):
            for msg in range(len(basic_encoded[batch])):
                bpe_encoded = bpe_encoded + bpe_tokenizer.bpe_encode(basic_encoded[batch][msg]) + [eom_token_val]
        # convert to tensor, unsqueeze to add batch dimension
        x = torch.tensor(bpe_encoded).unsqueeze(0)
    else:
        x = torch.stack([torch.from_numpy((encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)).reshape(-1)) for i in ix])
    if use_sink:
        # append sink token to start of each batch sequence (since vocab.SINK_TOK = 1, we can just use torch.ones)
        x = torch.cat([torch.ones((batch_size, 1), dtype=torch.int), x], dim=1)
    # target y is the same as x but shifted by one token
    y = x[:, 1:]
    y = y.type(torch.LongTensor) # casting to long for cross entropy loss fn
    x = x[:, :-1] # offset x by one (final) token to match y
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [34]:
data = rng.choice(train_datasets)
ix = torch.randint(len(data) - msg_seq_len, (batch_size,))

In [66]:
# torch.stack([torch.from_numpy((encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)).reshape(-1)) for i in ix])
# [torch.from_numpy((encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)).reshape(-1)) for i in ix]
basic_encoded = [(encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)) for i in ix]

# bpe encode the messages and concat EOM token to the end
assert len(basic_encoded) == 1
bpe_encoded = []
for batch in range(len(basic_encoded)):
    for msg in range(len(basic_encoded[batch])):
        bpe_encoded = bpe_encoded + bpe_tokenizer.bpe_encode(basic_encoded[batch][msg]) + [eom_token_val]
        # bpe_encoded = bpe_encoded + [bpe_tokenizer.bpe_encode(basic_encoded[batch][msg]) + [eom_token_val]]
        # bpe_encoded.append(bpe_tokenizer.bpe_encode(basic_encoded[batch][msg]) + [eom_token_val])

bpe_encoded

# convert to tensor, unsqueeze to add batch dimension
x = torch.tensor(bpe_encoded).unsqueeze(0)
print("x.shape:", x.shape)


# x = torch.stack([torch.from_numpy((encode_msgs((data[i:i+msg_seq_len]).astype(np.int64), vocab.ENCODING)).reshape(-1)) for i in ix])

x.shape: torch.Size([1, 1790])


In [53]:
# np.full((1, 1), eom_token_val)
# bpe_tokenizer.bpe_encode(basic_encoded[batch][msg]) + [eom_token_val]

[12525, 11039, 12536, 4, 90, 268, 37, 821, 218, 86, 565, 12520, 0]

In [81]:
X, Y = get_batch('train') # fetch the very first batch
print("X.shape:", X.shape) # (batch_size, block_size)
print("Y.shape:", Y.shape)

X

X.shape: torch.Size([1, 6875])
Y.shape: torch.Size([1, 6875])


tensor([[    1, 12529, 11028,  ...,   956,    96,   985]], device='cuda:0')

In [83]:
X.shape[1] * 1.57 # close to 10368 ?

10793.75