In [None]:
using PauliPropagation
using PauliSampling
using Plots
using Statistics

In [None]:
nq = 16

topology = bricklayertopology(nq);

H = PauliSum(nq)

for qind in 1:nq
    add!(H, :X, qind, 1.0)
end

for pair in topology
    add!(H, [:Z, :Z], collect(pair), 1.0)
end

nl = 4

circuit = hardwareefficientcircuit(nq, nl; topology=topology)
nparams = countparameters(circuit)
thetas = randn(nparams);

max_freq = 25
max_weight = 5

@time psum = propagate(circuit, H, thetas; max_freq, max_weight);

function lossfunction(thetas)
    # differentiation libraries use custom types to trace through the computation
    # we need to make all of our objects typed like that so that nothing breaks
    CoeffType = eltype(thetas)

    # define H again 
    H = PauliSum(CoeffType, nq)
    for qind in 1:nq
        add!(H, :X, qind, CoeffType(1.0))
    end
    for pair in topology
        add!(H, [:Z, :Z], collect(pair), CoeffType(1.0))
    end

    # wrapp the coefficients into PauliFreqTracker so that we can use `max_freq` truncation.
    # usually this happens automatically but the in-place propagate!() function does not allow that.
    wrapped_H = wrapcoefficients(H, PauliFreqTracker)

    # be also need to run the in-place version with `!`, because by default we copy the Pauli sum
    output_H = propagate!(circuit, wrapped_H, thetas; max_freq, max_weight);
    return overlapwithzero(output_H)
end

  0.756964 seconds (817.97 k allocations: 46.493 MiB, 26.87% gc time, 92.32% compilation time)


lossfunction (generic function with 1 method)

In [None]:
using ReverseDiff: GradientTape, gradient!, compile
using ReverseDiff: gradient


In [None]:
grad_array = similar(thetas);

# pre-record a GradientTape for `gradsimulation` using inputs of length m with Float64 elements
@time const simulation_tape = GradientTape(lossfunction, thetas)

# first evaluation compiles and is slower
@time gradient!(grad_array, simulation_tape, thetas)
# second evaluation
@time gradient!(grad_array, simulation_tape, thetas);

  1.999754 seconds (20.78 M allocations: 920.752 MiB, 27.73% gc time, 43.82% compilation time)
  0.528443 seconds (393.32 k allocations: 19.527 MiB, 27.46% compilation time)
  0.378913 seconds


In [None]:
# compile to make it even faster
@time const compiled_simulation_tape = compile(simulation_tape)

# some inputs and work buffer to play around with
grad_array_compiled = similar(thetas);

# first evaluation compiles and is slower
@time gradient!(grad_array_compiled, compiled_simulation_tape, thetas)
# second evaluation
@time gradient!(grad_array_compiled, compiled_simulation_tape, thetas);

  3.662580 seconds (27.75 M allocations: 1.156 GiB, 42.16% gc time, 6.47% compilation time)
  0.318169 seconds (74.83 k allocations: 3.678 MiB, 12.37% compilation time)
  0.334400 seconds


### using Maximum Mean Discrepancy

$$MMD^2 (P, Q) = \mathbb{E}_{x,x' \sim P} \left[ k(x,x') \right] + \mathbb{E}_{y,x' \sim Q} \left[ k(y,y') \right] - 2 \mathbb{E}_{x \sim P, y \sim Q} \left[ k(x,y) \right]$$

In [None]:
# using LinearAlgebra

"""
Compute pairwise squared Hamming distances between bitstrings
represented as integers in 0:(2^n - 1)
"""
function hamming_distance_matrix(n)
    N = 2^n
    H = zeros(Int, N, N)
    for i in 0:N-1, j in 0:N-1
        H[i+1, j+1] = count_ones(i ⊻ j)
    end
    return H
end

"""
Compute MMD^2 between two probability distributions p and q using a kernel matrix K.
"""
function compute_mmd(p::Vector{Float64}, q::Vector{Float64}, K::Matrix{Float64})
    return p' * K * p + q' * K * q - 2 * p' * K * q
end

"""
RBF kernel over bitstrings, using Hamming distance as input.
gamma is the bandwidth parameter.
"""
function rbf_kernel_hamming(n::Int; gamma::Float64=1.0)
    H = hamming_distance_matrix(n)
    return exp.(-gamma .* Float64.(H))
end


rbf_kernel_hamming

In [None]:
nq = 5
init_psum = zero_state(nq)
circuit = build_circuit(nq, topology_type=:staircase, circuit_type=:heisenbergtrotter)
nparams = countparameters(circuit)
thetas = randn(nparams) * 0.5
exact_psum = propagate(circuit, init_psum, thetas)
max_weight = 5
trunc_psum = propagate(circuit, init_psum, thetas; max_weight=max_weight)

p = get_dist(exact_psum, approximate_prob)
q = get_dist(trunc_psum, approximate_prob)

K = rbf_kernel_hamming(nq, gamma=0.5)
mmd_value = compute_mmd(p, q, K)


## MMD from samples

two sets of bitstring samples:

- $X = {x_1 , \dots , x_m} \sim P$
- $Y = {y_1 , \dots , y_n} \sim Q$

The empirical unbiased MMD estimator is
$$MMD^2(P,Q) = \frac{1}{m(m-1)} \sum_{i\neq j} k(x_i,x_j) + \frac{1}{n(n-1)} \sum_{i\neq j} k(y_i,y_j) - \frac{2}{mn} \sum_{i=1}^m\sum_{j=1}^n k(x_i,y_j) $$ 

In [None]:

"""
Compute the Hamming distance between two bitstrings represented as integers
"""
hamming(x::Int, y::Int) = count_ones(x ⊻ y)

"""
Compute RBF kernel value between two bitstrings (as integers), using Hamming distance
"""
function rbf_kernel(x::Int, y::Int; gamma::Float64=1.0)
    return exp(-gamma * hamming(x, y))
end

"""
Compute MMD² from samples of integers in [0, 2^n - 1]
"""
function mmd_from_samples(X::Vector{Int}, Y::Vector{Int}; gamma::Float64=1.0)
    m = length(X)
    n = length(Y)

    k_xx = sum(rbf_kernel(X[i], X[j]; gamma=gamma) for i in 1:m, j in 1:m if i != j) / (m * (m - 1))
    k_yy = sum(rbf_kernel(Y[i], Y[j]; gamma=gamma) for i in 1:n, j in 1:n if i != j) / (n * (n - 1))
    k_xy = sum(rbf_kernel(X[i], Y[j]; gamma=gamma) for i in 1:m, j in 1:n) / (m * n)

    return k_xx + k_yy - 2 * k_xy
end


mmd_from_samples

In [None]:
nq = 5
init_psum = zero_state(nq)
circuit = build_circuit(nq)
nparams = countparameters(circuit)
thetas = randn(nparams) * 0.5
exact_psum = propagate(circuit, init_psum, thetas)
max_weight = 3
trunc_psum = propagate(circuit, init_psum, thetas; max_weight=max_weight)

X_samples = Vector{BitVector}()
Y_samples = Vector{BitVector}()
n_sample = 10000

for i in 1:n_sample
    push!(X_samples, sample_bitstring(exact_psum))
    push!(Y_samples, sample_bitstring(trunc_psum))
end

# Convert BitVectors to integer bitstring indices
bitvec_to_int(b::BitVector) = parse(Int, join(b .* 1), base=2)
X = bitvec_to_int.(X_samples)
Y = bitvec_to_int.(Y_samples)

σ = std(vcat(X, Y))
# gamma = 1 / (2 * σ^2)
gamma = 1 / 1
mmd2 = mmd_from_samples(X, Y; gamma)


0.013156987331881587

In [None]:
T = eltype(thetas)

init_psum = zero_state(nq, T)
wrapped_psum = wrapcoefficients(init_psum, PauliFreqTracker)

propagate!(circuit, wrapped_psum, thetas; max_weight=max_weight)

PauliSum(nqubits: 5, 376 Pauli terms:
 PauliFreqTracker(0.0067474 + 0.0im) * IYZII
 PauliFreqTracker(-0.00042681 + 0.0im) * ZIYII
 PauliFreqTracker(0.0013812 + 0.0im) * IZZII
 PauliFreqTracker(-0.010162 + 0.0im) * IZXZI
 PauliFreqTracker(-0.0011018 + 0.0im) * IXZIX
 PauliFreqTracker(-0.00096645 + 0.0im) * ZIIXI
 PauliFreqTracker(-0.0030523 + 0.0im) * IYYIZ
 PauliFreqTracker(-0.0014657 + 0.0im) * XYIXI
 PauliFreqTracker(-0.00013777 + 0.0im) * ZIZXI
 PauliFreqTracker(0.0047889 + 0.0im) * IIZXI
 PauliFreqTracker(0.0038888 + 0.0im) * XZIIY
 PauliFreqTracker(-0.0019058 + 0.0im) * IYIXY
 PauliFreqTracker(-0.0040366 + 0.0im) * IIYIY
 PauliFreqTracker(-0.006907 + 0.0im) * IYYYI
 PauliFreqTracker(0.0018432 + 0.0im) * ZZIZI
 PauliFreqTracker(0.0010347 + 0.0im) * ZXIIX
 PauliFreqTracker(0.0009279 + 0.0im) * YXIIY
 PauliFreqTracker(-0.00011072 + 0.0im) * YIZZI
 PauliFreqTracker(0.0027144 + 0.0im) * IIYZI
 PauliFreqTracker(-0.0037907 + 0.0im) * XIZYI
  ⋮)

In [None]:
X_samples = Vector{BitVector}()
n_sample = 1000

for _ in 1:n_sample
    push!(X_samples, sample_bitstring(wrapped_psum))
end

bitvec_to_int(b::BitVector) = parse(Int, join(b .* 1), base=2)
X = bitvec_to_int.(X_samples)

1000-element Vector{Int64}:
 18
 26
  8
 22
 22
 26
 17
 17
 23
 19
  ⋮
 23
  8
  0
 19
 18
 22
 27
 18
  3

In [None]:
function lossfunction(thetas)
    T = eltype(thetas)

    init_psum = zero_state(nq, T)
    wrapped_psum = wrapcoefficients(init_psum, PauliFreqTracker)

    propagate!(circuit, wrapped_psum, thetas; max_weight=max_weight)

    X_samples = Vector{BitVector}()
    n_sample = 1000

    for _ in 1:n_sample
        push!(X_samples, sample_bitstring(wrapped_psum))
    end

    bitvec_to_int(b::BitVector) = parse(Int, join(b .* 1), base=2)
    X = bitvec_to_int.(X_samples)

    # σ = std(vcat(X, Y))
    # gamma = 1 / (2 * σ^2 + eps())
    gamm =1 
    return mmd_from_samples(X, Y; gamma)
end


lossfunction (generic function with 1 method)

In [None]:
@time lossfunction(thetas)


  1.066973 seconds (7.50 M allocations: 413.353 MiB, 3.40% gc time, 34.75% compilation time)


0.0001634392591330358

In [None]:
gradient(lossfunction, thetas)