# Demonstration of the Observable-Tunable Expectation Value Sampler Quantum Generative Model (OT-EVS) on the controlled experiments with synthetic datasets
### This notebook produces similar results as in Sections V.A-C of the article "Shadow-Frugal Expectation-Value-Sampling Variational Quantum Generative Model" (arXiv:2412.17039).
### Below we demonstrate the training of OT-EVS with the classical shadows measurements. Users may check the other notebooks which use conventional measurements. 

## Import Modules

In [1]:
import sys
import os

from types import SimpleNamespace
from itertools import product
import copy
import math

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray
from jax.tree_util import tree_map
import equinox as eqx
import optax
import tensorcircuit as tc

import faiss
from scipy.special import digamma

from datetime import datetime
from tqdm import tqdm

Please first ``pip install -U qiskit`` to enable related functionality in translation module
Please first ``pip install -U cirq`` to enable related functionality in translation module


## Experiment Configurations

In [2]:
config = {
    # critic architecture
    'critic_layer_size': 512,   # width of MLP hidden layers in the critic
    'critic_depth': 4,  # depth of MLP hidden layers in the critic
    'n_critic': 5,  # how many times to update the critic before updating the generator

    # generator architecture
    'latent_dim': 2,   # dimension of input Gaussian random variables
    'data_dim': 2,  # dimension of output 
    'nq': 4,  # number of qubits
    'nl': 2,   # number of circuit layers
    'k': 1,   # locality of observables
    'n_shots': 4096,   # number of shots per observable

    # learning and decay rates for the generator (circuit part)
    'lr_gq': 0.001, 
    'b1_gq': 0,
    'b2_gq': 0.9,

    # learning and decay rates for the generator (observable part)
    'lr_gl': 0.0001, 
    'b1_gl': 0.9,
    'b2_gl': 0.9,

    # learning and decay rates for the critic
    'lr_c': 0.0001,
    'b1_c': 0.5,
    'b2_c': 0.9,

    'lambda_gp': 0.1,  # scalar in front of the gradient penalty term 
    'batch_size': 256,  # batch size

    'n_iter': 20000,  # how many training iterations to use
    'eval_freq': 200,   # how often to estimate the KLD
    'train_size': 65536,   # how many data to include in the training set
    'eval_size': 2048   # how many samples and training data to use to estimate the KLD
}

config = SimpleNamespace(**config)

## Model Architecture

In [3]:
K = tc.set_backend('jax')

def get_all_k_local_observables(nq, k):
    '''
    The observables
    '''
    all_tuples = product([0, 1, 2, 3], repeat=nq)
    valid_tuples = [t for t in all_tuples if (sum(1 for x in t if x == 0) >= nq - k and sum(1 for x in t if x == 0) < nq)]
    
    return jnp.array(valid_tuples)


def get_circuit(nq, nl, inputs, weights):
    '''
    The circuit
    '''
    circuit = tc.Circuit(nq)
    for l in range(nl):
        for i in range(nq):
            circuit.rx(i, theta=inputs[l])
            circuit.ry(i, theta=weights[l,i])
        for i in range(0,nq-1):
            circuit.cnot(i, i+1)
            circuit.ry(i+1, theta=weights[l, nq+i])
            circuit.cnot(i, i+1)
    
    return circuit 


class GeneratorQuantum(eqx.Module):
    nq: int = eqx.field(static=True)
    nl: int = eqx.field(static=True)
    k: int = eqx.field(static=True)
    weights: jax.Array
    
    @K.jit
    def evaluate_circuit(self, inputs, observable):
        circuit = get_circuit(self.nq, self.nl, inputs, self.weights)
        return tc.templates.measurements.parameterized_measurements(circuit, observable, onehot=True)

    ### Call this function when using the classical shadows method
    def get_2k_values(self, x):
        all_2k_observables = get_all_k_local_observables(self.nq, int(min(2 * self.k, self.nq)))
        return K.vmap(self.evaluate_circuit, vectorized_argnums=1)(x, all_2k_observables)

    def __call__(self, x):
        all_observables = get_all_k_local_observables(self.nq, self.k)

        return K.vmap(self.evaluate_circuit, vectorized_argnums=1)(x, all_observables)

