# Installs and Imports

In [1]:
import optax
from jax import numpy as jnp
import jax
from functools import partial
from tqdm import tqdm, trange
import os

# Data

We begin by import the modules from the repository. Then to generate train and test data sets for geometric Brownian motion simulations we use the gen_paths() function in `data.py`. 

Note: The data generation has been commented here because we pickled a particular train and test split. 

In [2]:
from source.models import simple_network, recurrent_network, lstm_network, attention_network
from source.qnn import linear, ortho_linear, ortho_linear_noisy
from source.train import build_train_fn, gen_paths, entropy_loss
from source.utils import train_test_split, get_batches, HyperParams
import numpy as np
import pickle

seed = 100
key = jax.random.PRNGKey(seed)
hps = HyperParams(n_steps=30, discrete_path = False)

# Data
# S = gen_paths(hps)
# [S_train, S_test] = train_test_split([S], test_size=0.2)
# _, train_batches = get_batches(jnp.array(S_train[0]), batch_size=hps.batch_size)
# _, test_batches = get_batches(jnp.array(S_test[0]), batch_size=hps.batch_size)

In [3]:
# saving train and test batches
# pickle.dump(train_batches, open('/content/drive/MyDrive/JPMC/train_batches_30_days', 'wb'))
# pickle.dump(test_batches, open('/content/drive/MyDrive/JPMC/test_batches_30_days', 'wb'))

# DeepHedgingBenchmark()

This is a class to benchmark the performance of different models and layers for deep hedging. It has two methods: __train_model() and __test_model() which are used to train and test the deep learning models. It also has a train() method to train the model on given input data.

## Parameters
- `key`: A random key value used for jax random splitting.
- `eps`: A list of float values representing the hedge intervals to be used in training.
- `layers`: A list of string values representing the layer types to be used in 
training. It should only contain values from `['linear', 'ortho_pyramid', 'ortho_butterfly', 'noisy_ortho_pyramid', 'noisy_ortho_butterfly']`.
- `models`: A list of string values representing the model types to be used in training. It should only contain values from `['simple', 'recurrent', 'lstm', 'attention']`.

## Methods
- `__train_model(hps, train_batches)`: A private method that trains the model. It takes in hyperparameters hps and train batches train_batches. The hyperparameters include layer_type, model_type, n_steps, epsilon, and num_epochs. It returns the training losses, trained model parameters and time taken for training.
- `__test_model(hps,params, test_batches)`: A private method that tests the model. It takes in hyperparameters hps, trained model parameters params and test batches test_batches. The hyperparameters include layer_type, model_type, n_steps, epsilon, and num_epochs. It returns the testing loss.
- `train(inputs, save_loc)`: A method to train the model. It takes in inputs and save_loc. Inputs are training data, and save_loc is the path where the trained model is to be saved. If the path exists, it loads the model from the location and continues training. It trains the model using the __train_model() method and for each combination of layers, eps and models. It saves the training information in train_info as a dictionary.
- `test(test_batches,save_loc)`:  A method to test the model. It takes in test_batches and save_loc. test_batches are testing data, and save_loc is the path where the trained model is to be saved. If the path exists, it loads the model from the location and continues testing. It trains the model using the __test_model() method and for each combination of layers, eps and models.

In [4]:
_LAYERS  = [
    'linear',
    'ortho_pyramid',
    'ortho_butterfly',
    'noisy_ortho_pyramid',
    'noisy_ortho_butterfly'
]

