# Quantum Deep Hedging in Quantum Environments
--------------

### Importing required packages


In [1]:
import warnings
from tqdm import tqdm

warnings.filterwarnings("ignore")

This notebook requirs the following packages:
- `jax`: a numerical computing library that provides GPU acceleration and supports automatic differentiation.
- `haiku`: a neural network library built on top of JAX that provides a more user-friendly API for defining neural networks.
- `optax`: a library for implementing gradient-based optimization algorithms in JAX.

In [2]:
import numpy as np 
import haiku as hk 
import jax 

### Importing source code 
The source code imported here is made of different functions used for Quantum Deep Hedging algorithm.
- `env.py` contains Environment functions
  - `compute_black_scholes_delta` computes the Black-Scholes delta for a given sequence of prices.
  - `compute_prices` computes the stock prices at each day given the jump sequences and other parameters.
  - `compute_rewards` computes the rewards given a sequence of prices and deltas.
  - `compute_bounds` computes the bounds of R(t) for a given set of parameters.
  - `compute_returns` computes the sequence of cumulative returns.
  - `compute_utility` computes the utility of the sequence of rewards using exponential utility.
- `quantum.py`contains functions to build quantum compound NNs
  - `get_brick_idxs` computes the indices of the RBS gates for the Brick architecture and returns a nested list where each inner list contains pairs of indices indicating the RBS gates to be applied in parallel.
  - `make_ortho_fn` creates a mapping between a set of parameters to an orthogonal matrix and returns it as a JAX function that maps a set of parameters to an orthogonal matrix.
  - `compute_compound` computes the compound matrix given the orthogonal matrix (unary) and the order k.
  - `decompose_state` decomposes a quantum state into a sum of weighted subspaces and their respective projections.
- `agent.py`
  - `make_train_nn` create a compound NN that use a batching tree during train
  - `make_test_nn` create a compound NN that can takes as input any sequence of jumps
  - `make_train` create agent trainers for the vanilla and actor-critic algorithms
  - `make_test` create agent evaluators for trained quantum policies

In [3]:
import source

### Experiment

In [4]:
test_seq_jumps = source.utils.load_file("./data/hardware_seq_jumps.pkl")

def train_experiment(seed, train_steps, **hparams):
    """ Train an agent for a given number of steps and save the best parameters."""
    init_step, train_step = source.agent.make_train(**hparams)
    test_step = source.agent.make_test(**hparams)
    np_random = np.random.RandomState(seed=seed)
    rng_key = hk.PRNGSequence(np_random.randint(0, 2**32 - 1))
    params, opt_state = init_step(next(rng_key))
    best_utility, best_params = -np.inf, None 
    all_metrics, all_params = [], []
    stop_exp = False

    for it in tqdm(range(train_steps)):
        # test step
        ##### TODO: Remove this after debug
        test_metrics = test_step(test_seq_jumps[:,:hparams['num_days']], params)
        #test_metrics = test_step(test_seq_jumps, params)
        test_metrics = jax.device_get(test_metrics)
        test_metrics = jax.tree_map(float, test_metrics)
        utility = test_metrics[f"utility"]
        bs_utility = test_metrics[f"bs_utility"]
        if utility > best_utility:
            best_utility = utility
            best_params = params
        test_metrics["best"] = best_utility
        # train step
        params, opt_state, train_metrics = train_step(
            next(rng_key),
            params,
            opt_state,
        )
        train_metrics = jax.device_get(train_metrics)
        train_metrics = jax.tree_map(float, train_metrics)
        all_metrics.append((it, train_metrics, test_metrics))
        all_params.append((it, params))
        # stop if nan
        for _, v in train_metrics.items():
            if np.isnan(v):
                stop_exp = True
        if stop_exp:
            break
    return best_utility, bs_utility, best_params, all_metrics, all_params

In [5]:
seed = 123
train_steps = 10

num_days = 5
num_trading_days = 30
initial_price = 100.0
utility_lambda = 0.1
mu = 0.0
sigma = 0.2
train_num_paths = 16
utility_lambda = 0.1

strike = 1.0
cost_eps = 0.0

critic_update = "vanilla"
if critic_update == "vanilla":
    model = "vanilla"
    actor_opt = "radam"
    actor_lr = 1E-3
    critic_opt = None
    critic_lr = None
else:
    model = "actor_critic"
    actor_opt = "radam"
    actor_lr = 1E-3
    critic_opt = "adam"
    critic_lr = 1E-2

hparams = dict(
    model=model,
    num_days=num_days,
    num_trading_days=num_trading_days,
    mu=mu,
    sigma=sigma,
    initial_price=initial_price,
    strike=strike,
    cost_eps=cost_eps,
    utility_lambda=utility_lambda,
    train_num_paths=train_num_paths, 
    actor_opt=actor_opt,
    actor_lr=actor_lr,
    critic_update=critic_update,
    critic_opt=critic_opt,
    critic_lr=critic_lr,
)


In [6]:
best_utility, bs_utility, best_params, all_metrics, all_params = train_experiment(seed=seed, train_steps=train_steps, **hparams)

100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


In [7]:
print(f"best utility / black-scholes utility = {best_utility:.3f}/{bs_utility:.3f}")

best utility / black-scholes utility = -3.174/-3.422


In [8]:
prefix = (
    f"{hparams['cost_eps']}_{hparams['strike']}_{hparams['critic_update']}_{seed}"
)
source.utils.save_file(f"./logs/{prefix}-best_params.pkl", best_params)
source.utils.save_file(f"./logs/{prefix}-all_metrics.pkl", all_metrics)
source.utils.save_file(f"./logs/{prefix}-all_params.pkl", all_params)
    
