# Quantum Deep Hedging in Quantum Environments
--------------

### Importing required packages


In [None]:
import warnings
from tqdm import tqdm

warnings.filterwarnings("ignore")

This notebook requirs the following packages:
- `jax`: a numerical computing library that provides GPU acceleration and supports automatic differentiation.
- `haiku`: a neural network library built on top of JAX that provides a more user-friendly API for defining neural networks.
- `optax`: a library for implementing gradient-based optimization algorithms in JAX.

In [None]:
import numpy as np 
import haiku as hk 
import jax 

### Importing source code 
The imported source code comprises various functions utilized for the Quantum Deep Hedging algorithm.

- `env.py` includes environment functions:
  - `compute_black_scholes_delta` calculates the Black-Scholes delta for a given sequence of prices.
  - `compute_prices` computes the stock prices at each day, given the jump sequences and other parameters.
  - `compute_rewards` computes the rewards given a sequence of prices and deltas.
  - `compute_bounds` computes the bounds of R(t) for a given set of parameters.
  - `compute_returns` computes the sequence of cumulative returns.
  - `compute_utility` computes the utility of the sequence of rewards using exponential utility.
- `quantum.py` includes functions to build quantum compound NNs:
  - `get_brick_idxs` calculates the indices of the RBS gates for the Brick architecture and returns a nested list where each inner list contains pairs of indices indicating the RBS gates to be applied in parallel.
  - `make_ortho_fn` creates a mapping between a set of parameters to an orthogonal matrix and returns it as a JAX function that maps a set of parameters to an orthogonal matrix.
  - `compute_compound` calculates the compound matrix given the orthogonal matrix (unary) and the order k.
  - `decompose_state` decomposes a quantum state into a sum of weighted subspaces and their respective projections.