_MODELS = [
    'simple',
    'recurrent',
    'lstm',
    'attention'
]
class DeepHedgingBenchmark():
  """
    Initializes a DeepHedgingBenchmark object with the given key, epsilon values,
    layer types, and model types.

    Args:
    - key: a random key for reproducibility
    - eps: a list of epsilon values for training
    - layers: a list of layer types to use for training
    - models: a list of model types to use for training

    Returns:
    - None
    """
  def __init__(self, key, eps,  layers, models):
    assert all(layer in  _LAYERS for layer in layers), f'Layers don\'t have valid layer types.'
    assert all(model in  _MODELS for model in models), f'Models don\'t have valid model types.'
    self.__key = key
    self.__models = models
    self.__layers = layers
    self.__eps = eps
    self.train_info = {layer:{str(eps):{} for eps in self.__eps} for layer in self.__layers}
  def __train_model(self, hps, train_batches):
    """
    Trains a model with the given hyperparameters and training batches.

    Args:
    - hps: a HyperParams object specifying the hyperparameters for the model
    - train_batches: a generator of training batches

    Returns:
    - train_losses: a list of training losses for each epoch
    - params: the final model parameters
    - elapsed: the elapsed time for training the model
    """
    if hps.layer_type == 'linear':
      layer_func = linear
    else:
      layout = hps.layer_type.split('_')[-1]
      if hps.layer_type.startswith('ortho'):
        layer_func = partial(ortho_linear,layout = layout)
      elif hps.layer_type.startswith('noisy_ortho'):
        layer_func = partial(ortho_linear_noisy,layout = layout,noise_scale=0.01)

    if hps.model_type == 'simple':
      net = simple_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'recurrent':
      net = recurrent_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'lstm':
      net = lstm_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'attention':
      net = attention_network(hps=hps, layer_func=layer_func)
    
    opt = optax.adam(1E-3)
    key, init_key = jax.random.split(self.__key)
    params, state, _ = net.init(init_key, (1, hps.n_steps, 1))
    opt_state = opt.init(params)
    loss_metric = entropy_loss

    # Training

    train_fn, loss_fn = build_train_fn(hps, net, opt, loss_metric)
    num_epochs = hps.num_epochs
    loss = 0.0
    train_losses=[]
    elapsed = 0    
    with trange(1, num_epochs+1) as t:
      for epoch in t:
        loss_epoch = []
        for i, inputs in enumerate(train_batches):
          inputs = inputs[...,None]
          key, train_key = jax.random.split(key)
          params, state, opt_state, loss, (wealths, deltas, outputs) = train_fn(
              params, state, opt_state, train_key, inputs)
          loss_epoch.append(loss)
        loss = jnp.mean(jnp.array(loss_epoch))
        train_losses.append(loss)  
        t.set_postfix(loss=loss,model=hps.model_type, layer=hps.layer_type, eps=hps.epsilon)
        if epoch==num_epochs:
          elapsed = t.format_dict["elapsed"]
    return train_losses,params, elapsed
  def __test_model(self, hps,params, test_batches):

    if hps.layer_type == 'linear':
      layer_func = linear
    else:
      layout = hps.layer_type.split('_')[-1]
      if hps.layer_type.startswith('ortho'):
        layer_func = partial(ortho_linear,layout = layout)
      elif hps.layer_type.startswith('noisy_ortho'):
        layer_func = partial(ortho_linear_noisy,layout = layout,noise_scale=0.01)

    if hps.model_type == 'simple':
      net = simple_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'recurrent':
      net = recurrent_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'lstm':
      net = lstm_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'attention':
      net = attention_network(hps=hps, layer_func=layer_func)
    
    opt = optax.adam(1E-3)
    key, init_key = jax.random.split(self.__key)
    _, state, _ = net.init(init_key, (1, hps.n_steps, 1))
    loss_metric = entropy_loss

    # Testing

    _, loss_fn = build_train_fn(hps, net, opt, loss_metric)
    loss = 0.0    
    loss_epoch = []
    for i, inputs in enumerate(test_batches):
      inputs = inputs[...,None]
      key, test_key = jax.random.split(key)
      loss,_ = jax.jit(loss_fn)(params, state, test_key, inputs)
      loss_epoch.append(loss)
    loss = jnp.mean(jnp.array(loss_epoch))
    print(f'Model = {hps.model_type} | Layer = {hps.layer_type} | EPS = {hps.epsilon}| Loss = {loss}')
    return loss

  def train(self, inputs, save_loc):
    if os.path.exists(save_loc):
      self.train_info = pickle.load(open(save_loc, 'rb'))
    else:
      self.train_info = {layer:{str(eps):{} for eps in self.__eps} for layer in self.__layers}
    for layer in self.__layers:
      for eps in self.__eps:
        for model in self.__models:
            hps = HyperParams(S0=100,
                  n_steps=30,
                  n_paths=120000,
                  discrete_path=False,
                  strike_price=100,
                  epsilon=eps,
                  sigma=0.2,
                  risk_free=0,
                  dividend=0,
                  model_type=model,
                  layer_type=layer,
                  n_features=16,
                  n_layers=3,
                  loss_param=1.0,
                  batch_size=256,
                  test_size=0.2,
                  optimizer='adam',
                  learning_rate=1E-3,
                  num_epochs=100
                  )

            if model in self.train_info[layer][str(eps)].keys():
              train_losses,params, elapsed = self.train_info[layer][str(eps)][model]
              
              loss = min(train_losses)
              num_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
              epochs = len(train_losses)
              print(f'[eps={eps}, layer={layer}, loss={loss}, model={model}, num_params = {num_params}, elapsed = {elapsed}, num_epochs = {epochs}] already saved, continuing...')
              continue
            else:
              result = self.__train_model(hps, inputs)
              self.train_info[layer][str(eps)][model] = result
              pickle.dump(open(save_loc,'wb'), self.train_info)
  
  def test(self, test_batches, save_loc):
    if os.path.exists(save_loc):
      self.train_info = pickle.load(open(save_loc, 'rb'))
    for model in self.__models:
      for eps in self.__eps:
        for layer in self.__layers:
            hps = HyperParams(S0=100,
                  n_steps=30,
                  n_paths=120000,
                  discrete_path=False,
                  strike_price=100,
                  epsilon=eps,
                  sigma=0.2,
                  risk_free=0,
                  dividend=0,
                  model_type=model,
                  layer_type=layer,
                  n_features=16,
                  n_layers=3,
                  loss_param=1.0,
                  batch_size=256,
                  test_size=0.2,
                  optimizer='adam',
                  learning_rate=1E-3,
                  num_epochs=100
                  )
            if model in self.train_info[layer][str(eps)].keys():
              _,params, _ = self.train_info[layer][str(eps)][model]
              self.__test_model(hps,params, test_batches)
            else:
              print(f"Error! layer={layer}, eps={eps}, model={model} not found")