In [4]:
class GeneratorLinear(eqx.Module):
    ''' 
    The generator (observable part)
    '''
    model: eqx.Module
    
    def __init__(self, n_obs, data_dim, key):
        super(GeneratorLinear, self).__init__()
        
        self.model = eqx.nn.Linear(n_obs, data_dim, key=key)
        
    def __call__(self, x):
        return self.model(x)

In [5]:
class Critic(eqx.Module):
    '''
    The critic
    '''
    layers:list
    
    def __init__(self, data_dim, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Linear(data_dim, 512, key=key1), 
            jax.nn.relu,
            eqx.nn.Linear(512, 512, key=key2),
            jax.nn.relu,
            eqx.nn.Linear(512, 512, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(512, 1, key=key4),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## KLD Estimator

In [6]:
def kld_estimator(s1, s2):
    # equation 25 of the reference paper
    s1, s2 = np.array(s1), np.array(s2)
    n, m = len(s1), len(s2)
    d = int(s1.shape[1])

    #res = faiss.StandardGpuResources()

    index_s1=faiss.IndexFlatL2(d)
    #index_s1=faiss.index_cpu_to_gpu(res,0,index_s1)
    index_s1.add(s1)

    index_s2=faiss.IndexFlatL2(d)
    #index_s2=faiss.index_cpu_to_gpu(res,0,index_s2)
    index_s2.add(s2)

    fulldist1 = np.sqrt(index_s1.search(s1, n)[0])
    fulldist2 = np.sqrt(index_s2.search(s1, m)[0])

    rhoi=fulldist1[::,1].reshape(-1)
    nui=fulldist2[::,0].reshape(-1)

    epsilon=np.maximum(rhoi, nui)
    arg=np.where(rhoi>=nui, 0, 1)

    li = np.array([np.searchsorted(fulldist1[i], epsilon[i], side='right') for i in range(m)]) - 1
    ki = np.array([np.searchsorted(fulldist2[i], epsilon[i], side='right') for i in range(n)])


    return np.mean(digamma(li)-digamma(ki)) + np.log(m / (n - 1))

## Train Loop

In [7]:
def train(config, seed_data, seed_initial):
    ############################# Data loader #####################################
    def dataloader(data, batch_size, *, key):
        dataset_size = data.shape[0]
        indices = jnp.arange(dataset_size)
        while True:
            key, subkey = jax.random.split(key, 2)
            perm = jax.random.permutation(subkey, indices)
            start = 0
            end = batch_size
            while end < dataset_size:
                batch_perm = perm[start:end]
                yield data[batch_perm]
                start = end
                end = start + batch_size 
    # Yield a batch of training data from the training set
    def infinite_trainloader():
        while True:
            yield from dataloader

    ############################ Add shot noise (conventional method) to ideal outputs ############################
    @jax.jit
    def add_sampling_error(exact, n_shots, key):
        p = jnp.clip((1 - exact) / 2, 0, 1)
        mean = n_shots * p
        std = jnp.sqrt(jnp.clip(n_shots * p * (1 - p), min=1e-16))
        return 1 - 2 * jnp.clip((jax.random.normal(key) * std + mean) / n_shots, 0, 1)

    ############################ Add shot noise (classical shadow method) to ideal outputs ############################
    def index_and_factor_matrices(nq, k):
        '''The first and second output entries are pauli basis and factor respectively. '''
        rule_map = {(0, 0): (0, 1), (0, 1): (1, 1), (0, 2): (2, 1), (0, 3): (3, 1),
                    (1, 0): (1, 1), (1, 1): (0, 3), (1, 2): (-1, 0), (1, 3): (-1, 0),
                    (2, 0): (2, 1), (2, 1): (-1, 0), (2, 2): (0, 3), (2, 3): (-1, 0),
                    (3, 0): (3, 1), (3, 1): (-1, 0), (3, 2): (-1, 0), (3, 3): (0, 3)}
        
        def rule(pauli_a, pauli_b):
            '''Optimized rule lookup using pre-defined map.'''
            return rule_map[(pauli_a, pauli_b)]
        
        obs_k = np.array(get_all_k_local_observables(nq, k))
        obs_2k = np.array(get_all_k_local_observables(nq, min(2*k, nq)))
        obs_2k = np.concatenate([np.zeros((1, nq), dtype=int), obs_2k], axis=0)
        
        num_k = len(obs_k)
        
        indices = np.zeros((num_k, num_k), dtype=int)
        factors = np.zeros((num_k, num_k), dtype=int)
        
        obs_2k_dict = {tuple(row): idx for idx, row in enumerate(obs_2k)}
        
        for i in range(num_k):
            for j in range(num_k):
                output = np.array([rule(obs_k[i][q], obs_k[j][q]) for q in range(nq)])
                
                obs_out = tuple(output[:, 0])  # convert the array to a tuple for hashing
                factor = np.prod(output[:, 1])
                
                indices[i][j] = obs_2k_dict.get(obs_out, 0)  # default to 0 if not found
                factors[i][j] = factor
    
        return jnp.array(indices, dtype=int), jnp.array(factors, dtype=int)


    def generate_covariance_matrix(obs_k, obs_2k, indices_matrix, factors_matrix):
    
        return obs_2k[indices_matrix] * factors_matrix - jnp.outer(obs_k, obs_k)
    
    
    def add_shadow_error(mean, cov, n_shots, key):
        L = jnp.linalg.cholesky(cov / n_shots)
    
        return mean + jnp.dot(jax.random.normal(key, shape=(1, mean.shape[0])), L.T)
        
    ###################################### One iteration during training  ######################################
    @eqx.filter_jit
    def train_step(generator_quantum_params, generator_linear_params, critic_params, generator_quantum_opt_state, generator_linear_opt_state, critic_opt_state, key):

        # The block below contains subroutines  for a training step. 
        ############################### Classical Shadows Method ##############################
        # Evaluate the generator (circuit part)
        @eqx.filter_value_and_grad(has_aux=False)
        def compute_grads_generator_quantum(generator_quantum_params, generator_linear_params, critic_params, z, keys):
            generator_quantum = eqx.combine(generator_quantum_params, generator_quantum_static)
            fake_batch_intermediate = jax.vmap(generator_quantum, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jax.vmap(generator_quantum.get_2k_values, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jnp.concatenate([jnp.ones((config.batch_size,1)), obs_2k_batch], axis=1)
            cov_batch = jax.vmap(generate_covariance_matrix, in_axes=(0, 0, None, None))(fake_batch_intermediate, obs_2k_batch, indices_matrix, factors_matrix)
            fake_batch_sampled = jnp.squeeze(jax.vmap(add_shadow_error, in_axes=(0, 0, None, 0))(fake_batch_intermediate, cov_batch, config.n_shots * n_obs, keys))
            generator_linear = eqx.combine(generator_linear_params, generator_linear_static)
            fake_batch = jax.vmap(generator_linear, in_axes=0, out_axes=0)(fake_batch_sampled)
            critic = eqx.combine(critic_params, critic_static)
            fake_value = jax.vmap(critic, in_axes=0, out_axes=0)(fake_batch)
            loss = -fake_value.mean()

            return loss
        
        # Evaluate the generator (observable part)
        @eqx.filter_value_and_grad(has_aux=False)
        def compute_grads_generator_linear(generator_linear_params, generator_quantum_params, critic_params, z, keys):
            generator_quantum = eqx.combine(generator_quantum_params, generator_quantum_static)
            fake_batch_intermediate = jax.vmap(generator_quantum, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jax.vmap(generator_quantum.get_2k_values, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jnp.concatenate([jnp.ones((config.batch_size,1)), obs_2k_batch], axis=1)
            cov_batch = jax.vmap(generate_covariance_matrix, in_axes=(0, 0, None, None))(fake_batch_intermediate, obs_2k_batch, indices_matrix, factors_matrix)
            fake_batch_sampled = jnp.squeeze(jax.vmap(add_shadow_error, in_axes=(0, 0, None, 0))(fake_batch_intermediate, cov_batch, config.n_shots * n_obs, keys))
            generator_linear = eqx.combine(generator_linear_params, generator_linear_static)
            fake_batch = jax.vmap(generator_linear, in_axes=0, out_axes=0)(fake_batch_sampled)
            critic = eqx.combine(critic_params, critic_static)
            fake_value = jax.vmap(critic, in_axes=0, out_axes=0)(fake_batch)
            loss = -fake_value.mean()

            return loss

        # Subroutine for evaluating the critic
        @eqx.filter_vmap(in_axes=(0, None))
        @eqx.filter_grad(has_aux=False)
        def critic_forward(input_data, critic):
            """Helper function to calculate the gradients with respect to the input."""
            value = critic(input_data)
            return value[0]

        # Evaluate the critic
        @eqx.filter_value_and_grad(has_aux=False)
        def compute_grads_critic(critic_params, generator_quantum_params, generator_linear_params, real_batch, z, key, keys):
            generator_quantum = eqx.combine(generator_quantum_params, generator_quantum_static)
            fake_batch_intermediate = jax.vmap(generator_quantum, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jax.vmap(generator_quantum.get_2k_values, in_axes=0, out_axes=0)(z)
            obs_2k_batch = jnp.concatenate([jnp.ones((config.batch_size,1)), obs_2k_batch], axis=1)
            cov_batch = jax.vmap(generate_covariance_matrix, in_axes=(0, 0, None, None))(fake_batch_intermediate, obs_2k_batch, indices_matrix, factors_matrix)
            fake_batch_sampled = jnp.squeeze(jax.vmap(add_shadow_error, in_axes=(0, 0, None, 0))(fake_batch_intermediate, cov_batch, config.n_shots * n_obs, keys))
            generator_linear = eqx.combine(generator_linear_params, generator_linear_static)
            fake_batch = jax.vmap(generator_linear, in_axes=0, out_axes=0)(fake_batch_sampled)
            critic = eqx.combine(critic_params, critic_static)
            fake_value = jax.vmap(critic, in_axes=0, out_axes=0)(fake_batch)
            real_value = jax.vmap(critic, in_axes=0, out_axes=0)(real_batch)
            
            epsilon = jax.random.uniform(key, shape=(config.batch_size, 1), minval=0, maxval=1)
            data_mix = real_batch * epsilon + fake_batch * (1 - epsilon) 
            
            grads = critic_forward(data_mix, critic)
            grad_norm = jnp.linalg.norm(grads, axis=1)
            gradient_penalty = jnp.mean((grad_norm - 1) ** 2)
            
            loss = -jnp.mean(real_value) + jnp.mean(fake_value) + config.lambda_gp * gradient_penalty

            return loss

            
        ### The block below is what differs for the three training algorithms. Comment and Uncomment the blocks to switch algorithms

        ########################################### Asynchronous version ########################################################
        '''
        for _, real_batch in zip(range(config.n_critic), infinite_trainloader()):
            key, subkey, subsubkey, subsubsubkey = jax.random.split(key, 4)
            z = jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
            keys = jax.random.split(subsubsubkey, config.batch_size).reshape(config.batch_size, 2)
            
            loss_critic, grads = compute_grads_critic(critic_params, generator_quantum_params, generator_linear_params, real_batch, z, subsubkey, keys)
            updates, critic_opt_state = tx_c.update(grads, critic_opt_state)
            critic_params = eqx.apply_updates(critic_params, updates)            
        
        key, subkey, subsubkey = jax.random.split(key, 3)
        z= jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
        keys = jax.random.split(subsubkey, config.batch_size).reshape(config.batch_size, 2)
        
        for _ in range(config.n_critic):
            loss_generator_linear, grads = compute_grads_generator_linear(generator_linear_params, generator_quantum_params, critic_params, z, keys)
            updates, generator_linear_opt_state = tx_gl.update(grads, generator_linear_opt_state)
            generator_linear_params = eqx.apply_updates(generator_linear_params, updates) 
        
        loss_generator_quantum, grads = compute_grads_generator_quantum(generator_quantum_params, generator_linear_params, critic_params, z, keys)
        updates, generator_quantum_opt_state = tx_gq.update(grads, generator_quantum_opt_state)
        generator_quantum_params = eqx.apply_updates(generator_quantum_params, updates) 
        '''
        ########################################### Decoupled version ##########################################################
        '''
        for _, real_batch in zip(range(config.n_critic), infinite_trainloader()):
            key, subkey, subsubkey, subsubsubkey = jax.random.split(key, 4)
            z = jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
            keys = jax.random.split(subsubsubkey, config.batch_size).reshape(config.batch_size, 2)
            
            loss_critic, grads = compute_grads_critic(critic_params, generator_quantum_params, generator_linear_params, real_batch, z, subsubkey, keys)
            updates, critic_opt_state = tx_c.update(grads, critic_opt_state)
            critic_params = eqx.apply_updates(critic_params, updates)            

            loss_generator_linear, grads = compute_grads_generator_linear(generator_linear_params, generator_quantum_params, critic_params, z, keys)
            updates, generator_linear_opt_state = tx_gl.update(grads, generator_linear_opt_state)
            generator_linear_params = eqx.apply_updates(generator_linear_params, updates) 
        
        key, subkey, subsubkey = jax.random.split(key, 3)
        z= jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
        keys = jax.random.split(subsubkey, config.batch_size).reshape(config.batch_size, 2)
        
        loss_generator_quantum, grads = compute_grads_generator_quantum(generator_quantum_params, generator_linear_params, critic_params, z, keys)
        updates, generator_quantum_opt_state = tx_gq.update(grads, generator_quantum_opt_state)
        generator_quantum_params = eqx.apply_updates(generator_quantum_params, updates) 
        '''

        ########################################### Joint version ##############################################################  

        for _, real_batch in zip(range(config.n_critic), infinite_trainloader()):
            key, subkey, subsubkey, subsubsubkey = jax.random.split(key, 4)
            z = jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
            keys = jax.random.split(subsubsubkey, config.batch_size).reshape(config.batch_size, 2)
            
            loss_critic, grads = compute_grads_critic(critic_params, generator_quantum_params, generator_linear_params, real_batch, z, subsubkey, keys)
            updates, critic_opt_state = tx_c.update(grads, critic_opt_state)
            critic_params = eqx.apply_updates(critic_params, updates)            
        
        key, subkey, subsubkey = jax.random.split(key, 3)
        z= jax.random.uniform(subkey, shape=(config.batch_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
        keys = jax.random.split(subsubkey, config.batch_size).reshape(config.batch_size, 2)
        loss_generator_linear, grads_linear = compute_grads_generator_linear(generator_linear_params, generator_quantum_params, critic_params, z, keys)
        loss_generator_quantum, grads_quantum = compute_grads_generator_quantum(generator_quantum_params, generator_linear_params, critic_params, z, keys)
        updates_linear, generator_linear_opt_state = tx_gl.update(grads_linear, generator_linear_opt_state)
        generator_linear_params = eqx.apply_updates(generator_linear_params, updates_linear) 
        updates_quantum, generator_quantum_opt_state = tx_gq.update(grads_quantum, generator_quantum_opt_state)
        generator_quantum_params = eqx.apply_updates(generator_quantum_params, updates_quantum)         
        
        return generator_quantum_params, generator_linear_params, critic_params, generator_quantum_opt_state, generator_linear_opt_state, critic_opt_state, loss_generator_quantum, loss_generator_linear, loss_critic, key
                

    # Generate fake samples to evaluate the model
    @eqx.filter_jit
    def evaluate_fake(generator_quantum_params, generator_linear_params, key):
        z = jax.random.uniform(key, shape=(config.eval_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)  # now a full-size sample, not just a batch.
        
        generator_quantum = eqx.combine(generator_quantum_params, generator_quantum_static)
        fake_imgs_intermediate = jax.vmap(generator_quantum, in_axes=0, out_axes=0)(z)
        generator_linear = eqx.combine(generator_linear_params, generator_linear_static)
        fake_imgs = jax.vmap(generator_linear, in_axes=0, out_axes=0)(fake_imgs_intermediate)
        
        return fake_imgs

    # General real samples to evaluate the model
    @eqx.filter_jit
    def evaluate_real(key):
        z= jax.random.uniform(key, shape=(config.eval_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)  # now a full-size sample, not just a batch.
        real_imgs_intermediate = jax.vmap(generator_quantum_real, in_axes=0, out_axes=0)(z)
        real_imgs = jax.vmap(generator_linear_real, in_axes=0, out_axes=0)(real_imgs_intermediate)
        
        return real_imgs

    
    ################################################### make experiment folder #######################################################
    now = datetime.now()
    timestamp = now.strftime('%d_%m_%Y_%H_%M_%S')
    current_folder = os.path.abspath(os.getcwd()) +'/'    

    exp_folder = current_folder + '_' + timestamp + '_' + '_' + str(seed_data) + '_' + str(seed_initial) + '/'
    os.makedirs(exp_folder, exist_ok=True)


    ################################################### prepare training set #########################################################
    # set up keys
    key = jax.random.PRNGKey(seed=seed_data)
    key, key_real = jax.random.split(key, 2)    # key is used once to initialise the model and then split recursively during training, key_real is used once to initialise real data generator.

    ### The blocks below are for preparing real data generator 
    key_gq, key_gq2, key_gl, key_gl2, key_loader = jax.random.split(key_real, 5)

    # here quantum params are oriented
    theta = jax.random.normal(key_gq, shape=(config.nl, 2*config.nq-1)) * np.pi / 8 + jax.random.uniform(key_gq2, minval=-1, maxval=1) * np.pi 
    generator_quantum_real = GeneratorQuantum(nq=config.nq, nl=config.nl, k=config.k, weights=theta)   # create model with initialised weights

    # here observable weights are chosen sparse, bias are 0.         
    n_obs = len(get_all_k_local_observables(config.nq, config.k))
    generator_linear_1 = GeneratorLinear(n_obs=n_obs, data_dim=config.data_dim, key=key_gl)    # create model
    bias = jnp.zeros(config.data_dim)
    get_bias = lambda m: m.model.bias
    generator_linear_2 = eqx.tree_at(get_bias, generator_linear_1, bias)    # initialise bias

    def place_values_in_row(row_key):
        idxs = jax.random.choice(row_key, n_obs, shape=(3,), replace=False)
        return jnp.zeros(n_obs).at[idxs].set(jnp.array([1, 4, 9]))

    keys = jax.random.split(key_gl2, config.data_dim)
    weight_unnormalized = jnp.vstack([place_values_in_row(k) for k in keys])
    weight = weight_unnormalized / jnp.linalg.norm(weight_unnormalized)
    get_weight = lambda m: m.model.weight    
    generator_linear_real = eqx.tree_at(get_weight, generator_linear_2, weight)    # initialise weight

    generator_quantum_params_real, generator_quantum_static_real = eqx.partition(generator_quantum_real, eqx.is_array)
    generator_linear_params_real, generator_linear_static_real = eqx.partition(generator_linear_real, eqx.is_array)

    eqx.tree_serialise_leaves(exp_folder + "generator_quantum_real.eqx", copy.deepcopy(generator_quantum_params_real))
    eqx.tree_serialise_leaves(exp_folder + "generator_linear_real.eqx", copy.deepcopy(generator_linear_params_real))

    # prepare training set and data loader
    z= jax.random.uniform(key, shape=(config.train_size, config.latent_dim), minval=-jnp.pi, maxval=jnp.pi)
    dataset_intermediate = jax.vmap(generator_quantum_real, in_axes=0, out_axes=0)(z)
    dataset = jax.vmap(generator_linear_real, in_axes=0, out_axes=0)(dataset_intermediate)
    dataloader = dataloader(jnp.array(dataset), batch_size=config.batch_size, key=key_loader) 


    ################################################### initialize models #########################################################
    key = jax.random.PRNGKey(seed=seed_initial)
    key, key_gq, key_gl, key_c = jax.random.split(key, 4)

    # quantum parameters are uniformly distributed
    theta = jax.random.uniform(key_gq, shape=(config.nl, 2*config.nq-1), minval=-jnp.pi, maxval=jnp.pi)
    generator_quantum = GeneratorQuantum(nq=config.nq, nl=config.nl, k=config.k, weights=theta)
    
    # generator parameters are initialised by default (Kaiming uniform)
    generator_linear = GeneratorLinear(n_obs=n_obs, data_dim=config.data_dim, key=key_gl)
    
    # critic parameters are also initialised by default (Kaiming uniform)
    critic = Critic(data_dim=config.data_dim, key=key_c)
    
    generator_quantum_params, generator_quantum_static = eqx.partition(generator_quantum, eqx.is_array)
    generator_linear_params, generator_linear_static = eqx.partition(generator_linear, eqx.is_array)
    critic_params, critic_static = eqx.partition(critic, eqx.is_array)
    
    indices_matrix, factors_matrix = index_and_factor_matrices(config.nq, config.k)

    ################################################### initialize optimizers #########################################################
    tx_gq = optax.adam(learning_rate=config.lr_gq, b1=config.b1_gq, b2=config.b2_gq)
    tx_gl = optax.adam(learning_rate=config.lr_gl, b1=config.b1_gl, b2=config.b2_gl)
    tx_c = optax.adam(learning_rate=config.lr_c, b1=config.b1_c, b2=config.b2_c)

    generator_quantum_opt_state = tx_gq.init(generator_quantum_params)
    generator_linear_opt_state = tx_gl.init(generator_linear_params)
    critic_opt_state = tx_c.init(critic_params)
    
    loss_history = [] # tuples of (lg, lc, kld, quantum_params_dist, linear_params_dist)
    best=np.infty

    ################################################### train loop starts here #########################################################
    for i in tqdm(range(config.n_iter + 1)):
        key, subkey, subsubkey = jax.random.split(key, 3)
        generator_quantum_params, generator_linear_params, critic_params, generator_quantum_opt_state, generator_linear_opt_state, critic_opt_state, loss_generator_quantum, loss_generator_linear, loss_critic, key = train_step(generator_quantum_params, generator_linear_params, critic_params, generator_quantum_opt_state, generator_linear_opt_state, critic_opt_state, key)

        # Evaluation
        if i % config.eval_freq == 0:
            fake_imgs = evaluate_fake(generator_quantum_params, generator_linear_params, subkey)
            real_imgs = evaluate_real(subsubkey)

            kld = kld_estimator(fake_imgs, real_imgs)  # here k is used for KNN estimator. 
            loss_history.append((loss_generator_quantum.item(), loss_critic.item(), kld))

        if i in [0,1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000]:
            eqx.tree_serialise_leaves(exp_folder + str(i) + "_generator_quantum.eqx", copy.deepcopy(generator_quantum_params))
            eqx.tree_serialise_leaves(exp_folder + str(i) + "_generator_linear.eqx", copy.deepcopy(generator_linear_params))

        elif kld < best:
            eqx.tree_serialise_leaves(exp_folder + str(i) + "_generator_quantum.eqx", copy.deepcopy(generator_quantum_params))
            eqx.tree_serialise_leaves(exp_folder + str(i) + "_generator_linear.eqx", copy.deepcopy(generator_linear_params))
            best = kld

    loss_history = np.array(loss_history)
    np.save(exp_folder+"loss_history.npy", np.array(loss_history))

    # plot loss curves
    fig, ax = plt.subplots(1, 1, figsize=(8,8))
    ax.plot(np.arange(len(loss_history)) * config.eval_freq, loss_history[:,0], label='Critic Loss')
    ax.plot(np.arange(len(loss_history)) * config.eval_freq, loss_history[:,1], label='Generator Loss')
    ax.plot(np.arange(len(loss_history)) * config.eval_freq, loss_history[:,2], label='KLD')
    ax.set_xlabel('Updates of Quantum Parameters')
    ax.set_ylabel('Metrics')
    ax.legend()
    fig.tight_layout()
    plt.savefig(exp_folder+"training_curves.png")
    plt.close()

In [None]:
train(config=config, seed_data=6, seed_initial=90)

  0%|                                                 | 0/20001 [00:00<?, ?it/s]