In [1]:
import sys

if "google.colab" in sys.modules:
  %pip install optax qiskit qcware qcware-quasar
  ! rm -rf deep-hedging
  ! git clone https://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@github.com/SnehalRaj/jpmc-qcware-deephedging deep-hedging
  ! cp -r deep-hedging/* .
  from google.colab import drive
  drive.mount('/content/drive')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Cloning into 'deep-hedging'...
remote: Enumerating objects: 572, done.[K
remote: Counting objects: 100% (379/379), done.[K
remote: Compressing objects: 100% (176/176), done.[K
remote: Total 572 (delta 262), reused 258 (delta 203), pack-reused 193[K
Receiving objects: 100% (572/572), 11.88 MiB | 20.65 MiB/s, done.
Resolving deltas: 100% (370/370), done.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import itertools
import sys
import warnings
from math import factorial
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
import optax

warnings.filterwarnings("ignore")


class Agent(NamedTuple):
    init: Callable
    train_step: Callable
    eval_step: Callable


def binomial(n, k):
    return factorial(n) // factorial(k) // factorial(n - k)


def compute_black_scholes_deltas(
    seq_prices,
    *,
    num_days=8,
    num_trading_days=252,
    mu=0.0,
    sigma=0.5,
    strike=1.0,
):
    seq_prices = seq_prices[..., None]
    seq_prices = seq_prices[:, :-1]
    strike_price = seq_prices[0, 0] * strike
    T = jnp.arange(1, num_days + 1) / num_trading_days
    T = jnp.repeat(jnp.flip(T[None, :]), seq_prices.shape[0], 0)
    d1 = jnp.divide(
        jnp.log(seq_prices[..., 0] / strike_price) + (mu + 0.5 * sigma**2) * T,
        sigma * jnp.sqrt(T),
    )
    seq_deltas = jax.scipy.stats.norm.cdf(d1, 0.0, 1.0)
    return seq_deltas


def compute_prices(
    seq_jumps,
    *,
    num_trading_days=252,
    mu=0.0,
    sigma=0.5,
    initial_price=100.0,
):
    num_jumps = 1
    bernoulli_prob = 0.5
    seq_jumps = seq_jumps - bernoulli_prob  # mean 0
    seq_jumps /= np.sqrt(bernoulli_prob * (1 - bernoulli_prob))  # std 1
    num_paths, num_days = seq_jumps.shape
    seq_jumps = seq_jumps.reshape(num_paths, num_days * num_jumps)
    brownian = jnp.cumsum(seq_jumps, axis=1)
    brownian /= np.sqrt(num_jumps * num_trading_days)
    t = jnp.arange(1, 1 + num_days) / num_trading_days
    log_prices = (mu - sigma**2 / 2) * t + sigma * brownian
    seq_prices = jnp.exp(log_prices)
    seq_prices = jnp.concatenate([jnp.ones((num_paths, 1)), seq_prices], axis=1)
    seq_prices *= initial_price
    return seq_prices


def compute_rewards(seq_prices, seq_deltas, *, strike=0.9, cost_eps=0.0):
    seq_actions = [
        seq_deltas[:, [0]],
        seq_deltas[:, 1:] - seq_deltas[:, :-1],
        -seq_deltas[:, [-1]],
    ]
    seq_actions = jnp.concatenate(seq_actions, axis=1)
    payoff = -jnp.maximum(seq_prices[:, -1] - strike * seq_prices[:, 0], 0.0)
    costs = -(jnp.abs(seq_actions) * cost_eps + seq_actions) * seq_prices
    seq_rewards = costs.at[:, -1].add(payoff)
    return seq_rewards


def compute_bounds(
    num_days=8,
    num_trading_days=252,
    mu=0.0,
    sigma=0.5,
    initial_price=100.0,
    strike=0.9,
    cost_eps=0.0,
):
    # TODO: add cost_eps
    jumps_max = jnp.ones((num_days))
    jumps_min = jnp.zeros((num_days))
    seq_jumps = jnp.stack([jumps_min, jumps_max], axis=0)
    prices_min, prices_max = compute_prices(
        seq_jumps,
        num_trading_days=num_trading_days,
        mu=mu,
        sigma=sigma,
        initial_price=initial_price,
    )
    payoffs_min = -jnp.maximum(prices_max - strike * initial_price, 0)
    values_max = (2 * (prices_max - strike * initial_price))[::-1][:-1]
    values_min = (2 * (prices_min - strike * initial_price) + payoffs_min)[::-1][:-1]
    Gt_range = jnp.stack((values_min, values_max), axis=0)
    return Gt_range


def compute_returns(seq_rewards):
    seq_returns = jnp.cumsum(seq_rewards[:, ::-1], axis=1)[:, ::-1]
    return seq_returns


def compute_utility(seq_rewards, *, utility_lambda=1.0):
    returns = seq_rewards.sum(axis=1)
    utility = (
        -1 / utility_lambda * jnp.log(jnp.mean(jnp.exp(-utility_lambda * returns)))
    )
    return utility


def get_pyramid_idxs(num_qubits):
    num_max = num_qubits
    num_min = num_qubits - 1
    if num_max == num_min:
        num_min -= 1
    end_idxs = np.concatenate(
        [np.arange(1, num_max - 1), num_max - np.arange(1, num_min + 1)]
    )
    start_idxs = np.concatenate(
        [
            np.arange(end_idxs.shape[0] + num_min - num_max) % 2,
            np.arange(num_max - num_min),
        ]
    )
    rbs_idxs = [
        np.arange(start_idxs[i], end_idxs[i] + 1).reshape(-1, 2)
        for i in range(len(start_idxs))
    ]
    return rbs_idxs


def get_butterfly_idxs(num_qubits):
    def _get_butterfly_idxs(n):
        if n == 2:
            return np.array([[[0, 1]]])
        else:
            rbs_idxs = _get_butterfly_idxs(n // 2)
            first = np.concatenate([rbs_idxs, rbs_idxs + n // 2], 1)
            last = np.arange(n).reshape(1, 2, n // 2).transpose(0, 2, 1)
            rbs_idxs = np.concatenate([first, last], 0)
            return rbs_idxs

    rbs_idxs = _get_butterfly_idxs(int(2 ** np.ceil(np.log2(num_qubits))))
    rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
    rbs_idxs = [
        [
            [i, j]
            for i, j in rbs_idx
            if (i in range(num_qubits)) and (j in range(num_qubits))
        ]
        for rbs_idx in rbs_idxs
    ]
    return rbs_idxs[::-1]


def get_triangle_idxs(num_qubits):
    rbs_idxs = [[(i, i + 1)] for i in range(num_qubits - 1)]
    rbs_idxs += rbs_idxs[::-1]
    return rbs_idxs


def get_iks_idxs(num_qubits):
    rbs_idxs_down = [[(i, i + 1)] for i in range(num_qubits - 1)]
    rbs_idxs_up = [[(i, i + 1)] for i in range(num_qubits - 1)][::-1]
    rbs_idxs = [
        (m + n if m != n else m) for m, n in zip(rbs_idxs_down, rbs_idxs_up)
    ] + rbs_idxs_down[num_qubits - 1 :]
    return rbs_idxs


def make_ortho_fn(rbs_idxs, num_qubits):
    rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
    len_idxs = np.cumsum([0] + list(map(len, rbs_idxs)))

    def get_rbs_unary(theta):
        cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
        unary = jnp.array(
            [
                [cos_theta, sin_theta],
                [-sin_theta, cos_theta],
            ]
        )
        unary = unary.transpose(*[*range(2, unary.ndim), 0, 1])
        return unary

    def get_rbs_unary_grad(theta):
        cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
        unary = jnp.array(
            [
                [-sin_theta, cos_theta],
                [-cos_theta, -sin_theta],
            ]
        )
        unary = unary.transpose(*[*range(2, unary.ndim), 0, 1])
        return unary

    @jax.custom_jvp
    def get_parallel_rbs_unary(thetas):
        unitaries = []
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i] : len_idxs[i + 1]]
            rbs_blocks = get_rbs_unary(sub_thetas)
            eye_block = jnp.eye(num_qubits - len(idxs), dtype=thetas.dtype)
            permutation = idxs + [i for i in range(num_qubits) if i not in idxs]
            permutation = np.argsort(permutation)
            unary = jax.scipy.linalg.block_diag(*rbs_blocks, eye_block)
            unary = unary[permutation][:, permutation]
            unitaries.append(unary)
        unitaries = jnp.stack(unitaries)
        return unitaries

    @get_parallel_rbs_unary.defjvp
    def get_parallel_rbs_unary_jvp(primals, tangents):
        (thetas,) = primals
        (thetas_dot,) = tangents
        unitaries = []
        unitaries_dot = []
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i] : len_idxs[i + 1]]
            sub_thetas_dot = thetas_dot[len_idxs[i] : len_idxs[i + 1]]
            rbs_blocks = get_rbs_unary(sub_thetas)
            rbs_blocks_grad = get_rbs_unary_grad(sub_thetas)
            rbs_blocks_dot = sub_thetas_dot[..., None, None] * rbs_blocks_grad
            eye_block = jnp.eye(num_qubits - len(idxs), dtype=thetas.dtype)
            zero_block = jnp.zeros_like(eye_block)
            permutation = idxs + [i for i in range(num_qubits) if i not in idxs]
            permutation = np.argsort(permutation)
            unary = jax.scipy.linalg.block_diag(*rbs_blocks, eye_block)
            unary_dot = jax.scipy.linalg.block_diag(*rbs_blocks_dot, zero_block)
            unary = unary[permutation][:, permutation]
            unary_dot = unary_dot[permutation][:, permutation]
            unitaries.append(unary)
            unitaries_dot.append(unary_dot)
        primal_out = jnp.stack(unitaries)
        tangent_out = jnp.stack(unitaries_dot)
        return primal_out, tangent_out

    def orthogonal_fn(thetas):
        unitaries = get_parallel_rbs_unary(thetas)
        unary = jnp.linalg.multi_dot(unitaries[::-1])
        return unary

    return orthogonal_fn


def compute_compound(unary, order=1):
    num_qubits = unary.shape[-1]
    if (order == 0) or (order == num_qubits):
        return jnp.ones((1, 1))
    elif order == 1:
        return unary
    else:
        subsets = list(itertools.combinations(range(num_qubits), order))
        compounds = unary[subsets, ...][..., subsets].transpose(0, 2, 1, 3)
        compound = jnp.linalg.det(compounds)
    return compound


def decompose_state(state):
    num_qubits = int(np.log2(state.shape[-1]))
    batch_dims = state.shape[:-1]
    state = state.reshape(-1, 2**num_qubits)
    idxs = list(itertools.product(*[[0, 1]] * num_qubits))
    subspace_idxs = [
        [
            (np.array(idx) * 2 ** np.arange(num_qubits)[::-1]).sum()
            for idx in idxs
            if sum(idx) == weight
        ]
        for weight in range(num_qubits + 1)
    ]
    subspace_states = [
        state[..., subspace_idxs[weight]] for weight in range(num_qubits + 1)
    ]
    alphas = [
        jnp.linalg.norm(subspace_state, axis=-1) for subspace_state in subspace_states
    ]
    betas = [
        subspace_state / (alpha[..., None] + 1e-6)
        for alpha, subspace_state in zip(alphas, subspace_states)
    ]
    alphas = [alpha.reshape(*batch_dims, -1) for alpha in alphas]
    betas = [beta.reshape(*batch_dims, -1) for beta in betas]
    alphas = jnp.stack(alphas, -1)[..., 0, :]
    return alphas, betas

def get_square_idxs(num_qubits, num_layers=None):
    if num_layers is None:
        num_layers = 1+int(np.log2(num_qubits))
    rbs_idxs = [[(i,i+1) for i in range(0,num_qubits-1,2)]]
    rbs_idxs += [[(i,i+1) for i in range(1,num_qubits-1,2)]]
    return rbs_idxs * num_layers

#get_square_idxs = get_pyramid_idxs

In [3]:
np.set_printoptions(formatter={'float': "{0:0.3f}".format})

### Hardware exp circuit construction

In [4]:
# Global counter

global_number_of_circuits_executed = 0

# Global object keeping track of result
# Used for pickling
# Populated initially in DeepHedgingBenchmark().__test_model
# and with run results in run_circuit
# keeping track of batch_idx in scan (under "Models")

global_hardware_run_results_dict = {}

In [5]:
import qiskit

import quasar
from qcware_transpile.translations.quasar.to_qiskit import translate
from qiskit.compiler import assemble
import collections

from qio import loader

import numpy as np
from qnn import _get_butterfly_idxs, _get_pyramid_idxs, _make_orthogonal_fn
# fix for older versions of Qiskit
if qiskit.__version__ <= '0.37.1':
    import qiskit.providers.aer.noise as noise
else:
    import qiskit_aer.noise as noise
import json
import pickle
import time
import copy
from pathlib import Path
from tqdm import tqdm
import itertools
from utils import save_params, load_params
import datetime
def prepare_circuit(rbs_idxs, time_step, num_qubits, seq_jumps, thetas):
    def _get_layer_circuit(params):
      _params = np.array(params).astype('float')
      circuit_layer = quasar.Circuit()
      idx_angle = 0
      for gates_per_timestep in rbs_idxs:
        for gate in gates_per_timestep:
          circuit_layer.add_gate(quasar.Gate.RBS(theta=-_params[idx_angle]), tuple(gate))
          idx_angle+=1
      return circuit_layer

    first_gates = [quasar.Circuit().H(0)]*(num_qubits-2)  + [quasar.Circuit().I(0)] +[quasar.Circuit().X(0)]
    circuit = quasar.Circuit.join_in_qubits(first_gates)
    if time_step ==0:
      layer_circuit = _get_layer_circuit(thetas[0])
      circuit = quasar.Circuit.join_in_time([circuit, layer_circuit])
    else:
      thetas = thetas.reshape(2,time_step, -1)
      for idx,jump in enumerate(seq_jumps):
        layer_circuit = _get_layer_circuit(thetas[int(jump)][idx])
        circuit = quasar.Circuit.join_in_time([circuit, layer_circuit])
    # Translate from qcware-quasar to qiskit
    qiskit_circuit = translate(circuit)
    # qiskit_circuit.save_statevector()
    qiskit_circuit = qiskit.transpile(qiskit_circuit, optimization_level=0)
    c = qiskit.ClassicalRegister(num_qubits)
    qiskit_circuit.add_register(c)
    qiskit_circuit.barrier()
    qiskit_circuit.measure(qubit=range(num_qubits),cbit=c)
    return qiskit_circuit



def counter_to_dict(c):
    """Converts counter returned by pytket get_counts function
    to dictionary returned by qiskit
    canonical use:
    >>> result = backend.get_result(handle)
    >>> counts = result.get_counts(basis=BasisOrder.dlo)
    >>> counts_qiskit = counter_to_dict(counts)
    """
    d = {}
    for k, v in c.items():
        d[''.join(str(x) for x in k)] = int(v)
    return d

def run_circuit(circs,num_qubits, backend_name = 'qiskit_noisy'):
    """
    backend name accepted 
    """
    global global_number_of_circuits_executed
    global global_hardware_run_results_dict
    results = np.zeros((len(circs), 2**num_qubits))
    
    global_number_of_circuits_executed += len(circs)
    num_measurements = 1000
    
    if "qiskit" in backend_name:
        backend = qiskit.Aer.get_backend('qasm_simulator')
        if backend_name == 'qiskit_noiseless':
            measurement = qiskit.execute(circs, backend, shots=num_measurements)
        elif backend_name == 'qiskit_noisy': 
            # Error probabilities
            prob_1 = 0.001  # 1-qubit gate
            prob_2 = 0.01   # 2-qubit gate
            # Dylan's tunes error probabilities
            # prob_1 = 0  # 1-qubit gate
            # prob_2 = 3.5e-3   # 2-qubit gate

            # Depolarizing quantum errors
            error_1 = noise.depolarizing_error(prob_1, 1)
            error_2 = noise.depolarizing_error(prob_2, 2)

            # Add errors to noise model
            noise_model = noise.NoiseModel()
            noise_model.add_all_qubit_quantum_error(error_1, ['h', 'x', 'ry'])
            noise_model.add_all_qubit_quantum_error(error_2, ['cz'])

            # Get basis gates from noise model
            basis_gates = noise_model.basis_gates
            measurement = qiskit.execute(circs, backend,basis_gates=basis_gates, noise_mode=noise_model, shots=num_measurements)
        else:
            raise ValueError(f"Unexpected backend name {backend_name}")
        all_counts = measurement.result().get_counts()
    elif "quantinuum" in backend_name:
        # From docs: "Batches cannot exceed the maximum limit of 500 H-System Quantum Credits (HQCs) total"
        # Therefore batching is more or less useless on quantinuum
        from pytket.extensions.qiskit import qiskit_to_tk
        from pytket.circuit import BasisOrder
        from pytket.extensions.quantinuum import QuantinuumBackend
    
        outpath_stem = "_".join([
            "1103_30_points",
            global_hardware_run_results_dict['model_type'],
            backend_name,
            global_hardware_run_results_dict['layer_type'],
            str(global_hardware_run_results_dict['epsilon']),
            str(global_hardware_run_results_dict['batch_idx']),
        ])
        
        outpath_result_final = f"data/{outpath_stem}.json"
        outpath_handles = f"data/handles_{outpath_stem}.pickle"
        
        if Path(outpath_result_final).exists():
            # if precomputed results already present on disk, simply load
            print(f"Using precomputed counts from {outpath_result_final}")
            all_counts = json.load(open(outpath_result_final, "r"))['all_counts']
        else:
            if backend_name == "quantinuum_H1-2E":
                backend = QuantinuumBackend(device_name="H1-2E")
            elif backend_name == "quantinuum_H1-2":
                backend = QuantinuumBackend(device_name="H1-2")
            else:
                raise ValueError(f"Unknown Quantinuum backend: {backend_name}")
            if Path(outpath_handles).exists():
                # if circuits already submitted, simply load from disk
                print(f"Using pickled handles from {outpath_handles}")
                handles = pickle.load(open(outpath_handles, "rb"))
            else:
                # otherwise, submit circuits and pickle handles
                circs_tk = [qiskit_to_tk(circ) for circ in circs]
                for idx, circ in enumerate(circs_tk):
                    circ.name = f'{outpath_stem}_{idx+1}_of_{len(circs)}'
                compiled_circuits = backend.get_compiled_circuits(circs_tk, optimisation_level=2)
                handles = backend.process_circuits(compiled_circuits, n_shots=num_measurements)
                pickle.dump(handles, open(outpath_handles, "wb"))
                print(f"Dumped handles to {outpath_handles}")
            # retrieve results from handles
            result_list = []
            
            with tqdm(total=len(handles), desc='#jobs finished') as pbar:
                for handle in handles:
                    while True:
                        status = backend.circuit_status(handle).status
                        if status.name == 'COMPLETED':
                            result = backend.get_result(handle)
                            result_list.append(copy.deepcopy(result))
                            pbar.update(1)
                            break
                        else:
                            assert status.name in ['QUEUED', 'RUNNING'] 
                        time.sleep(1)
            global_hardware_run_results_dict['result_list'] = [x.to_dict() for x in result_list]
            # convert from tket counts format to qiskit
            all_counts = [
                counter_to_dict(
                    result.get_counts(basis=BasisOrder.dlo)
                ) for result in result_list
            ]
            global_hardware_run_results_dict['all_counts'] = all_counts
            # dump result on disk
            json.dump(global_hardware_run_results_dict, open(outpath_result_final, "w"))
    else:
        raise ValueError(f"Unexpected backend name {backend_name}")
    global_hardware_run_results_dict['batch_idx'] += 1 

    # Post processing
    for j in range(len(circs)):
        if len(circs) == 1:
            measurementRes = all_counts
        else:
            measurementRes = all_counts[j]
        num_qubits = len(list(measurementRes)[0]) 
        filtered_counts = {f"{i:0{num_qubits}b}":0 for i in range(2**num_qubits)}
        num_postselected = 0
        for bitstring, count in measurementRes.items():
            ham_weight = sum([int(x) for x in bitstring])
            if ham_weight == 0 or ham_weight == num_qubits:
                continue
            filtered_counts[bitstring] = count
            num_postselected+= count
        results[j] = np.sqrt([filtered_counts[k]/num_postselected for k in sorted(filtered_counts)])
    return results

In [6]:
def make_ortho_fn(rbs_idxs, num_qubits):
    rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
    len_idxs = np.cumsum([0] + list(map(len, rbs_idxs)))
    def get_rbs_unary(theta):
        cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
        unary = jnp.array(
            [
                [cos_theta, sin_theta],
                [-sin_theta, cos_theta],
            ]
        )
        unary = unary.transpose(*[*range(2, unary.ndim), 0, 1])
        return unary
    def get_parallel_rbs_unary(thetas):
        unitaries = []
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i] : len_idxs[i + 1]]
            rbs_blocks = get_rbs_unary(sub_thetas)
            eye_block = jnp.eye(num_qubits - len(idxs), dtype=thetas.dtype)
            permutation = idxs + [i for i in range(num_qubits) if i not in idxs]
            permutation = np.argsort(permutation)
            unary = jax.scipy.linalg.block_diag(*rbs_blocks, eye_block)
            unary = unary[permutation][:, permutation]
            unitaries.append(unary)
        unitaries = jnp.stack(unitaries)
        return unitaries

    def orthogonal_fn(thetas):
        unitaries = get_parallel_rbs_unary(thetas)
        if len(unitaries) > 1:
            unary = jnp.linalg.multi_dot(unitaries[::-1])
        else:
            unary = unitaries[0]
        return unary[::-1][:,::-1]

    return orthogonal_fn


def compute_compound(unary, order=1):
    num_qubits = unary.shape[-1]
    if (order == 0) or (order == num_qubits):
        return jnp.ones((1, 1))
    elif order == 1:
        return unary
    else:
        subsets = list(itertools.combinations(range(num_qubits), order))
        compounds = unary[subsets, ...][..., subsets].transpose(0, 2, 1, 3)
        compound = jnp.linalg.det(compounds)
    return compound

def decompose_state(state):
    num_qubits = int(np.log2(state.shape[-1]))
    batch_dims = state.shape[:-1]
    state = state.reshape(-1, 2**num_qubits)
    idxs = list(itertools.product(*[[0, 1]] * num_qubits))
    subspace_idxs = [
        [
            int((2**np.array(bla)).sum())
            for bla in itertools.combinations(range(num_qubits), weight)
        ]
        for weight in range(num_qubits + 1)
    ]
    subspace_states = [
        state[..., subspace_idxs[weight]] for weight in range(num_qubits + 1)
    ]
    alphas = [
        jnp.linalg.norm(subspace_state, axis=-1) for subspace_state in subspace_states
    ]
    betas = [
        subspace_state / (alpha[..., None] + 1e-6)
        for alpha, subspace_state in zip(alphas, subspace_states)
    ]
    alphas = [alpha.reshape(*batch_dims, -1) for alpha in alphas]
    betas = [beta.reshape(*batch_dims, -1) for beta in betas]
    alphas = jnp.stack(alphas, -1)[..., 0, :]
    return alphas, betas


In [58]:
def make_agent(
    num_days=14,
    num_jumps=1,
    num_trading_days=252,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
    strike=1.0,
    cost_eps=0.0,
    train_num_paths=32,
    eval_num_paths=32,
    utility_lambda=0.1,
    model="vanilla",
):
    bernoulli_prob = 0.5

    def net_fn_apply(params, key, batch_jumps):
        for time_step in range(num_days):
            seq_jumps = batch_jumps[:,:time_step]
            num_qubits = num_days - time_step + 2
            rbs_idxs = get_square_idxs(num_qubits)
            num_params = sum(map(len, rbs_idxs))
            if time_step == 0:
                thetas_shape = (1, num_params)
            else:
                thetas_shape = (2 * time_step, num_params)
            thetas = params[0]["actor_thetas_{}".format(time_step)]
            state = jnp.ones((2 ** (num_days - time_step),)) / np.sqrt(
                2 ** (num_days - time_step)
            )
            state = jnp.kron(state, jnp.array([0.0, 1.0, 0.0, 0.0]))
            alphas, betas = decompose_state(state)
            thetas = thetas.reshape(-1, num_params)
            unaries = jax.vmap(make_ortho_fn(rbs_idxs, num_qubits))(thetas)
            if time_step == 0:
                seq_unaries = jnp.repeat(unaries, seq_jumps.shape[0], axis=0)
            else:
                unaries = unaries.reshape(2, time_step, num_qubits, num_qubits)
                seq_unaries = jnp.einsum("bt,tij->btij", seq_jumps, unaries[1])
                seq_unaries += jnp.einsum("bt,tij->btij", 1 - seq_jumps, unaries[0])
                if time_step > 1:
                    seq_unaries = jax.vmap(jnp.linalg.multi_dot)(seq_unaries[:,::-1,:,:])
                else:
                    seq_unaries = seq_unaries[:, 0]
            compounds = [
                jax.vmap(compute_compound, in_axes=(0, None))(seq_unaries, order)
                for order in range(num_qubits + 1)
            ]
            deltas_betas = [compound @ beta for compound, beta in zip(compounds, betas)]
            deltas_ranges = [(0, 1) for _ in range(len(deltas_betas))]
            deltas_dist = [
                beta**2 @ jnp.linspace(*delta_range, beta.shape[-1])
                for beta, delta_range in zip(deltas_betas, deltas_ranges)
            ]
            deltas_exp = [alpha**2 * dist for alpha, dist in zip(alphas, deltas_dist)]
            deltas_exp = jnp.array(deltas_exp).sum(0)
            if time_step == 0:
                seq_deltas_exp = [deltas_exp]
            else:
                seq_deltas_exp.append(deltas_exp)
        return (
            seq_jumps,
            seq_deltas_exp,
        )


    def hardware_net_fn_apply(params, key, batch_jumps):
        for time_step in range(num_days):
            seq_jumps = batch_jumps[:,:time_step]
            num_qubits = num_days - time_step + 2
            rbs_idxs = get_square_idxs(num_qubits)
            num_params = sum(map(len, rbs_idxs))
            if time_step == 0:
                thetas_shape = (1, num_params)
            else:
                thetas_shape = (2 * time_step, num_params)
            thetas = params[0]["actor_thetas_{}".format(time_step)]
            state = jnp.ones((2 ** (num_days - time_step),)) / np.sqrt(
                2 ** (num_days - time_step)
            )
            state = jnp.kron(state, jnp.array([0.0, 1.0, 0.0, 0.0]))
            alphas, betas = decompose_state(state)
            # Begin Quantum-HW
            circs = []
            for jumps in seq_jumps:
                circs.append(prepare_circuit(rbs_idxs, time_step, num_qubits, jumps, thetas))
            results = jnp.array(run_circuit(circs,num_qubits))
            deltas_alphas, deltas_betas = decompose_state(results)
            # End Quantum-HW
            deltas_ranges = [(0, 1) for _ in range(len(deltas_betas))]
            deltas_dist = [
                beta**2 @ jnp.linspace(*delta_range, beta.shape[-1])
                for beta, delta_range in zip(deltas_betas, deltas_ranges)
            ]
            deltas_exp = [alpha**2 * dist for alpha, dist in zip(alphas, deltas_dist)]
            deltas_exp = jnp.array(deltas_exp).sum(0)
            if time_step == 0:
                seq_deltas_exp = [deltas_exp]
            else:
                seq_deltas_exp.append(deltas_exp)
        
        return (
            seq_jumps,
            seq_deltas_exp,
        )
    
    def eval_step(params, batch_jumps):
        key = jax.random.PRNGKey(123)
        keys = jax.random.split(key, 2)
        keys = jax.random.split(key, 4)
        net_params = params
        
        seq_jumps, seq_deltas_exp = net_fn_apply(net_params, keys[0], batch_jumps)
        seq_jumps, seq_deltas_exp_hw = hardware_net_fn_apply(net_params, keys[0], batch_jumps)

        day_jumps = jax.random.bernoulli(
            keys[1], bernoulli_prob, (seq_jumps.shape[0], 1)
        )
        seq_jumps = jnp.concatenate([seq_jumps, day_jumps], axis=-1)
        seq_prices = compute_prices(
            seq_jumps,
            num_trading_days=num_trading_days,
            mu=mu,
            sigma=sigma,
            initial_price=initial_price,
        )
        seq_deltas_hw = jnp.stack(seq_deltas_exp_hw,axis=1)
        seq_deltas = jnp.stack(seq_deltas_exp, axis=1)
        seq_rewards = compute_rewards(
            seq_prices, seq_deltas, strike=strike, cost_eps=cost_eps
        )
        seq_bs_deltas = compute_black_scholes_deltas(
            seq_prices,
            num_days=num_days,
            num_trading_days=num_trading_days,
            mu=mu,
            sigma=sigma,
            strike=strike,
        )
        seq_rewards = compute_rewards(
            seq_prices, seq_deltas, strike=strike, cost_eps=cost_eps
        )
        seq_hw_rewards = compute_rewards(
            seq_prices, seq_deltas_hw, strike=strike, cost_eps=cost_eps
        ) 
        seq_bs_rewards = compute_rewards(
            seq_prices, seq_bs_deltas, strike=strike, cost_eps=cost_eps
        )
        returns = seq_rewards.sum(axis=1).mean()
        hw_returns = seq_hw_rewards.sum(axis=1).mean()
        bs_returns = seq_bs_rewards.sum(axis=1).mean()
        metrics = {
            "returns": returns,
            "hw_returns": hw_returns,
            "bs_returns": bs_returns,
            "seq_deltas": seq_deltas_exp[0][0],
            "seq_deltas_hw": seq_deltas_exp_hw[0][0],
        }
        utility_lambda = 1E-1
        utility = compute_utility(seq_rewards, utility_lambda=utility_lambda)
        hw_utility = compute_utility(seq_hw_rewards, utility_lambda=utility_lambda)
        bs_utility = compute_utility(seq_bs_rewards, utility_lambda=utility_lambda)
        metrics[f'U_{utility_lambda}'] = utility
        metrics[f'U_hw_{utility_lambda}'] = hw_utility
        metrics[f'U_bs_{utility_lambda}'] = bs_utility
        return metrics

    return Agent(init=None, train_step=None, eval_step=eval_step)


def experiment(hparams, seed, params_save_loc, jumps_save_loc):
    global global_number_of_circuits_executed
    global global_hardware_run_results_dict
    global_number_of_circuits_executed = 0
    global_hardware_run_results_dict = {
        'model_type' : hparams["model"],
        'measurementRes' : None,
        'epsilon' : hparams["cost_eps"],
        'backend_name' : None,
        'num_trading_days' : hparams["num_trading_days"],
        'batch_idx' : 0,
    }
    agent = make_agent(**hparams)
    params = load_params(params_save_loc)
    batch_jumps = load_params(jumps_save_loc)
    eval_metrics = agent.eval_step(params, batch_jumps)
    eval_metrics = jax.device_get(eval_metrics)
    eval_metrics = jax.tree_map(float, eval_metrics)
    print(f'Total number of circuits executed = {global_number_of_circuits_executed}')

    utility_lambda = 1E-1
    a = eval_metrics[f'U_{utility_lambda}']
    b = eval_metrics[f'U_bs_{utility_lambda}']
    c = eval_metrics[f'U_hw_{utility_lambda}']
    print("lambda {}, bs {:.2f}, agent {:.2f}, hw agent {:,.2f}".format(utility_lambda, b,a,c ))
            

In [59]:
num_days = 10
env_kwargs = dict(
    num_days=num_days,
    num_jumps=1,
    num_trading_days=30,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
    strike=1.,
    cost_eps=0.,
    utility_lambda=0.1,
)

hparams = dict(env_kwargs)
hparams["model"] = "distributional"
experiment(hparams, seed=19983, params_save_loc='./10-0.0-1.0_distributional_20221114-141240-3.pkl', jumps_save_loc= './seq_jumps_10_days')

Total number of circuits executed = 160
lambda 0.1, bs -4.42, agent -4.13, hw agent -4.14