- `agent.py includes:
  - `make_train_nn` creates a compound NN that uses a batching tree during training.
  - `make_test_nn` creates a compound NN that can take any sequence of jumps as input.
  - `make_train` creates agent trainers for the vanilla and actor-critic algorithms.
  - `make_test` creates agent evaluators for trained quantum policies.

In [None]:
import source

### Experiment

In [None]:
test_seq_jumps = source.utils.load_file("./data/hardware_seq_jumps.pkl")

def train_experiment(seed, train_steps, **hparams):
    """ Train an agent for a given number of steps and save the best parameters."""
    init_step, train_step = source.agent.make_train(**hparams)
    test_step = source.agent.make_test(**hparams)
    np_random = np.random.RandomState(seed=seed)
    rng_key = hk.PRNGSequence(np_random.randint(0, 2**32 - 1))
    params, opt_state = init_step(next(rng_key))
    best_utility, best_params = -np.inf, None 
    all_metrics, all_params = [], []
    stop_exp = False

    for it in tqdm(range(train_steps)):
        # test step
        test_metrics, _ = test_step(test_seq_jumps[:,:num_days], params)
        test_metrics = jax.device_get(test_metrics)
        test_metrics = jax.tree_map(float, test_metrics)
        utility = test_metrics[f"utility"]
        bs_utility = test_metrics[f"bs_utility"]
        if utility > best_utility:
            best_utility = utility
            best_params = params
        test_metrics["best"] = best_utility
        # train step
        params, opt_state, train_metrics = train_step(
            next(rng_key),
            params,
            opt_state,
        )
        train_metrics = jax.device_get(train_metrics)
        train_metrics = jax.tree_map(float, train_metrics)
        all_metrics.append((it, train_metrics, test_metrics))
        all_params.append((it, params))
        # stop if nan
        for _, v in train_metrics.items():
            if np.isnan(v):
                stop_exp = True
        if stop_exp:
            break
    return best_utility, bs_utility, best_params, all_metrics, all_params

In [None]:
train_steps = 2000

num_days = 10
num_trading_days = 30
initial_price = 100.0
utility_lambda = 0.1
mu = 0.0
sigma = 0.2
train_num_paths = 16
utility_lambda = 0.1

In [None]:
#====================================================================================================#

# Experiment 1/3 | vanilla | 1295890448 | 0.0 | 0.95 | vanilla | U = -7.215316772
cost_eps = 0.0
strike = 0.95
seed = 1295890448
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 623440079 | 0.0 | 0.95 | vanilla | U = -7.354511261
cost_eps = 0.0
strike = 0.95
seed = 623440079
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 2035691052 | 0.0 | 0.95 | vanilla | U = -7.260621071
cost_eps = 0.0
strike = 0.95
seed = 2035691052
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 1776499889 | 0.0 | 0.95 | expected | U = -7.778673172
cost_eps = 0.0
strike = 0.95
seed = 1776499889
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 1189185057 | 0.0 | 0.95 | expected | U = -7.29262352
cost_eps = 0.0
strike = 0.95
seed = 1189185057
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 2441524775 | 0.0 | 0.95 | expected | U = -7.527234554
cost_eps = 0.0
strike = 0.95
seed = 2441524775
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 134117753 | 0.0 | 0.95 | distributional | U = -7.383306503
cost_eps = 0.0
strike = 0.95
seed = 134117753
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 1498164396 | 0.0 | 0.95 | distributional | U = -7.613495827
cost_eps = 0.0
strike = 0.95
seed = 1498164396
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 2057014442 | 0.0 | 0.95 | distributional | U = -7.540843487
cost_eps = 0.0
strike = 0.95
seed = 2057014442
critic_update = 'distributional'

# Experiment 1/3 | vanilla | 2980840230 | 0.0 | 1.0 | vanilla | U = -4.467153072
cost_eps = 0.0
strike = 1.0
seed = 2980840230
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 2468570955 | 0.0 | 1.0 | vanilla | U = -4.304522514
cost_eps = 0.0
strike = 1.0
seed = 2468570955
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 3107862872 | 0.0 | 1.0 | vanilla | U = -4.057483196
cost_eps = 0.0
strike = 1.0
seed = 3107862872
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 1873257998 | 0.0 | 1.0 | expected | U = -4.186463356
cost_eps = 0.0
strike = 1.0
seed = 1873257998
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 2400283990 | 0.0 | 1.0 | expected | U = -4.26928091
cost_eps = 0.0
strike = 1.0
seed = 2400283990
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 267933897 | 0.0 | 1.0 | expected | U = -4.373553276
cost_eps = 0.0
strike = 1.0
seed = 267933897
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 3878154799 | 0.0 | 1.0 | distributional | U = -3.870510817
cost_eps = 0.0
strike = 1.0
seed = 3878154799
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 3215051902 | 0.0 | 1.0 | distributional | U = -4.25199461
cost_eps = 0.0
strike = 1.0
seed = 3215051902
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 161099832 | 0.0 | 1.0 | distributional | U = -4.075201511
cost_eps = 0.0
strike = 1.0
seed = 161099832
critic_update = 'distributional'

# Experiment 1/3 | vanilla | 1327560560 | 0.0 | 1.05 | vanilla | U = -3.11974144
cost_eps = 0.0
strike = 1.05
seed = 1327560560
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 1307767281 | 0.0 | 1.05 | vanilla | U = -2.788729429
cost_eps = 0.0
strike = 1.05
seed = 1307767281
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 3081439808 | 0.0 | 1.05 | vanilla | U = -2.866950035
cost_eps = 0.0
strike = 1.05
seed = 3081439808
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 3593294047 | 0.0 | 1.05 | expected | U = -2.768203735
cost_eps = 0.0
strike = 1.05
seed = 3593294047
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 2504117770 | 0.0 | 1.05 | expected | U = -2.653726578
cost_eps = 0.0
strike = 1.05
seed = 2504117770
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 4244424418 | 0.0 | 1.05 | expected | U = -2.630569696
cost_eps = 0.0
strike = 1.05
seed = 4244424418
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 1013705573 | 0.0 | 1.05 | distributional | U = -2.57022357
cost_eps = 0.0
strike = 1.05
seed = 1013705573
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 580205303 | 0.0 | 1.05 | distributional | U = -2.649965525
cost_eps = 0.0
strike = 1.05
seed = 580205303
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 4249105090 | 0.0 | 1.05 | distributional | U = -2.713878632
cost_eps = 0.0
strike = 1.05
seed = 4249105090
critic_update = 'distributional'
#====================================================================================================#

# Experiment 1/3 | vanilla | 2448734646 | 0.002 | 0.95 | vanilla | U = -7.586786747
cost_eps = 0.002
strike = 0.95
seed = 2448734646
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 348376058 | 0.002 | 0.95 | vanilla | U = -7.456678391
cost_eps = 0.002
strike = 0.95
seed = 348376058
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 1419667174 | 0.002 | 0.95 | vanilla | U = -7.765349388
cost_eps = 0.002
strike = 0.95
seed = 1419667174
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 1824290489 | 0.002 | 0.95 | expected | U = -8.031445503
cost_eps = 0.002
strike = 0.95
seed = 1824290489
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 4272408077 | 0.002 | 0.95 | expected | U = -7.508294582
cost_eps = 0.002
strike = 0.95
seed = 4272408077
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 1398375548 | 0.002 | 0.95 | expected | U = -7.78689909
cost_eps = 0.002
strike = 0.95
seed = 1398375548
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 3818513658 | 0.002 | 0.95 | distributional | U = -7.737299919
cost_eps = 0.002
strike = 0.95
seed = 3818513658
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 3667969190 | 0.002 | 0.95 | distributional | U = -7.654401779
cost_eps = 0.002
strike = 0.95
seed = 3667969190
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 3851442698 | 0.002 | 0.95 | distributional | U = -7.907392502
cost_eps = 0.002
strike = 0.95
seed = 3851442698
critic_update = 'distributional'

# Experiment 1/3 | vanilla | 4281868934 | 0.002 | 1.0 | vanilla | U = -4.73997736
cost_eps = 0.002
strike = 1.0
seed = 4281868934
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 4030196172 | 0.002 | 1.0 | vanilla | U = -4.643164635
cost_eps = 0.002
strike = 1.0
seed = 4030196172
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 1888259215 | 0.002 | 1.0 | vanilla | U = -4.695047379
cost_eps = 0.002
strike = 1.0
seed = 1888259215
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 3784581555 | 0.002 | 1.0 | expected | U = -4.674203873
cost_eps = 0.002
strike = 1.0
seed = 3784581555
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 398332362 | 0.002 | 1.0 | expected | U = -4.671273708
cost_eps = 0.002
strike = 1.0
seed = 398332362
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 555401847 | 0.002 | 1.0 | expected | U = -4.776445389
cost_eps = 0.002
strike = 1.0
seed = 555401847
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 2551650142 | 0.002 | 1.0 | distributional | U = -4.59783268
cost_eps = 0.002
strike = 1.0
seed = 2551650142
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 2513214051 | 0.002 | 1.0 | distributional | U = -4.423602104
cost_eps = 0.002
strike = 1.0
seed = 2513214051
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 3133990731 | 0.002 | 1.0 | distributional | U = -4.502878666
cost_eps = 0.002
strike = 1.0
seed = 3133990731
critic_update = 'distributional'

# Experiment 1/3 | vanilla | 3007081222 | 0.002 | 1.05 | vanilla | U = -3.343116045
cost_eps = 0.002
strike = 1.05
seed = 3007081222
critic_update = 'vanilla'

# Experiment 2/3 | vanilla | 4150151047 | 0.002 | 1.05 | vanilla | U = -3.063335657
cost_eps = 0.002
strike = 1.05
seed = 4150151047
critic_update = 'vanilla'

# Experiment 3/3 | vanilla | 561024990 | 0.002 | 1.05 | vanilla | U = -3.23662281
cost_eps = 0.002
strike = 1.05
seed = 561024990
critic_update = 'vanilla'

# Experiment 1/3 | actor_critic | 2655664254 | 0.002 | 1.05 | expected | U = -2.839249611
cost_eps = 0.002
strike = 1.05
seed = 2655664254
critic_update = 'expected'

# Experiment 2/3 | actor_critic | 2783584603 | 0.002 | 1.05 | expected | U = -2.984975338
cost_eps = 0.002
strike = 1.05
seed = 2783584603
critic_update = 'expected'

# Experiment 3/3 | actor_critic | 4269260147 | 0.002 | 1.05 | expected | U = -2.940672398
cost_eps = 0.002
strike = 1.05
seed = 4269260147
critic_update = 'expected'

# Experiment 1/3 | actor_critic | 550553097 | 0.002 | 1.05 | distributional | U = -2.790858746
cost_eps = 0.002
strike = 1.05
seed = 550553097
critic_update = 'distributional'

# Experiment 2/3 | actor_critic | 1367318865 | 0.002 | 1.05 | distributional | U = -3.111541033
cost_eps = 0.002
strike = 1.05
seed = 1367318865
critic_update = 'distributional'

# Experiment 3/3 | actor_critic | 2499673988 | 0.002 | 1.05 | distributional | U = -2.870720625
cost_eps = 0.002
strike = 1.05
seed = 2499673988
critic_update = 'distributional'

In [None]:
actor_opt = "radam"
actor_lr = 1E-3

if critic_update == "vanilla":
    model = "vanilla"
    critic_opt = None
    critic_lr = None
else:
    model = "actor_critic"
    critic_opt = "adam"
    critic_lr = 1E-2

hparams = dict(
    model=model,
    num_days=num_days,
    num_trading_days=num_trading_days,
    mu=mu,
    sigma=sigma,
    initial_price=initial_price,
    strike=strike,
    cost_eps=cost_eps,
    utility_lambda=utility_lambda,
    train_num_paths=train_num_paths, 
    actor_opt=actor_opt,
    actor_lr=actor_lr,
    critic_update=critic_update,
    critic_opt=critic_opt,
    critic_lr=critic_lr,
)


In [None]:
best_utility, bs_utility, best_params, all_metrics, all_params = train_experiment(seed=seed, train_steps=train_steps, **hparams)

In [None]:
print(f"best utility / black-scholes utility = {best_utility:.3f}/{bs_utility:.3f}")

In [None]:
prefix = (
    f"{hparams['cost_eps']}_{hparams['strike']}_{hparams['critic_update']}_{seed}"
)
source.utils.save_file(f"./logs/{prefix}-best_params.pkl", best_params)
source.utils.save_file(f"./logs/{prefix}-all_metrics.pkl", all_metrics)
source.utils.save_file(f"./logs/{prefix}-all_params.pkl", all_params)
