In [1010]:
using ITensors
using Zygote
using Random
using Pkg

Rather than defining the probability as the square of the overlap, an alternative formulation is define it as the normalised squared overlap:
$$p_i = \frac{|\langle y_i | \hat{y} \rangle|^2}{|\langle \hat{y} | \hat{y} \rangle|}$$
Here, $\langle y_i|$ is one for the i-th class and zero for all other classes.

In [1011]:
s = siteinds("S=1/2", 5)
l_index = Index(2, "f(x)")
l_tensor = randomITensor(ComplexF64, l_index)
mps = randomMPS(ComplexF64, s; linkdims=4)
mps[1] *= l_tensor;

Make product state

In [1012]:
function angle_encoder(x)
    return [exp(1im * (3π/2) * x) * cospi(0.5 * x), exp(-1im * (3π/2) * x) * sinpi(0.5 * x)]
end;

In [1013]:
Random.seed!(42)
sample = rand(5)
ps = MPS([ITensor(angle_encoder(sample[i]), s[i]) for i in eachindex(sample)]);

In [1014]:
function get_overlap(ps, mps)
    res = 1
    for i in eachindex(mps)
        res *= mps[i] * ps[i]
    end
    return res
end;

In [1015]:
function label_to_tensor(label, l_idx)
    tensor = onehot(l_idx => label + 1)
    return tensor
end;

In [1016]:
@show label_to_tensor(0, l_index)

label_to_tensor(0, l_index) = ITensor ord=1
Dim 1: (dim=2|id=482|"f(x)")
NDTensors.Dense{Float64, Vector{Float64}}
 2-element
 1.0
 0.0


ITensor ord=1 (dim=2|id=482|"f(x)")
NDTensors.Dense{Float64, Vector{Float64}}

In [1017]:
function get_probas(yhat)
    norm_val = abs((conj(yhat)*yhat)[])
    return [abs(yhat[i])^2 / norm_val for i in 1:dim(yhat)]
end;

In [1018]:
bt = mps[1] * mps[2]

