# GAN Development - Kochems Approach


## Import libraries

In [None]:

%load_ext autoreload
%autoreload 2

from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import csv
from collections import defaultdict
from data_reader import DataReader
from torch.autograd import Variable, grad
import matplotlib.pyplot as plt

np.set_printoptions(threshold=np.inf)

## Read in Data

### Functions to preprocess data

In [None]:
def convert_ask_bid_int(dataset):
    mask_ask = np.char.endswith(dataset[:,:,2], 'ask')
    mask_bid = np.char.endswith(dataset[:,:,2], 'bid')
    dataset[:,:,2][mask_ask] = '1'
    dataset[:,:,2][mask_bid] = '0'
    dataset = dataset.astype(np.float32)
    return dataset
    
    
def get_dataset_max_price(dataset, rows_per_orderbook, level = -1):
    index_on_each_OB = 0
    if level != -1:
        index_on_each_OB = (rows_per_orderbook//2-(level))
    last_row_prices = dataset[:, index_on_each_OB, 0]
    max_val = np.max(last_row_prices)
    return max_val

def get_dataset_min_price(dataset, rows_per_orderbook, level = -1):
    index_on_each_OB = -1
    if level != -1:
        index_on_each_OB = (rows_per_orderbook//2+(level-1))
    first_row_prices = dataset[:, index_on_each_OB, 0]
    min_val = np.min(first_row_prices)
    return min_val

def make_histogram_from_dataset(dataset, rows_per_orderbook = 100, bin_width = 0.5, level = -1):
    X_train = []
    hist_max = get_dataset_max_price(dataset, rows_per_orderbook, level)
    hist_min = get_dataset_min_price(dataset, rows_per_orderbook, level)
    print("range: ", hist_min, " ", hist_max)
    num_bins = int(np.ceil((hist_max-hist_min) / bin_width))
    bins = np.linspace(hist_min, hist_max, num_bins)
    for i in range(len(dataset)):
        orderbook = dataset[i]
        price = orderbook[:,0]
        quantity = orderbook[:,1]
        quantity[orderbook[:, 2] == 0] *= -1
        hist, bin_edges = np.histogram(price, bins=bins, weights=quantity)
        X_train.append(hist)
    X_train = np.array(X_train)
    return X_train, hist_min, hist_max, bins

def make_centred_LOB_snapshots(histograms, level = 1):
    X_train = []
    y_train = []
    for i in range(len(histograms)-1):
        current_OB = histograms[i]
        next_OB = histograms[i+1]
        j = -1
        while j < len(current_OB)-1 and not (current_OB[j] < 0 and current_OB[j+1] > 0): j+=1
        j+=1
        
        current_start_index = j-level
        current_subarray_size = 2 * level
        current_centre_LOB_snapshot = current_OB[current_start_index: current_start_index + current_subarray_size]
        current_np_before_after = np.zeros((level,))
        
        k = -1
        while k < len(next_OB)-1 and not (next_OB[k] < 0 and next_OB[k+1] > 0): k+=1
        k+=1
        
        next_start_index = k-level
        next_subarray_size = 2 * level
        next_centre_LOB_snapshot = next_OB[next_start_index: next_start_index + next_subarray_size]
        
        jk_diff = j-k
        if (abs(jk_diff) > level): continue
        
        next_np_before = np.zeros((level+jk_diff,))
        next_np_after = np.zeros((level-jk_diff,))
        current_centre_LOB_snapshot = np.concatenate((current_np_before_after, current_centre_LOB_snapshot, current_np_before_after))
        next_OB_transition = np.concatenate((next_np_before, next_centre_LOB_snapshot, next_np_after))
        X_train.append(current_centre_LOB_snapshot)
        y_train.append(next_OB_transition)
        
    X_train = np.vstack(X_train)
    y_train = np.vstack(y_train)
    
    return X_train, y_train
        

### Actually reading in data

In [None]:
data_reader = DataReader("orderbook_snapshots100.csv", rows_per_orderbook=100)
data_reader.read_csv()
X_train_raw = data_reader.get_data()
X_train_raw = convert_ask_bid_int(X_train_raw)
print(X_train_raw.shape)

### Preprocess data

In [None]:
histograms, price_min, price_max, bins = make_histogram_from_dataset(X_train_raw, rows_per_orderbook=100, bin_width=0.5, level=-1)
print(histograms.shape)
print(histograms)
# print(X_train_raw)

In [None]:
X_train, y_train = make_centred_LOB_snapshots(histograms, level=3)

In [None]:
best_asks = X_train[:, 8]
best_bids = X_train[:, 3]

plt.hist(best_bids, bins=30, density=True)
plt.show()
plt.hist(best_asks, bins=30, density=True)
plt.show()

## WGAN-GP Training

### Defining the Critic and Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=5, n=12):
        super(Generator, self).__init__()
        self.h_11 = nn.Linear(latent_dim,32)
        self.h_12 = nn.Linear(n,32)
        self.h_21 = nn.Linear(32,32)
        self.h_22 = nn.Linear(32,32)
        self.h_32 = nn.Linear(64, 64)
        self.h_33 = nn.Linear(64, n)

    def forward(self, noise, previous_state):
        h_11_output = torch.relu(self.h_11(noise))
        h_21_output = torch.relu(self.h_21(h_11_output))
        
        h_12_output = torch.relu(self.h_12(previous_state))
        h_22_output = torch.relu(self.h_22(h_12_output))
        
        #concatenation
        h_31_output = torch.cat((h_21_output, h_22_output), dim=1)
        
        h_32_output = torch.relu(self.h_32(h_31_output))
        h_33_output = self.h_33(h_32_output)
        return h_33_output

#Markovian setting
class Critic(nn.Module):
    def __init__(self, n=32):
        super(Critic, self).__init__()
        self.h_11 = nn.Linear(n,32)
        self.h_12 = nn.Linear(n,32)
        self.h_21 = nn.Linear(32,32)
        self.h_22 = nn.Linear(32,32)
        self.h_32 = nn.Linear(64, 64)
        self.h_33 = nn.Linear(64, 1)
        

    def forward(self, previous_state, current_state):
        h_11_output = torch.relu(self.h_11(current_state))
        h_21_output = torch.relu(self.h_21(h_11_output))
        
        h_12_output = torch.relu(self.h_12(previous_state))
        h_22_output = torch.relu(self.h_22(h_12_output))
        
        #concatenation
        h_31_output = torch.cat((h_21_output, h_22_output), dim=1)
        
        h_32_output = torch.relu(self.h_32(h_31_output))
        h_33_output = self.h_33(h_32_output)
        return h_33_output

### Defining the training loop

#### Gradient Penalty

In [None]:
def gradient_penalty(D, real_samples, fake_samples):
    batch_size = real_samples.size(0)
    # Ensure alpha is shaped correctly for broadcasting
    alpha = torch.rand(batch_size, 1, device=real_samples.device)
    alpha = alpha.expand(batch_size, real_samples.nelement() // batch_size).contiguous().view(batch_size, -1)

    # Ensure real_samples and fake_samples are flat
    real_samples_flat = real_samples.view(batch_size, -1)
    fake_samples_flat = fake_samples.view(batch_size, -1)

    # Calculate interpolates
    interpolates = (alpha * real_samples_flat + (1 - alpha) * fake_samples_flat).requires_grad_(True)
    d_interpolates = D(interpolates)
    
    fake = Variable(torch.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False).to(real_samples.device)
    gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake,
                     create_graph=True, retain_graph=True, only_inputs=True)[0]

    # Flatten the gradients
    gradients = gradients.view(gradients.size(0), -1)  
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

#### Training Loop of WGAN with Gradient Penalty

In [None]:
# Define variables to track progress
avg_d_loss = 0
avg_g_loss = 0
n_batches = len(data_loader)

# Define the number of critic updates per generator update
n_critic = 5  

for epoch in range(200):
    for i, (imgs, _) in enumerate(data_loader):
        
        # Configure input
        real_imgs = imgs.view(imgs.size(0), -1)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        for _ in range(n_critic):  # Update the discriminator n_critic times
            optimizer_D.zero_grad()

            # Sample noise as generator input
            z = torch.randn(imgs.shape[0], 100)  # Ensure noise_dim matches generator input

            # Generate a batch of images
            fake_imgs = generator(z)

            # Real images
            real_validity = discriminator(real_imgs)
            # Fake images
            fake_validity = discriminator(fake_imgs)
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            
            # Wasserstein GAN loss w/ gradient penalty
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty

            d_loss.backward()
            optimizer_D.step()
            avg_d_loss += d_loss.item() / n_critic  # Average over the n_critic updates
        
        # -----------------
        #  Train Generator
        # -----------------
        if i % n_critic == 0:  # Update the generator every n_critic steps
            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            g_loss = -torch.mean(discriminator(gen_imgs))

            g_loss.backward()
            optimizer_G.step()
            avg_g_loss += g_loss.item()

        # Prints progress within the epoch
        if (i+1) % 100 == 0:  # Print every 100 steps
            print(f"Epoch [{epoch+1}/{200}], Step [{i+1}/{n_batches}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # Prints average loss per epoch
    avg_d_loss /= n_batches
    avg_g_loss /= n_batches
    print(f"Epoch [{epoch+1}/{200}] completed. Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")
    # Resets average losses for the next epoch
    avg_d_loss = 0
    avg_g_loss = 0


4. Instantiate Models & Optimisers

In [None]:
# Instantiate the generator and discriminator
generator = Generator()
critic = Critic()

# Define the optimisers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = torch.optim.Adam(critic.parameters(), lr=0.001)


5. Training Loop

In [None]:
criterion = nn.BCELoss()
n_epochs = 200  # ADJUST
batch_size = 64  # ADJUST
noise_dim = 2  # Size of noise vector

for epoch in range(n_epochs):
    for i in range(0, len(data_tensor), batch_size):
        # Prepare real data batch
        real_data = data_tensor[i:i+batch_size, 1:]  # Exclude timestamp from training
        real_labels = torch.ones(real_data.size(0), 1)
        fake_labels = torch.zeros(real_data.size(0), 1)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        
        # Real data loss
        real_output = discriminator(real_data)
        d_loss_real = criterion(real_output, real_labels)
        
        # Generate fake data
        noise = torch.randn(real_data.size(0), noise_dim)
        # Conditional input, for now, let's use a random slice from the best_ask_price as an example
        conditional_input = real_data[:, 0].unsqueeze(1)  # This should be modified based on your specific conditional input


        # Right before generator(noise, conditional_input) call
        # print("Conditional input shape before generator:", conditional_input.shape)

        fake_data = generator(noise, conditional_input)
        
        # Fake data loss
        fake_output = discriminator(fake_data.detach())  # Detach to avoid training generator on these labels
        d_loss_fake = criterion(fake_output, fake_labels)
        
        # Combine loss and update discriminator
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        
        # Trick discriminator into thinking the generated data is real
        output = discriminator(fake_data)
        g_loss = criterion(output, real_labels)
        
        g_loss.backward()
        optimizer_G.step()
        
        if i % 100 == 0:  # Adjust printing frequency based on your preference
            print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(data_tensor)//batch_size}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")


6. Save Generator Model

In [None]:
# Save the generator's state dictionary
torch.save(generator.state_dict(), '/Users/sina/Downloads/SCRATCH_GENERATIVE_ADV_NET/Saved_Generator_States/generator_state_dict.pth')
print("Generator state has been saved.")

7. Generate Example Orderbook Snapshots

In [None]:
# Generate a sample order book snapshot
with torch.no_grad():
    test_noise = torch.randn(1, noise_dim)
    test_price_level = torch.tensor([[0.5]])  # Example price level, normalized
    generated_snapshot = generator(test_noise, test_price_level)
    inverse_transformed_snapshot = scaler.inverse_transform(generated_snapshot.numpy())
    print("Generated Order Book Snapshot:", inverse_transformed_snapshot)
