In [None]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

from scipy.special import kl_div
import pennylane as qml 
from pennylane import qaoa
# from pennylane import numpy as np
import math
import random
import autograd.numpy as np
import numpy.linalg as la

import jax
import jax.numpy as jnp
import jax.lax as lax
import gc

import optax

from pandarallel import pandarallel

pandarallel.initialize()
jax.config.update('jax_enable_x64', True)

np.random.seed(2025)

In [None]:
# number of the data qubits
n=8

# umber of the ancilla qubits
n_A=3

# number of the training cycles
T=30

# The number of circuit layers of the denosing circuit PQC
L=5

# The number of states in the state set taken during the denosing process
N_data=100

# The number of states in the state set taken during the forward diffusion process
N_All_set=500

# iterations
epoch=300

# learning rate
lr = 0.005

# Parameters used to control the rate of diffusion in the original QuDDPM
h_s = 1
h_e = 4

# # Parameters used to control the rate of diffusion in the improved QuDDPM
# h_s = 0.1
# h_e = 0.4

diff_hs = np.linspace(h_s, h_e, T)

file_path = 'original'

#### Forward diffusion circuit

In [None]:
dev = qml.device('default.qubit', wires=n)

@qml.qnode(dev, interface='jax')
def QSC(input_state, phi_t, g_t):
    qml.QubitStateVector(input_state, wires=range(n))
    for i in range(n):
        qml.RZ(phi_t[i][0], wires=i)
        qml.RY(phi_t[i][1], wires=i)
        qml.RZ(phi_t[i][2], wires=i)
    for i in range(n):
        for j in range(i):
            cost_h = qml.Hamiltonian([1/(2*n ** 0.5)],[qml.Z(j) @ qml.Z(i)])
            qaoa.cost_layer(g_t,cost_h)
    return qml.state()

QSC_jit = jax.jit(QSC)

def forward_nosiy_process(input_state, diff_hs):
    phi = np.random.uniform(0, 1, size=(T, N_All_set, n, 3))*np.pi/4-np.pi/8
    g = np.random.uniform(0, 1, size=(T, N_All_set))*0.2+0.4
    diff_state = np.zeros((T+1, N_All_set, 2**n), dtype=complex)
    diff_state[0] = input_state
    for t in range(T):
        for i in range(N_All_set):
            diff_state[t+1][i] = np.array(QSC_jit(diff_state[t][i],phi[t][i]*diff_hs[t], g[t][i]*diff_hs[t]))
    return diff_state

### Backward denosing circuit

In [None]:
dev_QPC_measurement = qml.device('default.qubit', wires=n_A, shots=1)

