In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

(parent_folder_path, current_dir) = os.path.split(os.path.abspath(''))
sys.path.append(parent_folder_path)

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import os
# import pickle
import numpy as np
from contextlib import nullcontext
import torch

from equities.model import GPTConfig, GPT
from equities.data_processing import itch_encoding

In [3]:
# INIT PARAMS
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = parent_folder_path + '/out' # ignored if init_from is not 'resume'
dataset = '12302019.NASDAQ_ITCH50_AAPL_message_proc.npy' # dataset to use for initial prompt
# start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_context_msgs = 100 # 3 # number of messages from dataset to use as context
# num_samples = 10 # number of samples to draw
num_samples = 1 # number of samples to draw
# max_new_tokens = 500 # number of tokens generated in each sample
max_new_tokens = 1 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 42
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
# exec(open('equities/configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)


In [4]:
# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

model.eval()
model.to(device)

number of parameters: 94.57M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(12515, 768)
    (wpe): Embedding(10367, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=12515, bias=False)
)

In [5]:
# define path for sample data
# dataset = '12302019.NASDAQ_ITCH50_AAPL_message_proc.npy'
data_dir = os.path.join('dataset/proc/ITCH/test/', dataset)
data_dir = parent_folder_path + '/' + data_dir

# grab sample data to use as context
context_dataset = np.load(data_dir, mmap_mode='r')
X_raw = np.array(context_dataset[0:num_context_msgs])
print("X_raw.shape:", X_raw.shape)
print("X_raw:", X_raw)

# encode the sample data
vocab = itch_encoding.Vocab()
X = itch_encoding.encode_msgs(X_raw, vocab.ENCODING)
print("X.shape:", X.shape)
print("X:", X)

# decode the sample data (will be missing order id, price_abs, old_id, and old_price_abs)
print("decoded X:", itch_encoding.decode_msgs(X, vocab.ENCODING))
print([ "ticker", "NA_VAL",
        "event_type", "direction", "NA_VAL", "price", "fill_size", "remain_size",
        "delta_t_s", "delta_t_ns", "time_s", "time_ns",
        "NA_VAL", "price_ref", "fill_size_ref", "time_s_ref", "time_ns_ref", "NA_VAL"])

X_raw.shape: (100, 18)
X_raw: [[       40   7872401         1 ...     -9999     -9999     -9999]
 [       40   7872405         1 ...     -9999     -9999     -9999]
 [       40   7872421         1 ...     -9999     -9999     -9999]
 ...
 [       40   8209581         1 ...     -9999     -9999     -9999]
 [       40   7963801         4 ...     34200 240679848     -9999]
 [       40   8209581         4 ...     34200 776929875     -9999]]
X.shape: (100, 24)
X: [[12051  1003 12010 ...     2     2     2]
 [12051  1003 12010 ...     2     2     2]
 [12051  1003 12011 ...     2     2     2]
 ...
 [12051  1003 12011 ...     2     2     2]
 [12051  1006 12011 ...   243   682   851]
 [12051  1006 12011 ...   779   932   878]]
