In [None]:
import os
import requests
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import optax
from jax import value_and_grad
import pickle
import pandas as pd
import math

from helper_funcs import generate, masked_fill, loss_fn, get_token_batch, encode, decode
from tqdm import tqdm
import matplotlib.pyplot as plt
print(jax.devices())

## Parameter Selection

The Parameters used below are a scaled down version of GPT-2. GPT-2 has 4 different sizes, small, medium, large and xl. This GPT-2 could be considered an extra-small version. Note that these models may not be able to fit into RAM on your device. The exact specifications of the different sized models are shown below:

### GPT-2 Small
- n_embed: 768
- block_size: 1024
- num_heads: 12
- num_layers: 12
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 Medium
- n_embed: 1024
- block_size: 1024
- num_heads: 16
- num_layers: 24
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 Large
- n_embed: 1280
- block_size: 1024
- num_heads: 20
- num_layers: 36
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 XL
- n_embed: 1600
- block_size: 1024
- num_heads: 25
- num_layers: 48
- vocab_size: 50257 (uses Tiktoken vocab)

In [None]:
n_embed = 32 # Number of embedding dimensions
batch_size = 4 # How many independent sequences will we process in parallel?
block_size = 480 # What is the maximum context length for predictions?
num_heads = 4 # Number of heads in the multi-headed block
num_layers = 6 # Number of transformer decoder blocks
drop_rate = 0.1 # Dropout rate for regularization
#token_id_col='vector_id'
token_id_col='cluster'

rng_key = jax.random.PRNGKey(128)

## Data Preparation

In [None]:
with open('vocab.pickle', 'rb') as f:
    loaded_data = pickle.load(f)

#vocab_size = len(loaded_data['vocab'])
#print(vocab_size)

X = pd.read_parquet("X.parquet")
X

In [None]:
# Test
df_slice = X.iloc[100:224]

v = encode(df_slice, token_id_col=token_id_col)
v

In [None]:
decode(v, X, token_id_col)

In [None]:
vocab_size = X[token_id_col].max()
vocab_size

## Build the Attention Model

In [None]:
from attention_model import *

In [None]:
model = GPT2(vocab_size, n_embed, block_size, num_heads, num_layers, drop_rate)
dummy_x = jnp.zeros(shape=(batch_size, block_size), dtype=jnp.uint16)
variables = model.init(rng_key, dummy_x)

In [None]:
out = model.apply(variables, dummy_x)
print(out.shape)

## Time Series Generation Pre-Training

In [None]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 24
learning_rate=1e-4

#generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
#generated_indices = list(np.array(generated_indices[0]))
#decode(generated_indices, X)

## Train the Model

In [None]:
optimizer = optax.adamw(learning_rate=learning_rate)
opt_state = optimizer.init(variables)

In [None]:
epochs = 500
train_ids = X[token_id_col].to_numpy()
losses = []

pbar = tqdm(range(epochs))
for step in pbar:
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_token_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        model.apply,
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates)

    if math.isnan(loss):
        break

    losses.append(loss)
    
    pbar.set_description(f"Epoch: {step}, Loss: {loss :.4f}")

# Save model
model_file = {
    "epochs": epochs,
    "epoch": step,
    "model": model,
    "vocab_size": vocab_size,
    "block_size": block_size,
    "variables": variables,
    "losses": losses,
    "opt_state": opt_state,
    "learning_rate": learning_rate,
}

with open('model.pickle', 'wb') as f:
    pickle.dump(model_file, f)

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(losses)
plt.title('Loss')
plt.show()

## Time Series Generation Post-Training

In [None]:
rng_key, subkey = jax.random.split(rng_key)
max_new_tokens = 480
t = int(random.randint(rng_key, shape=(), minval=0, maxval=len(X)-(3*max_new_tokens)))

# Select a random batch (t)
X_test = X.iloc[t:t+max_new_tokens]
Y_test = X.iloc[t+max_new_tokens+1:t+max_new_tokens+max_new_tokens+1].reset_index()
x = [encode(X_test, token_id_col)]
index_seq = jnp.array(x)

print(index_seq.shape)
print(index_seq)

In [None]:
rng_key, subkey = jax.random.split(rng_key)
generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
Y = decode(generated_indices, X, token_id_col)

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(Y.index, Y['Power demand'], label="Predicted")
plt.plot(Y_test.index, Y_test['Power demand'], label="Actual")
plt.title('Predicted vs actual')
plt.legend()
plt.show()