# Obtain the measurement results
@qml.qnode(dev_QPC_measurement, interface='jax')
def PQC_measurement(theta_t):
    for l in range(L):
        l_index = l*(n+n_A)*2
        for i in range(n_A):
            i_index = i*2
            qml.RX(theta_t[l_index+i_index], wires=i)
            qml.RY(theta_t[l_index+i_index+1], wires=i)
        for i in range((n_A)//2):
            qml.CZ(wires=[2*i, 2*i+1])
        for i in range((n_A-1)//2):       
            qml.CZ(wires=[2*i+1, 2*i+2])
    measurements = [qml.sample(qml.PauliZ(wires=[wire])) for wire in range(n_A)]
    return measurements
PQC_measurement_jit = jax.jit(PQC_measurement)

# PQCs
dev_QPC_state = qml.device('default.qubit', wires=n+n_A)
@qml.qnode(dev_QPC_state, interface='jax')
def PQC_state(input_state, theta_t):
    qml.QubitStateVector(input_state, wires=range(n_A, n+n_A))
    for l in range(L):
        l_index = l*(n+n_A)*2
        for i in range(n+n_A):
            i_index = i*2
            qml.RX(theta_t[l_index+i_index], wires=i)
            qml.RY(theta_t[l_index+i_index+1], wires=i)
        for i in range((n+n_A)//2):
            qml.CZ(wires=[2*i, 2*i+1])
        for i in range((n+n_A-1)//2):       
            qml.CZ(wires=[2*i+1, 2*i+2])
    return qml.state()
PQC_state_jit = jax.jit(PQC_state)

In [None]:
# Extract the "state of the data qubit" from the state obtained by the PQCs based on the measurement results.
def measure(input_state, theta_t):
    measurement = PQC_measurement_jit(theta_t)
    mask = jnp.stack(measurement) == -1
    start =np.sum(mask*(2**(n+np.arange(n_A, 0, -1)-1)))
    output_state = lax.dynamic_slice(input_state, (start,), (2**n,))
    norm = jnp.sqrt(jnp.sum(jnp.abs(output_state)**2))
    return output_state/norm
measure_jit = jax.jit(measure)

def get_worker_state(input_state, theta_t):
    output_state = PQC_state_jit(input_state, theta_t)
    worker_state = jnp.zeros((N_data, 2**n), dtype=jnp.complex128)
    for i in range(N_data):
        worker_state = worker_state.at[i].set(measure_jit(output_state[i], theta_t))
    return worker_state
get_worker_state_jit = jax.jit(get_worker_state)

### Loss

In [None]:
def mean_fideliy(S_1, S_2):
    fid = 1.-jnp.mean(jnp.abs(jnp.einsum('mi,ni->mn', jnp.conj(jnp.array(S_1)), jnp.array(S_2)))**2)
    return fid

def loss_MMD(theta_t, input_state, S_y):
    S_x = get_worker_state(input_state, theta_t)
    distance = 2*mean_fideliy(S_x, S_y)-mean_fideliy(S_y, S_y) - mean_fideliy(S_x, S_x)
    return distance

loss_mmd_jit = jax.jit(loss_MMD)
grad_loss_mmd = jax.grad(loss_mmd_jit)

### Generate the initial data (extended GHZ states)

In [None]:
def init_ghz_state():
    state = np.zeros((N_All_set, 2**n))
    t=np.random.uniform(0, 2*np.pi, N_All_set)
    for i in range(N_All_set):
        state[i][0] = math.cos(t[i]/2)
        state[i][2**n-1] = math.sin(t[i]/2)
    return state

### Forward diffusion process

In [None]:
initial_state = init_ghz_state()
diff_hs = np.linspace(h_s, h_e, T)
S_y = forward_nosiy_process(initial_state, diff_hs)
np.save('data/%s/state_diff/h%d_%dT%dNDate%depoch%dn%dn_A%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A), np.array(S_y))

### backward denoising process

In [None]:
# S_y represents each data set during the diffusion process
S_y = np.load('data/%s/state_diff/h%d_%dT%dNDate%depoch%dn%dn_A%d.npy'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A))

#### original

In [None]:
dev = qml.device("default.qubit", wires=n)

# The haar random states are obtained through the random circuit
@qml.qnode(dev)
def haar_state(weights, seed=None):
    qml.RandomLayers(weights=weights, wires=range(n), seed=seed)
    return qml.state()

haar_state_jit = jax.jit(haar_state)

S_x = np.zeros(shape=(N_data, 2**n), dtype=np.complex128)
for i in range(N_data):
    n_layers = 3
    weights = np.random.uniform(-np.pi, np.pi, size=(n_layers, 3))
    seed = np.random.randint(0, N_data*100)
    S_x[i] = haar_state(weights, seed)

#### improved

In [None]:
# S_x = np.zeros(shape=(N_All_set, 2**n), dtype=np.complex128)
# S_x[:, 0] = 1.
# S_x = forward_nosiy_process(S_x, diff_hs)

# sample_index = np.random.choice(np.arange(0, N_All_set), size=N_data, replace=False)
# S_x=S_x[-1, sample_index]

#### train

In [None]:
train_loss = np.zeros((T,epoch), dtype=float)
theta = np.zeros((T, 2*(n+n_A)*L), dtype=float)
Set = np.zeros((T, N_data, 2**n), dtype=np.complex128)
grad_norm = []

for t in range(T-1, -1, -1):
    gc.collect()
    jax.device_put(jax.random.normal(jax.random.PRNGKey(0), (1,)))

    theta_t = qml.numpy.random.rand(2*(n+n_A)*L, requires_grad=True)
    theta_t = jnp.array(theta_t)

    optimizer = optax.sgd(learning_rate=lr)
    opt_state = optimizer.init(theta_t)
    
    for step in range(epoch):
        
        y_sample_index = np.random.choice(np.arange(0, N_All_set), size=N_data, replace=False)
        S_y_t = S_y[t, y_sample_index]

        gradients = grad_loss_mmd(theta_t, S_x, S_y_t)
        updates, opt_state = optimizer.update(gradients, opt_state, theta_t)
        theta_t = optax.apply_updates(theta_t, updates)

        gradients_np = jax.tree_util.tree_map(lambda x: np.array(x), gradients)
        gradnorm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree_util.tree_leaves(gradients_np)))

        grad_norm.append(gradnorm)
        train_loss[T-1-t][step] = loss_MMD(theta_t, S_x, S_y_t)

        if step%10 ==0:

            print(f'{step}------------------------------')
            print(f"grad_norm:{grad_norm[-1]}")
            print(f"Training loss: {train_loss[T-1-t][step]}")
            print('------------------------------')
            
            
            
    theta[t] = theta_t
    S_x = get_worker_state(S_x, theta_t)
    Set[t] = S_x
    print(f'--------------------------------------------------t:{t}')

np.save('data/%s/train/h%d_%dT%dNDate%depoch%dn%dn_A%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A), np.array(Set))
np.save('data/%s/params/h%d_%dT%dNDate%depoch%dn%dn_A%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A), np.array(theta))
np.save('data/%s/loss/h%d_%dT%dNDate%depoch%dn%dn_A%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A), train_loss)
np.save('data/%s/grad/h%d_%dT%dNDate%depoch%dn%dn_A%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A), grad_norm)


## Sampling

#### original

In [None]:
dev = qml.device("default.qubit", wires=n)

@qml.qnode(dev)
def haar_state(weights, seed=None):
    qml.RandomLayers(weights=weights, wires=range(n), seed=seed)
    return qml.state()

haar_state_jit = jax.jit(haar_state)

S_x = np.zeros(shape=(N_data, 2**n), dtype=np.complex128)
for i in range(N_data):
    n_layers = 3
    weights = np.random.uniform(-np.pi, np.pi, size=(n_layers, 3))
    seed = np.random.randint(0, N_data*100)
    S_x[i] = haar_state(weights, seed)

#### improved

In [None]:
# S_x = np.zeros(shape=(N_All_set, 2**n), dtype=np.complex128)
# S_x[:, 0] = 1.
# S_x = forward_nosiy_process(S_x, diff_hs)

# sample_index = np.random.choice(np.arange(0, N_All_set), size=N_data, replace=False)
# S_x=S_x[-1, sample_index]

In [None]:
def kl_divergence(p, q):
    p = np.array(p)
    q = np.array(q)
    
    kl = np.sum(kl_div(p, q))/p.shape[0]
    return kl

In [None]:
theta = np.load('data/%s/params/h%d_%dT%dNDate%depoch%dn%dn_A%d.npy'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A))
S_y = np.load('data/%s/state_diff/h%d_%dT%dNDate%depoch%dn%dn_A%d.npy'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A))

kl=np.zeros((T))

for t in range(T-1, -1, -1):
    y_sample_index = np.random.choice(np.arange(0, N_All_set), size=N_data, replace=False)
    S_y_t = S_y[0, y_sample_index]
    S_x = get_worker_state(S_x, theta[t])
    for i in range(N_data):
        for j in range(N_data):
            np.abs(S_y_t) ** 2
            kl[t]+=kl_divergence(np.abs(S_y_t[i])**2, np.abs(S_x[j])**2)
    kl[t]=kl[t]/(N_data**2)

np.save('data/%s/kl/h%d_%dT%dNDate%depoch%dn%dn_A%dlr%d'%(file_path, h_s, h_e, T, N_data, epoch, n, n_A, lr), kl)