decoded X: [[       40     -9999         1 ...     -9999     -9999     -9999]
 [       40     -9999         1 ...     -9999     -9999     -9999]
 [       40     -9999         1 ...     -9999     -9999     -9999]
 ...
 [       40     -9999         1 ...     -9999     -9999    

In [6]:
encoded_tok_len = X.shape[1]
print("encoded_tok_len:", encoded_tok_len)

encoded_tok_len: 24


In [7]:
# prepare context tensor
x = (torch.tensor(X.reshape(-1), dtype=torch.long, device=device)[None, ...])
print("x.shape:", x.shape)
print("x:", x)

# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            # y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            y = model.generate(x, max_new_tokens*encoded_tok_len, temperature=temperature, top_k=top_k)
            # print(decode(y[0].tolist()))
            print("last generated msg:", y[0][-1].tolist())
            # print(y[0].tolist())
            print('---------------')
        print("new sequence", y[0].tolist())

x.shape: torch.Size([1, 2400])
x: tensor([[12051,  1003, 12010,  ...,   779,   932,   878]], device='cuda:0')


last generated msg: 172
---------------
new sequence [12051, 1003, 12010, 12008, 12007, 1108, 2, 3, 18, 164, 949, 37, 203, 25, 606, 15, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12010, 12008, 11603, 1108, 2, 3, 3, 9, 902, 37, 203, 25, 612, 914, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12011, 12009, 12007, 1108, 2, 3, 3, 19, 785, 37, 203, 25, 629, 696, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12011, 12009, 11577, 1108, 2, 3, 3, 13, 621, 37, 203, 25, 640, 314, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12010, 12008, 12007, 1108, 2, 3, 3, 470, 892, 37, 203, 26, 108, 203, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12011, 12009, 12007, 1108, 2, 3, 3, 21, 811, 37, 203, 26, 127, 11, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12010, 12008, 11603, 1108, 2, 3, 3, 160, 127, 37, 203, 26, 284, 135, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12011, 12009, 11577, 1108, 2, 3, 3, 23, 59, 37, 203, 26, 304, 191, 2, 2, 2, 2, 2, 2, 2, 2, 12051, 1003, 12011, 12009, 11322, 1108, 2, 3, 3, 841, 1001, 37, 203, 27, 143, 189, 2, 2, 2, 2, 2, 2, 2, 2,

In [8]:
# print the last message in the generated sequence
print("last generated msg:", y[0][-24:].tolist())

# print(y[0].tolist())
print("y:", y)

# decode the generated sequence
# print("decoded y:", itch_encoding.decode_msg(y[0][-24:].tolist(), vocab.ENCODING))
# print("decoded msg:", itch_encoding.decode_msg(np.array(y[0][-24:].tolist()), vocab.ENCODING))
decoded_msg = itch_encoding.decode_msg(np.array(y[0][-24:].tolist()), vocab.ENCODING)
print(decoded_msg)
print([ "ticker", "NA_VAL",
        "event_type", "direction", "NA_VAL", "price", "fill_size", "remain_size",
        "delta_t_s", "delta_t_ns", "time_s", "time_ns",
        "NA_VAL", "price_ref", "fill_size_ref", "time_s_ref", "time_ns_ref", "NA_VAL"])


last generated msg: [12051, 1006, 12011, 12009, 11121, 1012, 1008, 3, 3, 113, 82, 37, 203, 780, 83, 722, 12009, 11124, 1012, 37, 203, 778, 218, 172]
y: tensor([[12051,  1003, 12010,  ...,   778,   218,   172]], device='cuda:0')
[       40     -9999         4         1     -9999       113         4
         0         0    110079     34200 777080719     -9999       116
         4     34200 775215169     -9999]
['ticker', 'NA_VAL', 'event_type', 'direction', 'NA_VAL', 'price', 'fill_size', 'remain_size', 'delta_t_s', 'delta_t_ns', 'time_s', 'time_ns', 'NA_VAL', 'price_ref', 'fill_size_ref', 'time_s_ref', 'time_ns_ref', 'NA_VAL']


In [9]:
X_true = np.array(context_dataset[0:num_context_msgs+max_new_tokens])
print("X_true:", X_true)

print("true last msg:", X_true[-1])

X_true: [[       40   7872401         1 ...     -9999     -9999     -9999]
 [       40   7872405         1 ...     -9999     -9999     -9999]
 [       40   7872421         1 ...     -9999     -9999     -9999]
 ...
 [       40   7963801         4 ...     34200 240679848     -9999]
 [       40   8209581         4 ...     34200 776929875     -9999]
 [       40   8209645         1 ...     -9999     -9999     -9999]]
true last msg: [       40   8209645         1         1     29277       333       100
     -9999         0      4786     34200 777045652     -9999     -9999
     -9999     -9999     -9999     -9999]


In [10]:
print("decoded msg:", decoded_msg)
print("predicted symbol:", decoded_msg[0])
print("predicted event type:", decoded_msg[2])
print("predicted price:", decoded_msg[5])
print("predicted fill size:", decoded_msg[6])
print("predicted remain size:", decoded_msg[7])
print("predicted time:", decoded_msg[10], decoded_msg[11])
print("predicted price ref:", decoded_msg[13])
print("predicted fill size ref:", decoded_msg[14])
print("predicted time ref:", decoded_msg[15], decoded_msg[16])

decoded_msg

decoded msg: [       40     -9999         4         1     -9999       113         4
         0         0    110079     34200 777080719     -9999       116
         4     34200 775215169     -9999]
predicted symbol: 40
predicted event type: 4
predicted price: 113
predicted fill size: 4
predicted remain size: 0
predicted time: 34200 777080719
predicted price ref: 116
predicted fill size ref: 4
predicted time ref: 34200 775215169


array([       40,     -9999,         4,         1,     -9999,       113,
               4,         0,         0,    110079,     34200, 777080719,
           -9999,       116,         4,     34200, 775215169,     -9999])