In [1]:
import warnings
import pickle

import jax
import jax.numpy as jnp
import numpy as np
import source
from source.hardware_utils import prepare_circuit_compound, run_circuit_compound

warnings.filterwarnings("ignore")
np.set_printoptions(formatter={"float": "{0:0.3f}".format})

In [2]:
def make_agent(
    num_days=14,
    num_trading_days=252,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
    strike=1.0,
    cost_eps=0.0,
    train_num_paths=32,
    eval_num_paths=32,
    utility_lambda=0.1,
    model="vanilla",
):
    bernoulli_prob = 0.5

    def net_fn_apply(params, batch_jumps):
        for time_step in range(num_days):
            seq_jumps = batch_jumps[:, :time_step]
            num_qubits = num_days - time_step + 2
            depth = 2 * max(1, time_step) * int(np.log2(num_qubits))
            if depth <= 10:
                num_layers = int(np.log2(num_qubits))
            else:
                num_layers = max(int(np.log2(num_qubits)) // 2, 1)
            rbs_idxs = source.quantum.get_brick_idxs(num_qubits, num_layers=num_layers)
            num_params = sum(map(len, rbs_idxs))
            thetas = params[0]["actor_thetas_{}".format(time_step)]
            state = jnp.ones((2 ** (num_days - time_step),)) / np.sqrt(
                2 ** (num_days - time_step)
            )
            state = jnp.kron(state, jnp.array([0.0, 1.0, 0.0, 0.0]))
            alphas, betas = source.quantum.decompose_state(state)
            thetas = thetas.reshape(-1, num_params)
            unaries = jax.vmap(source.quantum.make_ortho_fn(rbs_idxs, num_qubits))(
                thetas
            )
            if time_step == 0:
                seq_unaries = jnp.repeat(unaries, seq_jumps.shape[0], axis=0)
            else:
                unaries = unaries.reshape(2, time_step, num_qubits, num_qubits)
                seq_unaries = jnp.einsum("bt,tij->btij", seq_jumps, unaries[1])
                seq_unaries += jnp.einsum("bt,tij->btij", 1 - seq_jumps, unaries[0])
                if time_step > 1:
                    seq_unaries = jax.vmap(jnp.linalg.multi_dot)(
                        seq_unaries[:, ::-1, :, :]
                    )
                else:
                    seq_unaries = seq_unaries[:, 0]
            compounds = [
                jax.vmap(source.quantum.compute_compound, in_axes=(0, None))(
                    seq_unaries, order
                )
                for order in range(num_qubits + 1)
            ]
            deltas_betas = [compound @ beta for compound, beta in zip(compounds, betas)]
            deltas_ranges = [(0, 1) for _ in range(len(deltas_betas))]
            deltas_dist = [
                beta**2 @ jnp.linspace(*delta_range, beta.shape[-1])
                for beta, delta_range in zip(deltas_betas, deltas_ranges)
            ]
            deltas_exp = [alpha**2 * dist for alpha, dist in zip(alphas, deltas_dist)]
            deltas_exp = jnp.array(deltas_exp).sum(0)
            if time_step == 0:
                seq_deltas_exp = [deltas_exp]
            else:
                seq_deltas_exp.append(deltas_exp)
        return (
            seq_jumps,
            seq_deltas_exp,
        )

    def hardware_net_fn_apply(params, batch_jumps, device_id, global_info, backend_name):
        for time_step in range(num_days):
            seq_jumps = batch_jumps[:, :time_step]
            num_qubits = num_days - time_step + 2
            depth = 2 * max(1, time_step) * int(np.log2(num_qubits))
            if depth <= 10:
                num_layers = int(np.log2(num_qubits))
            else:
                num_layers = max(int(np.log2(num_qubits)) // 2, 1)
            rbs_idxs = source.quantum.get_brick_idxs(num_qubits, num_layers=num_layers)
            thetas = params[0]["actor_thetas_{}".format(time_step)]
            state = jnp.ones((2 ** (num_days - time_step),)) / np.sqrt(
                2 ** (num_days - time_step)
            )
            state = jnp.kron(state, jnp.array([0.0, 1.0, 0.0, 0.0]))
            alphas, _ = source.quantum.decompose_state(state)
            # Begin Quantum-HW
            circs = []
            for jumps in seq_jumps:
                circs.append(
                    prepare_circuit_compound(
                        rbs_idxs, time_step, num_qubits, jumps, thetas
                    )
                )
            
            results, global_info = run_circuit_compound(circs, num_qubits, device_id, global_info, backend_name)


            _, deltas_betas = source.quantum.decompose_state(jnp.array(results))
            # End Quantum-HW
            deltas_ranges = [(0, 1) for _ in range(len(deltas_betas))]
            deltas_dist = [
                beta**2 @ jnp.linspace(*delta_range, beta.shape[-1])
                for beta, delta_range in zip(deltas_betas, deltas_ranges)
            ]
            deltas_exp = [alpha**2 * dist for alpha, dist in zip(alphas, deltas_dist)]
            deltas_exp = jnp.array(deltas_exp).sum(0)
            if time_step == 0:
                seq_deltas_exp = [deltas_exp]
            else:
                seq_deltas_exp.append(deltas_exp)

        return (
            seq_jumps,
            seq_deltas_exp,
        ), global_info

    def eval_step(params, batch_jumps, device_id, global_info, backend_name):
        key = jax.random.PRNGKey(123)
        keys = jax.random.split(key, 4)
        net_params = params

        seq_jumps, seq_deltas_exp = net_fn_apply(net_params, batch_jumps)
        (seq_jumps, seq_deltas_exp_hw), global_info = hardware_net_fn_apply(
            net_params, batch_jumps, device_id, global_info, backend_name
        )

        # Only needed for reproducibility
        if backend_name[-1] == "E":
            final_day_prob = 0.5
        else:
            final_day_prob = 1.0
        day_jumps = jax.random.bernoulli(
            keys[1], final_day_prob, (seq_jumps.shape[0], 1)
        )
        seq_jumps = jnp.concatenate([seq_jumps, day_jumps], axis=-1)
        seq_prices = source.env.compute_prices(
            seq_jumps,
            num_trading_days=num_trading_days,
            mu=mu,
            sigma=sigma,
            initial_price=initial_price,
        )
        seq_deltas_hw = jnp.stack(seq_deltas_exp_hw, axis=1)
        seq_deltas = jnp.stack(seq_deltas_exp, axis=1)
        seq_rewards = source.env.compute_rewards(
            seq_prices, seq_deltas, strike=strike, cost_eps=cost_eps
        )
        seq_bs_deltas = source.env.compute_black_scholes_deltas(
            seq_prices,
            num_days=num_days,
            num_trading_days=num_trading_days,
            mu=mu,
            sigma=sigma,
            strike=strike,
        )
        seq_rewards = source.env.compute_rewards(
            seq_prices, seq_deltas, strike=strike, cost_eps=cost_eps
        )
        seq_hw_rewards = source.env.compute_rewards(
            seq_prices, seq_deltas_hw, strike=strike, cost_eps=cost_eps
        )
        seq_bs_rewards = source.env.compute_rewards(
            seq_prices, seq_bs_deltas, strike=strike, cost_eps=cost_eps
        )
        returns = seq_rewards.sum(axis=1)
        hw_returns = seq_hw_rewards.sum(axis=1)
        metrics = {
            "returns": jnp.array(returns),
            "hw_returns": jnp.array(hw_returns),
            "seq_deltas": jnp.array(seq_deltas_exp),
            "seq_deltas_hw": jnp.array(seq_deltas_exp_hw),
        }
        utility = source.env.compute_utility(seq_rewards, utility_lambda=utility_lambda)
        hw_utility = source.env.compute_utility(
            seq_hw_rewards, utility_lambda=utility_lambda
        )
        bs_utility = source.env.compute_utility(
            seq_bs_rewards, utility_lambda=utility_lambda
        )
        metrics[f"U_{utility_lambda}"] = utility
        metrics[f"U_hw_{utility_lambda}"] = hw_utility
        metrics[f"U_bs_{utility_lambda}"] = bs_utility
        metrics["seq_prices"] = seq_prices
        return metrics, global_info

    return eval_step


def experiment(hparams, seed, params_save_loc, jumps_save_loc, device_id, backend_name):
    global_number_of_circuits_executed = 0
    global_hardware_run_results_dict = {
        "model_type": hparams["model"],
        "measurementRes": None,
        "epsilon": hparams["cost_eps"],
        "layer_type": "actor-critic",
        "backend_name": None,
        "num_trading_days": hparams["num_trading_days"],
        "batch_idx": 0,
    }
    global_info = (global_number_of_circuits_executed, global_hardware_run_results_dict)
    eval_step = make_agent(**hparams)
    if hparams["model"] == "vanilla":
        params = (pickle.load(open(params_save_loc,'rb'))["~"],)
    else:
        params = pickle.load(open(params_save_loc,'rb'))
    batch_jumps = np.array(pickle.load(open(jumps_save_loc,'rb')))
    eval_metrics, global_info = eval_step(params, batch_jumps, device_id, global_info, backend_name)
    global_number_of_circuits_executed, global_hardware_run_results_dict = global_info
    eval_metrics = jax.device_get(eval_metrics)
    print(f"Total number of circuits executed = {global_number_of_circuits_executed}")

    utility_lambda = 1e-1
    utility_agent = eval_metrics[f"U_{utility_lambda}"]
    utility_hw_agent = eval_metrics[f"U_hw_{utility_lambda}"]

    # For Hardware we only print deltas and compute utility and PnL later
    if backend_name[-1] == "E":
        print("---" * 10 + "Utility" + "---" * 10)
        print(
            "Agent {:.2f}, Hardware Agent {:,.2f}".format(
                utility_agent, utility_hw_agent
            )
        )
        print("---" * 10 + "Deltas" + "---" * 10)
        print(f'Agent :\n {eval_metrics["seq_deltas"]}')
        print(f'Hardware Agent :\n {eval_metrics["seq_deltas_hw"]}')
        print("---" * 10 + "Terminal PnL" + "---" * 10)
        print(f'Agent :\n {eval_metrics["returns"]}')
        print(f'Hardware Agent :\n {eval_metrics["hw_returns"]}')
    else:
        print("---" * 10 + "Deltas" + "---" * 10)
        print(f'Agent :\n {eval_metrics["seq_deltas"]}')
        print(f'Hardware Agent :\n {eval_metrics["seq_deltas_hw"]}')
    return eval_metrics["seq_deltas"], eval_metrics["seq_deltas_hw"]

# Hardware Emulator Backend

In [3]:
num_days = 10
env_kwargs = dict(
    num_days=num_days,
    num_trading_days=30,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
    strike=1.,
    cost_eps=0.002,
    utility_lambda=0.1,
)

hparams = dict(env_kwargs)

params_dict = {
    "distributional": 'params/20221116-121451_10-0.002-1.0_distributional.pkl',
    "expected": 'params/20221117-131056_10-0.002-1.0_expected.pkl',
    "vanilla": 'params/20221117-123435_10-0.002-1.0_vanilla.pkl'
}

for key in params_dict.keys():
    hparams["model"] = key
    experiment(hparams, seed=19983, params_save_loc=params_dict[key], jumps_save_loc= 'data/seq_jumps_10_days',device_id="1128_part_2",backend_name = 'quantinuum_H1-1E')

Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_0.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_1.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_2.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_3.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_4.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_5.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_6.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_7.json
Using precomputed counts from data/1128_part_2_distributional_quantinuum_H1-1E_actor-critic_0.002_8.json
Using precomputed counts from data/1128_part_2_distribu

# Hardware Backend

In [4]:
num_days = 10
env_kwargs = dict(
    num_days=num_days,
    num_trading_days=30,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
    strike=1.0,
    cost_eps=0.002,
    utility_lambda=0.1,
)

hparams = dict(env_kwargs)

# Experiment 1
hparams["model"] = "distributional"
deltas_1 = experiment(
    hparams,
    seed=19983,
    params_save_loc="params/20221116-121451_10-0.002-1.0_distributional.pkl",
    jumps_save_loc="data/seq_jumps_10_days_hardware_exp_1",
    device_id="1118_device",
    backend_name="quantinuum_H1-1",
)

# Experiment 2
deltas_2 = experiment(
    hparams,
    seed=19983,
    params_save_loc="params/20221116-121451_10-0.002-1.0_distributional.pkl",
    jumps_save_loc="data/seq_jumps_10_days_hardware_exp_2",
    device_id="1122_part2_2_device",
    backend_name="quantinuum_H1-1",
)

# Experiment 3
hparams["model"] = "expected"
deltas_3 = experiment(
    hparams,
    seed=19983,
    params_save_loc="params/20221117-131056_10-0.002-1.0_expected.pkl",
    jumps_save_loc="data/seq_jumps_10_days_hardware_exp",
    device_id="1207_part_3_device",
    backend_name="quantinuum_H1-2",
)


Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_0.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_1.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_2.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_3.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_4.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_5.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_6.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_7.json
Using precomputed counts from data/1118_device_distributional_quantinuum_H1-1_actor-critic_0.002_8.json
Using precomputed counts from data/1118_device_distributional_qu

In [5]:
classical_deltas_exp = deltas_3[0]
hardware_deltas_exp = deltas_3[1]

classical_deltas_dist = jnp.concatenate((deltas_2[0], deltas_1[0]), axis=1)
hardware_deltas_dist = jnp.concatenate((deltas_2[1], deltas_1[1]), axis=1)


#### Use actions to compute utility

In [7]:
# Load seq of jumps for 10 days for hardware experiments
batch_jumps = jnp.array(pickle.load(open("data/seq_jumps_10_days_hardware_exp", 'rb')))

# Pick a random sample of jumps on the final day. As we already know, the
# price at the final day has no consequence on the actions. We just pick
# a random sample of jumps on the final day and save it here for reproducibility.
day_jumps = pickle.load(open("data/final_day_jumps_10_days_hardware_exp", 'rb'))


seq_jumps = jnp.concatenate([batch_jumps, day_jumps], axis=-1)

seq_prices = source.env.compute_prices(
    seq_jumps,
    num_trading_days=30,
    mu=0.0,
    sigma=0.2,
    initial_price=100.0,
)

bs_deltas = source.env.compute_black_scholes_deltas(
    seq_prices,
    num_days=10,
    num_trading_days=30,
    mu=0.0,
    sigma=0.2,
    strike=1.0,
)

# Black Scholes PnL and Utility
seq_rewards_bs = source.env.compute_rewards(
    seq_prices, jnp.stack(bs_deltas.T, axis=1), strike=1.0, cost_eps=0.002
)
pnl_bs = seq_rewards_bs.sum(axis=1)

utility_bs = source.env.compute_utility(seq_rewards_bs, utility_lambda=0.1)


# Expected Actor-Critic PnL and Utility
seq_rewards_exp_sim = source.env.compute_rewards(
    seq_prices, jnp.stack(classical_deltas_exp, axis=1), strike=1.0, cost_eps=0.002
)
seq_rewards_exp_hw = source.env.compute_rewards(
    seq_prices, jnp.stack(hardware_deltas_exp, axis=1), strike=1.0, cost_eps=0.002
)

pnl_exp_sim = seq_rewards_exp_sim.sum(axis=1)
pnl_exp_hw = seq_rewards_exp_hw.sum(axis=1)

utility_exp_sim = source.env.compute_utility(seq_rewards_exp_sim, utility_lambda=0.1)
utility_exp_hw = source.env.compute_utility(seq_rewards_exp_hw, utility_lambda=0.1)


# Distributional Actor-Critic PnL and Utility
seq_rewards_dist_sim = source.env.compute_rewards(
    seq_prices, jnp.stack(classical_deltas_dist, axis=1), strike=1.0, cost_eps=0.002
)
seq_rewards_dist_hw = source.env.compute_rewards(
    seq_prices, jnp.stack(hardware_deltas_dist, axis=1), strike=1.0, cost_eps=0.002
)

pnl_dist_sim = seq_rewards_dist_sim.sum(axis=1)
pnl_dist_hw = seq_rewards_dist_hw.sum(axis=1)

utility_dist_sim = source.env.compute_utility(seq_rewards_dist_sim, utility_lambda=0.1)
utility_dist_hw = source.env.compute_utility(seq_rewards_dist_hw, utility_lambda=0.1)


print(f"Black-Scholes PnL = {pnl_bs}")
print(f"Black-Scholes Utility = {utility_bs}")

print(f"Expected Actor-Critic Simulator PnL = {pnl_exp_sim}")
print(f"Expected Actor-Critic Hardware PnL = {pnl_exp_hw}")

print(f"Expected Actor-Critic Simulator Utility = {utility_exp_sim}")
print(f"Expected Actor-Critic Hardware Utility = {utility_exp_hw}")

print(f"Distributional Actor-Critic Simulator PnL = {pnl_dist_sim}")
print(f"Distributional Actor-Critic Hardware PnL = {pnl_dist_hw}")

print(f"Distributional Actor-Critic Simulator Utility = {utility_dist_sim}")
print(f"Distributional Actor-Critic Hardware Utility = {utility_dist_hw}")


Black-Scholes PnL = [-4.602 -5.373 -4.614 -4.263 -5.173 -5.030 -5.017 -4.962]
Black-Scholes Utility = -4.884951114654541
Expected Actor-Critic Simulator PnL = [0.078 -6.204 -0.203 0.967 -6.768 -3.071 -2.984 -6.689]
Expected Actor-Critic Hardware PnL = [0.213 -6.666 -0.556 1.067 -6.895 -2.315 -2.569 -6.556]
Expected Actor-Critic Simulator Utility = -3.547945976257324
Expected Actor-Critic Hardware Utility = -3.501993179321289
Distributional Actor-Critic Simulator PnL = [-1.807 -8.313 -3.803 1.464 -2.736 -1.934 -2.669 -3.944]
Distributional Actor-Critic Hardware PnL = [-1.802 -8.214 -3.648 1.367 -2.993 -2.047 -2.803 -4.200]
Distributional Actor-Critic Simulator Utility = -3.309537410736084
Distributional Actor-Critic Hardware Utility = -3.369210958480835
