based on https://pennylane.ai/qml/demos/tutorial_qcbm#liu

Let $\{x'\}$ be a dataset of samples drawn, indipendently, from the same unknown probability distribution $\pi(x')$. We want to approximate it with a QCBM, treated as an implicit generative model, returning a 

$$ L(\theta) = \left\| \sum_x p_\theta(x) \phi(x) - \sum_x \pi(x) \phi(x) \right\|^2 $$
$$ K(x, y) = \frac{1}{c} \sum_{i=1} \exp\left(\frac{|x - y|^2}{2 \sigma_i^2}\right) $$
$$ L = \mathbb{E}_{x, y \sim p_\theta} \left[ K(x, y) \right] - 2 \mathbb{E}_{x \sim p_\theta, y \sim \pi} \left[ K(x, y) \right] + \mathbb{E}_{x, y \sim \pi} \left[ K(x, y) \right] $$



In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

# Maximum Mean Discrepancy: 
# to train the QCBM, we use the squared maximum mean discrepancy (MMD) as the loss function. We map the 
# quantum states to a feature space using the quantum kernel, and then calculate the MMD between the two
# distributions. The MMD is defined as the distance between the mean embeddings of the two distributions in 
# the feature space. The MMD is a measure of the difference between two distributions, and it is zero if and 
# only if the two distributions are the same.

class MMD:

    def __init__(self, scales, space):
        gammas = 1 / (2 * (scales**2))
        sq_dists = jnp.abs(space[:, None] - space[None, :]) ** 2
        self.K = sum(jnp.exp(-gamma * sq_dists) for gamma in gammas) / len(scales)
        self.scales = scales

    def k_expval(self, px, py):
        # Kernel expectation value
        return px @ self.K @ py

    def __call__(self, px, py):
        pxy = px - py
        return self.k_expval(pxy, pxy)


In [4]:
from functools import partial


class QCBM:

    def __init__(self, circ, mmd, py):
        self.circ = circ
        self.mmd = mmd
        self.py = py  # target distribution π(x)

    @partial(jax.jit, static_argnums=0) # decorator to speed up the function optimization by 
    def MMD_Loss(self, params):
        px = self.circ(params)
        return self.mmd(px, self.py), px
