# Imports

In [1]:

import sys

if "google.colab" in sys.modules:
  %pip install QuantLib
  %pip install optax
  %pip install qiskit
  %pip install qcware

  %pip install qcware-quasar
  ! rm -rf deep-hedging
  ! git clone https://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@github.com/SnehalRaj/jpmc-qcware-deephedging deep-hedging
  ! cp -r deep-hedging/* .

In [2]:
import qiskit

import quasar
from qcware_transpile.translations.quasar.to_qiskit import translate
from qiskit.compiler import assemble
import collections

from qio import loader

# qnn


In [3]:
from typing import (Callable, List, Mapping, NamedTuple, Optional, Sequence,
                    Tuple, Union)

import jax
import numpy as np
from jax import lax
from jax import numpy as jnp
import itertools
# Typing
# -----------------------------------------------------------------------------

Array = jnp.ndarray
Shape = Sequence[int]
Dtype = Union[jnp.float32, jnp.float64]
PRNGKey = Array
Params = Mapping[str, Mapping[str, jnp.ndarray]]
State = Mapping[str, Mapping[str, jnp.ndarray]]
InitializerFn = Callable[[PRNGKey, Shape, Dtype], Array]
Initializer = Callable[..., InitializerFn]
Module = Callable[..., InitializerFn]


class ModuleFn(NamedTuple):
    apply: Callable[..., Tuple[Array, State]]
    init: Optional[Callable[..., Tuple[Params, State, Array]]] = None


def add_scope_to_params(scope, params):
    return dict((f"{scope}/{key}", array) for key, array in params.items())


def get_params_by_scope(scope, params):
    return dict((key[len(scope) + 1:], array) for key, array in params.items()
                if key.startswith(scope + '/'))


# Initializers
# -----------------------------------------------------------------------------


def constant(val: float, ) -> InitializerFn:
    """ Initialize with a constant value. 

    Args:
        val: The value to initialize with.
    """
    def init_fn(key, shape, dtype=jnp.float32):
        return jnp.broadcast_to(val, shape).astype(dtype)

    return init_fn


def zeros() -> InitializerFn:
    """ Initialize with zeros."""
    return constant(0.)


def ones() -> InitializerFn:
    """ Initialize with ones."""
    return constant(1.)


def uniform(
    minval: float = 0.,
    maxval: float = 1.,
) -> InitializerFn:
    """ Initialize with a uniform distribution.

    Args:
        minval: The minimum value of the uniform distribution. 
        maxval: The maximum value of the uniform distribution.
    """
    def init_fn(key, shape, dtype=jnp.float32):
        return jax.random.uniform(key, shape, dtype, minval, maxval)

    return init_fn


def normal(
    mean: float = 0.,
    std: float = 1.,
) -> InitializerFn:
    """ Initialize with a normal distribution.

    Args:
        mean: The mean of the normal distribution.
        std: The standard deviation of the normal distribution.
    """
    def init_fn(key, shape, dtype=jnp.float32):
        _mean = lax.convert_element_type(mean, dtype)
        _std = lax.convert_element_type(std, dtype)
        return _mean + _std * jax.random.normal(key, shape, dtype)

    return init_fn


def truncated_normal(
    mean: float = 0.,
    std: float = 1.,
) -> InitializerFn:
    """ Initialize with a truncated normal distribution.

    Args:
        mean: The mean of the truncated normal distribution.
        std: The standard deviation of the truncated normal distribution.
    """
    def init_fn(key, shape, dtype=jnp.float32):
        _mean = lax.convert_element_type(mean, dtype)
        _std = lax.convert_element_type(std, dtype)
        return _mean + _std * jax.random.truncated_normal(
            key, -2., 2., shape, dtype)

    return init_fn


# Modules
# -----------------------------------------------------------------------------


def quax_wrapper(layer_fn):
    """ Create a module from a quax layer. """
    def module(*args, **kwargs):
        init_fn, apply_fn = layer_fn(*args, **kwargs)

        def _apply_fn(params, state, key, inputs, **kwargs):
            outputs = apply_fn(params, inputs, **kwargs)
            return outputs, state

        def _init_fn(key, inputs_shape):
            shape, params = init_fn(key, inputs_shape)
            state = None
            return params, state, shape

        return ModuleFn(_apply_fn, init=_init_fn)

    return module


def haiku_wrapper(layer_fn):
    """ Create a module from a Haiku layer. """
    def module(*args, **kwargs):
        import haiku as hk
        layer = hk.transform_with_state(layer_fn(*args, **kwargs))

        def _apply_fn(params, state, key, inputs, **kwargs):
            outputs, state = layer.apply(params, state, key, inputs, **kwargs)
            return outputs, state

        def _init_fn(key, inputs_shape):
            params, state = layer.init(key, inputs_shape)
            outputs, _ = layer.apply(params, state, key, inputs_shape,
                                     **kwargs)
            shape = outputs.shape
            return params, state, shape

        return ModuleFn(_apply_fn, init=_init_fn)

    return module


def elementwise(elementwise_fn: Callable[[Array], Array], ) -> ModuleFn:
    """ Create an elementwise layer from a JAX function. 

        Args:
            elementwise_fn: The JAX function to apply to each element.
    """
    return ModuleFn(apply=elementwise_fn)


def linear(
    n_features: int,
    with_bias: bool = True,
    w_init: Optional[InitializerFn] = None,
    b_init: Optional[InitializerFn] = None,
) -> ModuleFn:
    """ Create a linear layer.

    Args:
        n_features: The number of features in the output.
        with_bias: Whether to include a bias term.
        w_init: The initializer for the weights.
        b_init: The initializer for the bias.
    """
    def apply_fn(params, state, key, inputs, **kwargs):
        outputs = jnp.dot(inputs, params['w'])

        if with_bias:
            outputs += params['b']
        return outputs, None

    def init_fn(key, inputs_shape):
        params, state = {}, None
        key, w_key, b_key = jax.random.split(key, 3)
        w_init_ = w_init or truncated_normal(std=1. / inputs_shape[-1])
        w_shape = (inputs_shape[-1], n_features)
        params['w'] = w_init_(w_key, w_shape)
        if with_bias:
            b_init_ = b_init or zeros()
            b_shape = (n_features, )
            params['b'] = b_init_(b_key, b_shape)
        shape = inputs_shape[:-1] + (n_features, )
        return params, state, shape

    return ModuleFn(apply_fn, init=init_fn)


def layer_norm(
    with_scale: bool = True,
    with_bias: bool = True,
    s_init: Optional[InitializerFn] = None,
    b_init: Optional[InitializerFn] = None,
) -> ModuleFn:
    """ Create a normalization layer. 

    Args:
        with_scale: Whether to use a scale parameter.
        with_bias: Whether to include a bias term.
        s_init: The initializer for the scale.
        b_init: The initializer for the bias.
    """
    def init_fn(key, inputs_shape):
        params = {}
        state = None
        s_key, b_key = jax.random.split(key)
        n_features = inputs_shape[-1]
        if with_scale:
            s_init_ = s_init or ones()
            s_shape = (n_features, )
            params['s'] = s_init_(s_key, s_shape)
        if with_bias:
            b_init_ = b_init or zeros()
            b_shape = (n_features, )
            params['b'] = b_init_(b_key, b_shape)
        return params, state, inputs_shape

    def apply_fn(params, state, key, inputs, **kwargs):
        mean = jnp.mean(inputs, axis=-1, keepdims=True)
        var = jnp.var(inputs, axis=-1, keepdims=True) + 1e-5
        outputs = params['s'] * (inputs - mean) / jnp.sqrt(var) + params['b']
        return outputs, state

    return ModuleFn(apply_fn, init=init_fn)


def sequential(*modules: List[ModuleFn], ) -> ModuleFn:
    """ Create a sequential module from a list of modules.

    Args:
        modules: A list of modules.
    """
    def apply_fn(params, state, key, inputs, **kwargs):
        outputs = inputs
        if key is not None:
            key = jax.random.split(key, len(modules))
        else:
            key = len(modules) * [None]
        new_state = dict(
            ('layer_{}'.format(idx), None) for idx in range(len(modules)))
        if state is None:
            state = new_state
        for idx, module in enumerate(modules):
            if module.init is not None:
                outputs, new_module_state = module.apply(
                    params['layer_{}'.format(idx)],
                    state['layer_{}'.format(idx)],
                    key[idx],
                    outputs,
                    **kwargs,
                )
                new_state['layer_{}'.format(idx)] = new_module_state
            else:
                outputs = module.apply(outputs)

        state = new_state
        return outputs, state

    def init_fn(key, inputs_shape):
        params = dict(
            ('layer_{}'.format(idx), None) for idx in range(len(modules)))
        state = dict(
            ('layer_{}'.format(idx), None) for idx in range(len(modules)))
        key = jax.random.split(key, len(modules))
        shape = inputs_shape
        for idx, module in enumerate(modules):
            if module.init is not None:
                module_params, module_state, shape = module.init(
                    key[idx], shape)
                params['layer_{}'.format(idx)] = module_params
                state['layer_{}'.format(idx)] = module_state
            else:
                shape = module.apply(jnp.zeros(shape)).shape

        return params, state, shape

    return ModuleFn(apply_fn, init=init_fn)


def orthogonalize_weights(weights):
    """Take the current weight matrices for each layer, apply SVD decomposition on each one, 
    then transform the singular values, and finally recompose to make the weight matrix orthogonal.
    U,s,V = SVD(W). then all singular values must be ~1. 
    Output : update the self.weights matrices. 
    Reference : Orthogonal Deep Neural Networks, K.Juia et al. 2019"""
    epsilon = 0.5
    U, s, V = jnp.linalg.svd(weights, full_matrices=False)
    s = jnp.clip(s, 1/(1+epsilon), 1+epsilon)
    # reform with the new singular values
    weights = jnp.dot(U, jnp.dot(jnp.diag(s), V))
    return weights


def orthogonalize_params(params):
    """Take a dictionary of params and orthogonalize the weights
    """
    for k1 in params.keys():
        if params[k1] != None:
            for k2 in params[k1].keys():
                if k2.split('/')[-1] == 'w':
                    params[k1][k2] = orthogonalize_weights(params[k1][k2])

    return params


def _make_orthogonal_fn(rbs_idxs, size):
    num_thetas = sum(map(len, rbs_idxs))
    rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
    len_idxs = np.cumsum([0] + list(map(len, rbs_idxs)))

    def _get_rbs_unitary(theta):
        """ Returns the unitary matrix for a single RBS gate. """
        cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
        unitary = jnp.array([
            [cos_theta, sin_theta],
            [-sin_theta, cos_theta],
        ])
        unitary = unitary.transpose(*[*range(2, unitary.ndim), 0, 1])
        return unitary

    def _get_rbs_unitary_grad(theta):
        """ Returns the unitary matrix for a single RBS gate. """
        cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
        unitary = jnp.array([
            [-sin_theta, cos_theta],
            [-cos_theta, -sin_theta],
        ])
        unitary = unitary.transpose(*[*range(2, unitary.ndim), 0, 1])
        return unitary

    @jax.custom_jvp
    def _get_parallel_rbs_unitary(thetas):
        """ Returns the unitary matrix for parallel RBS gates. """
        unitaries = []
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i]:len_idxs[i + 1]]
            rbs_blocks = _get_rbs_unitary(sub_thetas)
            eye_block = jnp.eye(size - len(idxs), dtype=thetas.dtype)
            permutation = idxs + [i for i in range(size) if i not in idxs]
            permutation = np.argsort(permutation)
            unitary = jax.scipy.linalg.block_diag(*rbs_blocks, eye_block)
            unitary = unitary[permutation][:, permutation]
            unitaries.append(unitary)
        unitaries = jnp.stack(unitaries)
        return unitaries

    @_get_parallel_rbs_unitary.defjvp
    def get_parallel_rbs_unitary_jvp(primals, tangents):
        thetas, = primals
        thetas_dot, = tangents
        unitaries = []
        unitaries_dot = []
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i]:len_idxs[i + 1]]
            sub_thetas_dot = thetas_dot[len_idxs[i]:len_idxs[i + 1]]
            rbs_blocks = _get_rbs_unitary(sub_thetas)
            rbs_blocks_grad = _get_rbs_unitary_grad(sub_thetas)
            rbs_blocks_dot = sub_thetas_dot[..., None, None] * rbs_blocks_grad
            eye_block = jnp.eye(size - len(idxs), dtype=thetas.dtype)
            zero_block = jnp.zeros_like(eye_block)
            permutation = idxs + [i for i in range(size) if i not in idxs]
            permutation = np.argsort(permutation)
            unitary = jax.scipy.linalg.block_diag(*rbs_blocks, eye_block)
            unitary_dot = jax.scipy.linalg.block_diag(*rbs_blocks_dot,
                                                      zero_block)
            unitary = unitary[permutation][:, permutation]
            unitary_dot = unitary_dot[permutation][:, permutation]
            unitaries.append(unitary)
            unitaries_dot.append(unitary_dot)
        primal_out = jnp.stack(unitaries)
        tangent_out = jnp.stack(unitaries_dot)
        return primal_out, tangent_out

    def orthogonal_fn(thetas, precision=None):
        """ Returns the unitary matrix for a sequence of parallel RBS gates. """
        assert thetas.shape[0] == num_thetas, "Wrong number of thetas."
        unitaries = _get_parallel_rbs_unitary(thetas)
        unitary = jnp.linalg.multi_dot(unitaries[::-1], precision=precision)
        return unitary

    return orthogonal_fn

def make_general_orthogonal_fn(rbs_idxs, size):
    num_thetas = sum(map(len, rbs_idxs))
    rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
    len_idxs = np.cumsum([0] + list(map(len, rbs_idxs)))

    def _get_rbs_unitary(theta):
        """ Returns the unitary matrix for a single RBS gate. """
        cos_t, sin_t = jnp.cos(theta), jnp.sin(theta)
        zeros = jnp.zeros_like(cos_t)
        ones = jnp.ones_like(cos_t)
        unitary = jnp.array([
            [ones, zeros, zeros, zeros],
            [zeros, cos_t, -sin_t, zeros],
            [zeros, sin_t, cos_t, zeros],
            [zeros, zeros, zeros, ones],
        ])
        unitary = unitary.transpose(*[*range(2, unitary.ndim), 0, 1])
        return unitary

    def _get_parallel_rbs_unitary(thetas):
        """ Returns the unitary matrix for parallel RBS gates. """
        unitaries = []
        num_qubits = size
        map_qubits = [[0, 2**q] for q in range(num_qubits)]
        for i, idxs in enumerate(rbs_idxs):
            idxs = sum(idxs, [])
            sub_thetas = thetas[len_idxs[i]:len_idxs[i + 1]]
            rbs_blocks = _get_rbs_unitary(sub_thetas)
            eye_block = jnp.eye(2**(size - len(idxs)) , dtype=thetas.dtype)
            unitary =  tensordot_unitary([*rbs_blocks, eye_block])
            unitary_qubits = idxs + [
            q for q in range(num_qubits) if q not in idxs
            ]
            permutation = np.argsort([
            sum(binary)
            for binary in itertools.product(*(map_qubits[q]
                                              for q in unitary_qubits)) 
            ])
            unitary = unitary[permutation][:, permutation]
            unitaries.append(unitary)
        unitaries = jnp.stack(unitaries)
        
        return unitaries


    def orthogonal_fn(thetas, precision=None):
        """ Returns the unitary matrix for a sequence of parallel RBS gates. """
        assert thetas.shape[0] == num_thetas, "Wrong number of thetas."
        unitaries = _get_parallel_rbs_unitary(thetas)
        unitary = jnp.linalg.multi_dot(unitaries[::-1], precision=precision)
        return unitary

    return orthogonal_fn

def _get_pyramid_idxs(num_inputs, num_outputs):
    num_max = max(num_inputs, num_outputs)
    num_min = min(num_inputs, num_outputs)
    if num_max == num_min:
        num_min -= 1
    end_idxs = np.concatenate(
        [np.arange(1, num_max - 1), num_max - np.arange(1, num_min + 1)])
    start_idxs = np.concatenate([
        np.arange(end_idxs.shape[0] + num_min - num_max) % 2,
        np.arange(num_max - num_min)
    ])
    if num_inputs < num_outputs:
        start_idxs = start_idxs[::-1]
        end_idxs = end_idxs[::-1]
    rbs_idxs = [
        np.arange(start_idxs[i], end_idxs[i] + 1).reshape(-1, 2)
        for i in range(len(start_idxs))
    ]
    return rbs_idxs


def _get_butterfly_idxs(num_inputs, num_outputs):
    def _get_butterfly_idxs(n):
        if n == 2:
            return np.array([[[0, 1]]])
        else:
            rbs_idxs = _get_butterfly_idxs(n // 2)
            first = np.concatenate([rbs_idxs, rbs_idxs + n // 2], 1)
            last = np.arange(n).reshape(1, 2, n // 2).transpose(0, 2, 1)
            rbs_idxs = np.concatenate([first, last], 0)
            return rbs_idxs

    circuit_dim = int(2**np.ceil(np.log2(max(num_inputs, num_outputs))))
    rbs_idxs = _get_butterfly_idxs(circuit_dim)
    if num_inputs < num_outputs:
        rbs_idxs = rbs_idxs[::-1]
    return rbs_idxs


def ortho_linear(
    n_features: int,
    layout: Union[str, List[List[Tuple[int, int]]]] = 'butterfly',
    normalize_inputs: bool = False,
    normalize_outputs: bool = True,
    normalize_stop_gradient: bool = True,
    with_scale: bool = True,
    with_bias: bool = True,
    t_init: Optional[InitializerFn] = None,
    s_init: Optional[InitializerFn] = None,
    b_init: Optional[InitializerFn] = None,
) -> ModuleFn:
    """ Create an orthogonal layer from a layout of RBS gates.

    Args:
        n_features: The number of features in the output.
        layout: The layout of the RBS gates.
        normalize_inputs: Whether to normalize the inputs.
        normalize_outputs: Whether to normalize the outputs.
        normalize_stop_gradient: Whether to stop the gradient of the norm.
        with_scale: Whether to use a scale parameter.
        with_bias: Whether to include a bias term.
        t_init: The initializer for the angles.
        s_init: The initializer for the scale.
        b_init: The initializer for the bias.
    """
    def apply_fn(params, state, key, inputs, **kwargs):
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs.shape[-1], n_features)
            circuit_dim = int(2**np.ceil(
                np.log2(max(inputs.shape[-1], n_features))))
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs.shape[-1], n_features)
            make_unitary = _get_pyramid_idxs(inputs.shape[-1], n_features)
            circuit_dim = max(inputs.shape[-1], n_features)
        else:
            rbs_idxs = layout
            circuit_dim = max(
                [max(idxs) for moment in layout for idxs in moment])
        make_unitary = _make_orthogonal_fn(rbs_idxs[::-1], circuit_dim)
        if normalize_inputs:
            norm = jnp.linalg.norm(inputs, axis=-1)[..., None]
            if normalize_stop_gradient:
                norm = lax.stop_gradient(norm)
            inputs /= norm
        if inputs.shape[-1] < circuit_dim:
            zeros = jnp.zeros(
                (*inputs.shape[:-1], circuit_dim - inputs.shape[-1]), )
            inputs = jnp.concatenate([zeros, inputs], axis=-1)
        unitary = make_unitary(params['t'][::-1])
        outputs = jnp.dot(inputs, unitary.T)[..., -n_features:]
        if normalize_outputs:
            norm = jnp.linalg.norm(outputs, axis=-1)[..., None]
            if normalize_stop_gradient:
                norm = lax.stop_gradient(norm)
            outputs /= norm
        if with_scale:
            outputs *= params['s']
        if with_bias:
            outputs += params['b']
        return outputs, None

    def init_fn(key, inputs_shape):
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs_shape[-1], n_features)
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs_shape[-1], n_features)
        else:
            rbs_idxs = layout
        n_angles = sum(map(len, rbs_idxs))
        params, state = {}, None
        key, t_key, b_key, s_key = jax.random.split(key, 4)
        t_init_ = t_init or uniform(-np.pi, np.pi)
        t_shape = (n_angles, )
        params['t'] = t_init_(t_key, t_shape)
        if with_scale:
            s_init_ = s_init or ones()
            s_shape = (n_features, )
            params['s'] = s_init_(s_key, s_shape)
        if with_bias:
            b_init_ = b_init or zeros()
            b_shape = (n_features, )
            params['b'] = b_init_(b_key, b_shape)
        shape = inputs_shape[:-1] + (n_features, )
        return params, state, shape

    return ModuleFn(apply_fn, init=init_fn)



def ortho_linear_noisy(
    n_features: int,
    noise_scale: float = 0.01,
    layout: Union[str, List[List[Tuple[int, int]]]] = 'butterfly',
    normalize_inputs: bool = True,
    normalize_outputs: bool = True,
    normalize_stop_gradient: bool = True,
    with_scale: bool = True,
    with_bias: bool = True,
    t_init: Optional[InitializerFn] = None,
    s_init: Optional[InitializerFn] = None,
    b_init: Optional[InitializerFn] = None,
) -> ModuleFn:
    """ Create an orthogonal layer from a layout of RBS gates.
    Args:
        n_features: The number of features in the output.
        layout: The layout of the RBS gates.
        normalize_inputs: Whether to normalize the inputs.
        normalize_outputs: Whether to normalize the outputs.
        normalize_stop_gradient: Whether to stop the gradient of the norm.
        with_scale: Whether to use a scale parameter.
        with_bias: Whether to include a bias term.
        t_init: The initializer for the angles.
        s_init: The initializer for the scale.
        b_init: The initializer for the bias.
    """
    def apply_fn(params, state, key, inputs, **kwargs):
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs.shape[-1], n_features)
            circuit_dim = int(2**np.ceil(
                np.log2(max(inputs.shape[-1], n_features))))
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs.shape[-1], n_features)
            make_unitary = _get_pyramid_idxs(inputs.shape[-1], n_features)
            circuit_dim = max(inputs.shape[-1], n_features)
        else:
            rbs_idxs = layout
            circuit_dim = max(
                [max(idxs) for moment in layout for idxs in moment])
        make_unitary = _make_orthogonal_fn(rbs_idxs[::-1], circuit_dim)
        if normalize_inputs:
            norm = jnp.linalg.norm(inputs, axis=-1)[..., None]
            if normalize_stop_gradient:
                norm = lax.stop_gradient(norm)
            inputs /= norm
        if inputs.shape[-1] < circuit_dim:
            zeros = jnp.zeros(
                (*inputs.shape[:-1], circuit_dim - inputs.shape[-1]), )
            inputs = jnp.concatenate([zeros, inputs], axis=-1)
        unitary = make_unitary(params['t'][::-1])
        outputs = jnp.dot(inputs, unitary.T)
        if normalize_outputs:
            norm = jnp.linalg.norm(outputs, axis=-1)[..., None]
            if normalize_stop_gradient:
                norm = lax.stop_gradient(norm)
            outputs /= norm
        outputs = jnp.einsum('...i,...i->...i',outputs,outputs)
        outputs = outputs[..., -n_features:]
        if with_scale:
            outputs *= params['s']
        if with_bias:
            outputs += params['b']
        key, _ = jax.random.split(key)
        outputs += noise_scale*jax.random.normal(key, outputs.shape)
        return outputs, state

    def init_fn(key, inputs_shape):
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs_shape[-1], n_features)
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs_shape[-1], n_features)
        else:
            rbs_idxs = layout
        n_angles = sum(map(len, rbs_idxs))
        params, state = {}, None
        key, t_key, b_key, s_key = jax.random.split(key, 4)
        t_init_ = t_init or uniform(-np.pi, np.pi)
        t_shape = (n_angles, )
        params['t'] = t_init_(t_key, t_shape)
        if with_scale:
            s_init_ = s_init or ones()
            s_shape = (n_features, )
            params['s'] = s_init_(s_key, s_shape)
        if with_bias:
            b_init_ = b_init or zeros()
            b_shape = (n_features, )
            params['b'] = b_init_(b_key, b_shape)
        shape = inputs_shape[:-1] + (n_features, )
        return params, state, shape

    return ModuleFn(apply_fn, init=init_fn)

# Main

## Circuit constructions

In [4]:
# Global counter

global_number_of_circuits_executed = 0

# Global object keeping track of result
# Used for pickling
# Populated initially in DeepHedgingBenchmark().__test_model
# and with run results in run_circuit

global_hardware_run_results_dict = {}

In [5]:
%load_ext autoreload
%autoreload 2
import numpy as np
from qnn import _get_butterfly_idxs, _get_pyramid_idxs, _make_orthogonal_fn
# fix for older versions of Qiskit
if qiskit.__version__ <= '0.37.1':
    import qiskit.providers.aer.noise as noise
else:
    import qiskit_aer.noise as noise
import json
import pickle
import time
import copy
from pathlib import Path
from tqdm import tqdm

def RBS_gate(theta, bla):

    def operator_function(parameters):
            theta = parameters['theta']
            c = np.cos(theta)
            s = np.sin(theta)
            return np.array([[1.0, 0.0, 0.0, 0.0],
                             [0.0, c, s, 0.0],
                             [0.0, -s, c, 0.0],
                             [0.0, 0.0, 0.0, 1.0]], dtype=np.complex128)

    return quasar.Gate(
        nqubit=2,
        operator_function=operator_function,
        parameters=collections.OrderedDict([('theta', theta)]),
        name='RBS',
        ascii_symbols=['B', 'S'])



def prepare_circuit(input, params, loader_layout='parallel', layer_layout='butterfly'):
    def _get_layer_circuit():
      _params = np.array(params).astype('float')
      if layer_layout == 'butterfly':
        rbs_idxs = _get_butterfly_idxs(num_qubits, num_qubits)
      elif layer_layout == 'pyramid':
        rbs_idxs = _get_pyramid_idxs(num_qubits, num_qubits)
      circuit_layer = quasar.Circuit()
      idx_angle = 0
      for gates_per_timestep in rbs_idxs[::-1]:
        for gate in gates_per_timestep:
          circuit_layer.add_gate(quasar.Gate.RBS(theta=-_params[::-1][idx_angle]), tuple(gate))
          idx_angle+=1
      return circuit_layer
    
    num_qubits = len(input)
    loader_circuit = loader(np.array(input),mode=loader_layout,initial=True,controlled=False)
    layer_circuit = _get_layer_circuit()
    circuit = quasar.Circuit.join_in_time([loader_circuit, layer_circuit])
    # Translate from qcware-quasar to qiskit
    qiskit_circuit = translate(circuit)
    
    # qiskit_circuit.save_statevector()    

    qiskit_circuit = qiskit.transpile(qiskit_circuit, optimization_level=3)
    c = qiskit.ClassicalRegister(num_qubits)
    qiskit_circuit.add_register(c)
    qiskit_circuit.barrier()
    qiskit_circuit.measure(qubit=range(num_qubits),cbit=c)
    return qiskit_circuit

def counter_to_dict(c):
    """Converts counter returned by pytket get_counts function
    to dictionary returned by qiskit
    canonical use:
    >>> result = backend.get_result(handle)
    >>> counts = result.get_counts(basis=BasisOrder.dlo)
    >>> counts_qiskit = counter_to_dict(counts)
    """
    d = {}
    for k, v in c.items():
        d[''.join(str(x) for x in k)] = int(v)
    return d

def run_circuit(circs, backend_name = 'quantinuum_H1-2E'):
    """
    backend name accepted 
    """
    global global_number_of_circuits_executed
    global global_hardware_run_results_dict
    input_size = 8
    results = np.zeros((len(circs), input_size))
    
    #TODO
    # if 'qiskit' in backend:
    #elif 'quantinuum' in backend:
    
    global_number_of_circuits_executed += len(circs)
    num_measurements = 1000
    
    if "qiskit" in backend_name:
        backend = qiskit.Aer.get_backend('qasm_simulator')
        if backend_name == 'qiskit_noiseless':
            measurement = qiskit.execute(circs, backend, shots=num_measurements)
        elif backend_name == 'qiskit_noisy': 
            # Error probabilities
            prob_1 = 0.001  # 1-qubit gate
            prob_2 = 0.01   # 2-qubit gate
            # Dylan's tunes error probabilities
            # prob_1 = 0  # 1-qubit gate
            # prob_2 = 3.5e-3   # 2-qubit gate

            # Depolarizing quantum errors
            error_1 = noise.depolarizing_error(prob_1, 1)
            error_2 = noise.depolarizing_error(prob_2, 2)

            # Add errors to noise model
            noise_model = noise.NoiseModel()
            noise_model.add_all_qubit_quantum_error(error_1, ['h', 'x', 'ry'])
            noise_model.add_all_qubit_quantum_error(error_2, ['cz'])

            # Get basis gates from noise model
            basis_gates = noise_model.basis_gates
            measurement = qiskit.execute(circs, backend,basis_gates=basis_gates, noise_mode=noise_model, shots=num_measurements)
        else:
            raise ValueError(f"Unexpected backend name {backend_name}")
        all_counts = measurement.result().get_counts()
    elif "quantinuum" in backend_name:
        # From docs: "Batches cannot exceed the maximum limit of 500 H-System Quantum Credits (HQCs) total"
        # Therefore batching is more or less useless on quantinuum
        from pytket.extensions.qiskit import qiskit_to_tk
        from pytket.circuit import BasisOrder
        from pytket.extensions.quantinuum import QuantinuumBackend
     
        if global_hardware_run_results_dict['model_type'] != 'simple':
            raise NotImplementedError(f"Model {global_hardware_run_results_dict['model_type']} not supported yet, only simple model is supported.")
    
        outpath_stem = f"1031_{global_hardware_run_results_dict['model_type']}_{backend_name}_{global_hardware_run_results_dict['layer_type']}_{global_hardware_run_results_dict['epsilon']}"
        
        outpath_result_final = f"data/{outpath_stem}.json"
        outpath_handles = f"data/handles_{outpath_stem}.pickle"
        if Path(outpath_result_final).exists():
            # if precomputed results already present on disk, simply load
            print(f"Using precomputed counts from {outpath_result_final}")
            all_counts = json.load(open(outpath_result_final, "r"))['all_counts']
        else:
            if backend_name == "quantinuum_H1-2E":
                backend = QuantinuumBackend(device_name="H1-2E")
            elif backend_name == "quantinuum_H1-2":
                backend = QuantinuumBackend(device_name="H1-2")
            else:
                raise ValueError(f"Unknown Quantinuum backend: {backend_name}")
            if Path(outpath_handles).exists():
                # if circuits already submitted, simply load from disk
                print(f"Using pickled handles from {outpath_handles}")
                handles = pickle.load(open(outpath_handles, "rb"))
            else:
                # otherwise, submit circuits and pickle handles
                circs_tk = [qiskit_to_tk(circ) for circ in circs]
                for idx, circ in enumerate(circs_tk):
                    circ.name = f'{outpath_stem}_{idx+1}_of_{len(circs)}'
                compiled_circuits = backend.get_compiled_circuits(circs_tk, optimisation_level=2)
                handles = backend.process_circuits(compiled_circuits, n_shots=num_measurements)
                pickle.dump(handles, open(outpath_handles, "wb"))
                print(f"Dumped handles to {outpath_handles}")
            # retrieve results from handles
            result_list = []
            
            with tqdm(total=len(handles), desc='#jobs finished') as pbar:
                for handle in handles:
                    while True:
                        status = backend.circuit_status(handle).status
                        if status.name == 'COMPLETED':
                            result = backend.get_result(handle)
                            result_list.append(copy.deepcopy(result))
                            pbar.update(1)
                            break
                        else:
                            assert status.name in ['QUEUED', 'RUNNING'] 
                        time.sleep(1)
            global_hardware_run_results_dict['result_list'] = [x.to_dict() for x in result_list]
            # convert from tket counts format to qiskit
            all_counts = [
                counter_to_dict(
                    result.get_counts(basis=BasisOrder.dlo)
                ) for result in result_list
            ]
            global_hardware_run_results_dict['all_counts'] = all_counts
            # dump result on disk
            json.dump(global_hardware_run_results_dict, open(outpath_result_final, "w"))
    else:
        raise ValueError(f"Unexpected backend name {backend_name}")
        
    # Post processing
    # Discard bitstrings that do not correspond to unary encoding (not Hamming weight 1)
    # We build a dictionary with all unary bitstrings and only add counts corresponding to unary bitstrings
    # Note: f"{2**i:0{input_size}b}" converts 2**i to its binary string representation.
    for j in range(len(circs)):
        measurementRes = all_counts[j]
        num_postselected = 0
        filtered_counts = {f"{2**i:0{input_size}b}":0 for i in range(input_size)}
        for bitstring, count in measurementRes.items():
            if sum([int(x) for x in bitstring]) != 1:
                continue
            filtered_counts[bitstring] += count
            num_postselected+= count
        results[j] = [filtered_counts[k]/num_postselected for k in sorted(filtered_counts)][::-1]    
    return results

## Definition of layers

In [6]:

def ortho_linear_hardware(
    n_features: int,
    layout: Union[str, List[List[Tuple[int, int]]]] = 'butterfly',
    normalize_inputs: bool = True,
    normalize_outputs: bool = True,
    normalize_stop_gradient: bool = True,
    with_scale: bool = True,
    with_bias: bool = True,
    t_init: Optional[InitializerFn] = None,
    s_init: Optional[InitializerFn] = None,
    b_init: Optional[InitializerFn] = None,
) -> ModuleFn:
    """ Create an orthogonal layer from a layout of RBS gates.
    Args:
        n_features: The number of features in the output.
        layout: The layout of the RBS gates.
        normalize_inputs: Whether to normalize the inputs.
        normalize_outputs: Whether to normalize the outputs.
        normalize_stop_gradient: Whether to stop the gradient of the norm.
        with_scale: Whether to use a scale parameter.
        with_bias: Whether to include a bias term.
        t_init: The initializer for the angles.
        s_init: The initializer for the scale.
        b_init: The initializer for the bias.
    """
    def apply_fn(params, state, key, inputs, **kwargs):
        # Step 1: preprocess the inputs
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs.shape[-1], n_features)
            circuit_dim = int(2**np.ceil(
                np.log2(max(inputs.shape[-1], n_features))))
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs.shape[-1], n_features)
            make_unitary = _get_pyramid_idxs(inputs.shape[-1], n_features)
            circuit_dim = max(inputs.shape[-1], n_features)
        else:
            rbs_idxs = layout
            circuit_dim = max(
                [max(idxs) for moment in layout for idxs in moment])
        if normalize_inputs:
            norm = jnp.linalg.norm(inputs, axis=-1)[..., None]
            if normalize_stop_gradient:
                norm = lax.stop_gradient(norm)
            inputs /= norm
        if inputs.shape[-1] < circuit_dim:
            zeros = jnp.zeros(
                (*inputs.shape[:-1], circuit_dim - inputs.shape[-1]), )
            inputs = jnp.concatenate([zeros, inputs], axis=-1)
        # Step 2: generate the circuits
        circs = []
        out_shape = inputs.shape[:-1]+(n_features,)
        for input in inputs.reshape(-1,circuit_dim):
            circs.append(prepare_circuit(input,params['t']))
        # run circuits and truncate to desired number of outputs
        outputs = jnp.array(run_circuit(circs))[..., -n_features:]
        
        outputs = outputs.reshape(out_shape)
        # unitary = make_unitary(params['t'])
        # outputs = jnp.dot(inputs, unitary.T)[..., -n_features:]
        # outputs = inputs
        if with_scale:
            outputs *= params['s']
        if with_bias:
            outputs += params['b']
        return outputs, state

    def init_fn(key, inputs_shape):
        if layout == 'butterfly':
            rbs_idxs = _get_butterfly_idxs(inputs_shape[-1], n_features)
        elif layout == 'pyramid':
            rbs_idxs = _get_pyramid_idxs(inputs_shape[-1], n_features)
        else:
            rbs_idxs = layout
        n_angles = sum(map(len, rbs_idxs))
        params, state = {}, None
        key, t_key, b_key, s_key = jax.random.split(key, 4)
        t_init_ = t_init or uniform(-np.pi, np.pi)
        t_shape = (n_angles, )
        params['t'] = t_init_(t_key, t_shape)
        if with_scale:
            s_init_ = s_init or ones()
            s_shape = (n_features, )
            params['s'] = s_init_(s_key, s_shape)
        if with_bias:
            b_init_ = b_init or zeros()
            b_shape = (n_features, )
            params['b'] = b_init_(b_key, b_shape)
        shape = inputs_shape[:-1] + (n_features, )
        return params, state, shape

    return ModuleFn(apply_fn, init=init_fn)

# Models

In [7]:
from typing import Any, TypeVar
import itertools
import jax
from jax import numpy as jnp
import qnn
from qnn import ModuleFn, elementwise, linear, sequential, make_general_orthogonal_fn, _get_butterfly_idxs, _get_pyramid_idxs 


relu = elementwise(jax.nn.relu)
gelu = elementwise(jax.nn.gelu)
log_softmax = elementwise(jax.nn.log_softmax)
sigmoid = elementwise(jax.nn.sigmoid)

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, jnp.stack(ys)

def recurrent_network(hps, layer_func: ModuleFn = linear, **kwargs) -> ModuleFn:
    """ Create a Recurrent Network.
    Args:
        n_features: The number of features.
        n_layers: The number of layers.
        layer_func: The type of layers to use.
    """

    preprocessing = [linear(hps.n_features), sigmoid]
    features = hps.n_layers * [layer_func(hps.n_features), relu]
    postprocessing = [linear(1), sigmoid]
    layers = preprocessing + features + postprocessing
    net = sequential(*layers)

    def init_fn(key, inputs_shape):
        params = net.init(
            key, (inputs_shape[0], inputs_shape[1], 2*inputs_shape[2]))[0]
        return params, None, inputs_shape

    def apply_fn(params, state, key, inputs):
        def cell_fn(prev_outputs, inputs):
            inp = inputs[None, ...]
            inp = jnp.concatenate([prev_outputs, inp], axis=-1)
            delta = net.apply(params, None, key, inp)[0]
            # print(f'inputs shape = {inp.shape} deltas shape = {delta.shape}')
            return delta, delta

        prev_state = jnp.zeros((1, inputs.shape[0], inputs.shape[-1]))
        inputs = inputs.transpose(1, 0, 2)
        _, outputs = scan(cell_fn, prev_state, inputs)
        outputs = jnp.squeeze(outputs, 1)
        outputs = outputs.transpose(1, 0, 2)
        return outputs, state
    return qnn.ModuleFn(apply_fn, init_fn)


def lstm_cell(hps,  layer_func: ModuleFn = linear, **kwargs) -> ModuleFn:
    """ Create an LSTM Cell.
    Args:
        n_features: The number of features.
        layer_func: The type of layers to use.
    """

    _linear = layer_func(n_features=int(hps.n_features/2), with_bias=True)

    def init_fn(key, inputs_shape):
        keys = jax.random.split(key, num=4)
        params = {}
        layer_idx = ['i', 'g', 'f', 'o']
        _shape = (inputs_shape[0], inputs_shape[1], 2*inputs_shape[2])
        _init_params = {}
        for i, id in enumerate(layer_idx):
            _init_params[id] = _linear.init(keys[i],  _shape)[0]
            params.update(qnn.add_scope_to_params(id, _init_params[id]))
        return params, None, inputs_shape

    def apply_fn(params, state, key, inputs):
        def cell_fn(prev_state, inputs):
            prev_hidden, prev_cell = prev_state
            x_and_h = jnp.concatenate([inputs, prev_hidden], axis=-1)
            layer_idx = ['i', 'g', 'f', 'o']
            _apply_params = {}
            for i, id in enumerate(layer_idx):
                _apply_params[id] = qnn.get_params_by_scope(id, params)
            i = _linear.apply(_apply_params['i'], None, key, x_and_h)[0]
            g = _linear.apply(_apply_params['g'], None, key, x_and_h)[0]
            f = _linear.apply(_apply_params['f'], None, key, x_and_h)[0]
            o = _linear.apply(_apply_params['o'], None, key, x_and_h)[0]
            # i = input, g = cell_gate, f = forget_gate, o = output_gate
            f = jax.nn.sigmoid(f + 1)
            c = f * prev_cell + jax.nn.sigmoid(i) * jnp.tanh(g)
            h = jax.nn.sigmoid(o) * jnp.tanh(c)
            return jnp.stack([h, c], axis=0), h

        prev_state = jnp.zeros((2, inputs.shape[0], inputs.shape[-1]))
        inputs = inputs.transpose(1, 0, 2)
        _, outputs = scan(cell_fn, prev_state, inputs)
        outputs = outputs.transpose(1, 0, 2)
        return outputs, state
    return qnn.ModuleFn(apply_fn, init_fn)


def lstm_network(hps, layer_func: ModuleFn = linear, **kwargs) -> ModuleFn:
    """ Create an LSTM Network.
    Args:
        n_features: The number of features.
        layer_func: The type of layers to use.
    """
    preprocessing = [linear( int(hps.n_features/2) ), sigmoid]
    features = [lstm_cell(hps=hps, layer_func=layer_func)]
    postprocessing = [linear(1), sigmoid]
    layers = preprocessing + features + postprocessing
    net = sequential(*layers)
    return net


In [8]:
# from models import simple_network, attention_network
from qnn import linear
from train import build_train_fn
from qnn import ortho_linear, ortho_linear_noisy
from models import simple_network, attention_network
from loss_metrics import entropy_loss
from data import gen_paths
from utils import train_test_split, get_batches, HyperParams
import numpy as np
from tqdm import tqdm
import optax
from functools import partial 
from utils import HyperParams
seed = 100
key = jax.random.PRNGKey(seed)
hps = HyperParams(S0=100,
                  n_steps=5,
                  n_paths=10000,
                  discrete_path=False,
                  strike_price=100,
                  epsilon=0.0,
                  sigma=0.2,
                  risk_free=0,
                  dividend=0,
                  model_type='simple',
                  layer_type='noisy_ortho',
                  n_features=8,
                  n_layers=1,
                  loss_param=1.0,
                  batch_size=5,
                  test_size=0.2,
                  optimizer='adam',
                  learning_rate=1E-3,
                  num_epochs=100
                  )



# Data
S = gen_paths(hps)
[S_train, S_test] = train_test_split([S], test_size=0.2)
_, test_batches = get_batches(jnp.array(S_test[0]), batch_size=hps.batch_size)
test_batch = test_batches[0]


100%|██████████| 5/5 [00:00<00:00, 1907.89it/s]


In [9]:
from utils import load_params
class DeepHedgingBenchmark():
  """
  Runs the benchmark with different models / layers
  Input: test_batch above
  test_batch has 8 datapoints
  """
  def __init__(self, key, eps,  layers, models):
      self.__key = key
      self.__models = models
      self.__layers = layers
      self.__eps = eps
      self.test_info = {layer:{str(eps):{} for eps in self.__eps} for layer in self.__layers}
  def __test_model(self, hps, test_batch, save_dir = 'params_all_models_5_days.pkl'):
    # set up global objects for pickling circuit execution results
    global global_number_of_circuits_executed
    global global_hardware_run_results_dict
    global_number_of_circuits_executed = 0
    global_hardware_run_results_dict = {
        'model_type' : hps.model_type,
        'measurementRes' : None,
        'epsilon' : hps.epsilon,
        'backend_name' : None,
        'layer_type' : hps.layer_type,
    }
    if hps.layer_type in ['linear','linear_svb']:
      layer_func = linear
    elif hps.layer_type=='ortho':
      layer_func = ortho_linear
    elif hps.layer_type=='noisy_ortho':
      layer_func = partial(ortho_linear_noisy,noise_scale=0.01)
    elif hps.layer_type=='hardware_ortho':
      # TODO want to run this on Quantinuum device 
      layer_func = ortho_linear_hardware

    if hps.model_type == 'simple':
      net = simple_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'recurrent':
      net = recurrent_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'lstm':
      net = lstm_network(hps=hps, layer_func=layer_func)
    elif hps.model_type == 'attention':
      net = attention_network(hps=hps, layer_func=layer_func)
    
    opt = optax.adam(1E-3)
    key, init_key = jax.random.split(self.__key)
    _, state, _ = net.init(init_key, (1, hps.n_steps, 1))
    loss_metric = entropy_loss

    # Training

    train_fn, loss_fn = build_train_fn(hps, net, opt, loss_metric)

    train_info = load_params(save_dir)
    layer_type = "noisy_ortho" if hps.layer_type == 'hardware_ortho' else hps.layer_type
    train_losses, params = train_info[layer_type][str(hps.epsilon)][hps.model_type]
    loss, _ = loss_fn(params, state, key, test_batch[...,None])
    print(f'Model = {hps.model_type} | Layer = {hps.layer_type} | EPS = {hps.epsilon}| Loss = {loss} | #circs = {global_number_of_circuits_executed}')
    return loss
  def test(self, inputs):
    for model in self.__models:
      for eps in self.__eps:
        for layer in self.__layers:
            hps = HyperParams(S0=100,
                  n_steps=5,
                  n_paths=120000,
                  discrete_path=True,
                  strike_price=100,
                  epsilon=eps,
                  sigma=0.2,
                  risk_free=0,
                  dividend=0,
                  model_type=model,
                  layer_type=layer,
                  n_features=8,
                  n_layers=1,
                  loss_param=1.0,
                  batch_size=5,
                  test_size=0.2,
                  optimizer='adam',
                  learning_rate=1E-3,
                  num_epochs=100)
            self.test_info[layer][str(eps)][model] = self.__test_model(hps, inputs)
    



In [10]:
seed = 100
key = jax.random.PRNGKey(seed)

# LAYERS = ['hardware_ortho']
# EPS = [ 0.01]
# MODELS = ['simple','recurrent','lstm','attention']

# LAYERS = ['linear','ortho','noisy_ortho','hardware_ortho']
# EPS = [ 0.01]
# MODELS = ['lstm']

# test only

LAYERS = ['hardware_ortho']
EPS = [ 0.01]
MODELS = ['simple','recurrent']


dhb = DeepHedgingBenchmark(key=key,eps=EPS, layers=LAYERS, models=MODELS)

In [11]:
dhb.test(test_batch)

Using precomputed counts from data/1031_simple_quantinuum_H1-2E_hardware_ortho_0.01.json
Model = simple | Layer = hardware_ortho | EPS = 0.01| Loss = 2.0746476650238037 | #circs = 25


NotImplementedError: Model recurrent not supported yet, only simple model is supported.