In [5]:
seed = 100
key = jax.random.PRNGKey(seed)

LAYERS = ['linear', 'ortho_pyramid', 'ortho_butterfly']
EPS = [ 0.0 , 0.01]
MODELS = ['simple','recurrent','lstm', 'attention']

dhb = DeepHedgingBenchmark(key=key,eps=EPS, layers=LAYERS, models=MODELS)

In [6]:
train_batches = pickle.load(open('data/train_batches_30_days', 'rb'))

In [7]:
dhb.train(train_batches,save_loc='params/train_info.pkl')

[eps=0.0, layer=linear, loss=2.8535077571868896, model=simple, num_params = 881, elapsed = 68.77824378013611, num_epochs = 100] already saved, continuing...
[eps=0.0, layer=linear, loss=2.9285597801208496, model=recurrent, num_params = 881, elapsed = 285.0432679653168, num_epochs = 100] already saved, continuing...
[eps=0.0, layer=linear, loss=2.850743532180786, model=lstm, num_params = 569, elapsed = 217.25190377235413, num_epochs = 100] already saved, continuing...
[eps=0.0, layer=linear, loss=2.8527638912200928, model=attention, num_params = 1905, elapsed = 68.99567246437073, num_epochs = 100] already saved, continuing...
[eps=0.01, layer=linear, loss=5.03331995010376, model=simple, num_params = 881, elapsed = 62.16713500022888, num_epochs = 100] already saved, continuing...
[eps=0.01, layer=linear, loss=5.045419692993164, model=recurrent, num_params = 881, elapsed = 283.9681055545807, num_epochs = 100] already saved, continuing...
[eps=0.01, layer=linear, loss=4.733962059020996, mo

In [8]:
test_batches = pickle.load(open('data/test_batches_30_days', 'rb'))

In [9]:
dhb.test(test_batches,save_loc='params/train_info.pkl')

Model = simple | Layer = linear | EPS = 0.0| Loss = 2.8678064346313477
Model = simple | Layer = ortho_pyramid | EPS = 0.0| Loss = 2.8730738162994385
Model = simple | Layer = ortho_butterfly | EPS = 0.0| Loss = 2.8743155002593994
Model = simple | Layer = linear | EPS = 0.01| Loss = 5.0649566650390625
Model = simple | Layer = ortho_pyramid | EPS = 0.01| Loss = 5.048909664154053
Model = simple | Layer = ortho_butterfly | EPS = 0.01| Loss = 5.043356418609619
Model = recurrent | Layer = linear | EPS = 0.0| Loss = 2.933910369873047
Model = recurrent | Layer = ortho_pyramid | EPS = 0.0| Loss = 2.939173460006714
Model = recurrent | Layer = ortho_butterfly | EPS = 0.0| Loss = 2.930788516998291
Model = recurrent | Layer = linear | EPS = 0.01| Loss = 5.075922966003418
Model = recurrent | Layer = ortho_pyramid | EPS = 0.01| Loss = 5.101611137390137
Model = recurrent | Layer = ortho_butterfly | EPS = 0.01| Loss = 4.854455947875977
Model = lstm | Layer = linear | EPS = 0.0| Loss = 2.852872371673584