ITensor ord=4 (dim=2|id=27|"S=1/2,Site,n=1") (dim=2|id=482|"f(x)") (dim=2|id=744|"S=1/2,Site,n=2") (dim=4|id=730|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [1019]:
function loss(bt, mps, ps)
    gt_label = 1
    phi_tilde = conj(ps[1]) * conj(ps[2])
    for i in 3:5
        phi_tilde *= mps[i] * conj(ps[i])
    end
    yhat = bt * phi_tilde
    y = label_to_tensor(gt_label, l_index)
    y_yhat = (y * yhat)[]
    prob = abs(y_yhat)^2 / abs((conj(yhat)*yhat)[])
    return -log(prob)
end

loss (generic function with 2 methods)

In [1020]:
l(bt)

1.2542379326734892

In [1021]:
l = x -> loss(x, mps, ps);

In [1022]:
g, = gradient(l, bt)

(ITensor ord=4
Dim 1: (dim=2|id=482|"f(x)")
Dim 2: (dim=2|id=27|"S=1/2,Site,n=1")
Dim 3: (dim=2|id=744|"S=1/2,Site,n=2")
Dim 4: (dim=4|id=730|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×2×4
[:, :, 1, 1] =
 -1.5712768730217714 + 1.3437649030639047im  …  -2.9437952846202786 + 1.093829543731132im
  -2.051050245855548 - 2.5498655164088118im     -1.5902491231088167 - 4.709356966196219im

[:, :, 2, 1] =
 -0.419273344886655 - 1.7172868444585057im  …  0.30085708193018873 - 2.6681778545640764im
  2.736816456446854 - 0.5814979931250537im       4.206869266186243 + 0.6030900994858214im

[:, :, 1, 2] =
 3.287236480619436 - 1.6160117620716603im  …  5.5331363280507455 - 0.5840153311379886im
 2.400007559427803 + 5.277577549912979im      0.6603334143548532 + 8.781546928751053im

[:, :, 2, 2] =
 -0.03488142267370419 + 3.1316599632188615im   …  -1.6886492063633014 + 4.44732099497472im
   -4.952794876335148 - 0.20438341055656384im      -6.955453219833564 - 2.883415112880523im

[:, :, 

In [1023]:
bt_new = bt - 0.1 * g

ITensor ord=4 (dim=2|id=27|"S=1/2,Site,n=1") (dim=2|id=482|"f(x)") (dim=2|id=744|"S=1/2,Site,n=2") (dim=4|id=730|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [1024]:
l(bt_new)

0.26922639628722067

----

# Formulation 2

The more common formulation - the MPS Born Machine - is defined as the following:
$$p(x) = \frac{|\langle x | \Psi \rangle|^2}{|\langle \Psi | \Psi \rangle|}$$

In [1025]:
s = siteinds("S=1/2", 5)
l_index = Index(2, "f(x)")
l_tensor = randomITensor(ComplexF64, l_index)
mps = randomMPS(ComplexF64, s; linkdims=4)

MPS
[1] ((dim=2|id=487|"S=1/2,Site,n=1"), (dim=4|id=753|"Link,l=1"))
[2] ((dim=4|id=753|"Link,l=1"), (dim=2|id=264|"S=1/2,Site,n=2"), (dim=4|id=236|"Link,l=2"))
[3] ((dim=4|id=236|"Link,l=2"), (dim=2|id=596|"S=1/2,Site,n=3"), (dim=4|id=727|"Link,l=3"))
[4] ((dim=4|id=727|"Link,l=3"), (dim=2|id=290|"S=1/2,Site,n=4"), (dim=2|id=329|"Link,l=4"))
[5] ((dim=2|id=329|"Link,l=4"), (dim=2|id=653|"S=1/2,Site,n=5"))


In [1026]:
function angle_encoder(x)
    return [exp(1im * (3π/2) * x) * cospi(0.5 * x), exp(-1im * (3π/2) * x) * sinpi(0.5 * x)]
end;

In [1027]:
Random.seed!(42)
sample = rand(5)
ps = MPS([ITensor(angle_encoder(sample[i]), s[i]) for i in eachindex(sample)]);

In [1028]:
function loss(mps, ps)
    amp = 1
    for i in 1:5
        amp *= conj(ps[i]) * mps[i]
    end
    abs_amp_sq = abs(amp[])^2
    # partition function calculation, put mps into canonical form
    orthogonalize!(mps, 1)
    Z = conj(mps[1]) * mps[1]
    p = abs_amp_sq / abs(Z[])
    return -log(p)
end

loss (generic function with 2 methods)

In [1029]:
l1 = x -> loss(x, ps)

#191 (generic function with 1 method)

In [1030]:
l1(mps)

3.541627088028318

In [1031]:
(g,) = gradient(l1, mps)

((data = ITensor[ITensor ord=2
Dim 1: (dim=2|id=487|"S=1/2,Site,n=1")
Dim 2: (dim=4|id=753|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×4
 0.43549716503541136 + 0.31029825701430847im  …  1.3987087034994419 + 1.2495033822819541im
  1.5903788595416521 - 1.7216440912691144im      1.6991921165968031 + 2.8089533062661167im, ITensor ord=3
Dim 1: (dim=2|id=264|"S=1/2,Site,n=2")
Dim 2: (dim=4|id=753|"Link,l=1")
Dim 3: (dim=4|id=236|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×4×4
[:, :, 1] =
 -0.6128527944010728 - 1.3189705939402672im   …  -0.13991352468462265 - 0.595275636366965im
  1.2428322874298248 + 0.04112652277219919im       0.5081925274303595 + 0.12285321777643854im

[:, :, 2] =
 -0.4644675605533743 - 0.9036949260267133im    …  -0.11421247797249637 - 0.41165263080835807im
  0.8687206657118655 - 0.005831489545712898im      0.35816570823815785 + 0.07163597250012907im

[:, :, 3] =
 -0.40556134444356534 - 1.6482421465599821im   …  -0.02650547717959318 - 

In [1032]:
mps_new = mps .- 0.8 .* g.data

MPS
[1] ((dim=2|id=487|"S=1/2,Site,n=1"), (dim=4|id=753|"Link,l=1"))
[2] ((dim=4|id=753|"Link,l=1"), (dim=2|id=264|"S=1/2,Site,n=2"), (dim=4|id=236|"Link,l=2"))
[3] ((dim=4|id=236|"Link,l=2"), (dim=2|id=596|"S=1/2,Site,n=3"), (dim=4|id=727|"Link,l=3"))
[4] ((dim=4|id=727|"Link,l=3"), (dim=2|id=290|"S=1/2,Site,n=4"), (dim=2|id=329|"Link,l=4"))
[5] ((dim=2|id=329|"Link,l=4"), (dim=2|id=653|"S=1/2,Site,n=5"))


In [1033]:
loss1(mps_new, ps)

0.06541785766041702

In [1034]:
(g,) = gradient(l1, mps_new)

((data = ITensor[ITensor ord=2
Dim 1: (dim=2|id=487|"S=1/2,Site,n=1")
Dim 2: (dim=4|id=337|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×4
  5.471239265081177e-5 - 0.0005407774521311763im   …  -6.1454709414873e-5 - 0.00032728952798663686im
 -5.695570571341915e-5 + 0.00032362011842111834im     7.38620100945612e-5 - 0.0005069174975743613im, ITensor ord=3
Dim 1: (dim=2|id=264|"S=1/2,Site,n=2")
Dim 2: (dim=4|id=337|"Link,l=1")
Dim 3: (dim=4|id=292|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×4×4
[:, :, 1] =
  0.8935213260425252 + 1.1668990500270808im   …   0.05959893317786471 - 0.01576101437899264im
 -1.2350558346412357 + 0.23169525921331863im     -0.01096251339500321 + 0.051556317867300507im

[:, :, 2] =
   -0.367087312036387 + 0.1278030293135897im   …  -0.0009764537492605135 + 0.01627495073246405im
 0.044075669906422005 - 0.32940246160825043im      -0.012041929121852805 - 0.007022772405648199im

[:, :, 3] =
 -0.19761475001255785 + 0.2233693718633392im  

In [1035]:
mps_new2 = deepcopy(mps_new) .- 0.8 .* g.data

MPS
[1] ((dim=2|id=487|"S=1/2,Site,n=1"), (dim=4|id=337|"Link,l=1"))
[2] ((dim=2|id=264|"S=1/2,Site,n=2"), (dim=4|id=292|"Link,l=2"), (dim=4|id=337|"Link,l=1"))
[3] ((dim=2|id=596|"S=1/2,Site,n=3"), (dim=4|id=839|"Link,l=3"), (dim=4|id=292|"Link,l=2"))
[4] ((dim=2|id=290|"S=1/2,Site,n=4"), (dim=2|id=266|"Link,l=4"), (dim=4|id=839|"Link,l=3"))
[5] ((dim=2|id=653|"S=1/2,Site,n=5"), (dim=2|id=266|"Link,l=4"))


In [1036]:
loss1(mps_new2, ps)

0.010334078111464523