In [212]:
using ITensors
using Zygote
using Random

In [213]:
s = siteinds("S=1/2", 5)
l_index = Index(2, "f(x)")
l_tensor = randomITensor(ComplexF64, l_index)
mps = randomMPS(ComplexF64, s; linkdims=4)
mps[1] *= l_tensor;
normalize!(mps)

MPS
[1] ((dim=2|id=334|"S=1/2,Site,n=1"), (dim=4|id=117|"Link,l=1"), (dim=2|id=24|"f(x)"))
[2] ((dim=4|id=117|"Link,l=1"), (dim=2|id=111|"S=1/2,Site,n=2"), (dim=4|id=497|"Link,l=2"))
[3] ((dim=4|id=497|"Link,l=2"), (dim=2|id=51|"S=1/2,Site,n=3"), (dim=4|id=213|"Link,l=3"))
[4] ((dim=4|id=213|"Link,l=3"), (dim=2|id=207|"S=1/2,Site,n=4"), (dim=2|id=587|"Link,l=4"))
[5] ((dim=2|id=587|"Link,l=4"), (dim=2|id=554|"S=1/2,Site,n=5"))


In [214]:
function angle_encoder(x)
    return [exp(1im * (3π/2) * x) * cospi(0.5 * x), exp(-1im * (3π/2) * x) * sinpi(0.5 * x)]
end;

In [215]:
sample = rand(5)
ps = MPS([ITensor(angle_encoder(sample[i]), s[i]) for i in eachindex(sample)]);

In [216]:
function get_probas(mps, ps)
    amp = 1
    for i in 1:5
        amp *= conj(ps[i]) * mps[i]
    end
    abs_amp_sq = real(abs.(amp)).^2
    orthogonalize!(mps, 1)
    Z = conj(mps[1]) * mps[1]
    p = abs_amp_sq / abs(Z[])
    return p
end

get_probas (generic function with 1 method)

In [217]:
function loss1(mps, ps)
    # pretend the ground truth label is y = 1
    ground_truth_label = 1
    amp = 1
    for i in 1:5
        amp *= conj(ps[i]) * mps[i]
    end
    yhat = amp
    label_idx = first(inds(yhat))
    y = onehot(label_idx => (ground_truth_label + 1))
    f_ln = first(yhat * y)
    orthogonalize!(mps, 1)
    Z = conj(mps[1]) * mps[1]
    p = abs2.(f_ln) / abs(Z[])
    loss = -log(p)
    return loss
end

loss1 (generic function with 1 method)

In [218]:
loss1(mps, ps)

4.315999213703352

In [219]:
ll = x -> loss1(x, ps)

#69 (generic function with 1 method)

In [220]:
g, = gradient(ll, mps);

In [221]:
mps_new = mps .- 0.8 .* g.data

MPS
[1] ((dim=2|id=334|"S=1/2,Site,n=1"), (dim=4|id=117|"Link,l=1"), (dim=2|id=24|"f(x)"))
[2] ((dim=4|id=117|"Link,l=1"), (dim=2|id=111|"S=1/2,Site,n=2"), (dim=4|id=497|"Link,l=2"))
[3] ((dim=4|id=497|"Link,l=2"), (dim=2|id=51|"S=1/2,Site,n=3"), (dim=4|id=213|"Link,l=3"))
[4] ((dim=4|id=213|"Link,l=3"), (dim=2|id=207|"S=1/2,Site,n=4"), (dim=2|id=587|"Link,l=4"))
[5] ((dim=2|id=587|"Link,l=4"), (dim=2|id=554|"S=1/2,Site,n=5"))


In [222]:
ll(mps_new)

0.03174352780773894

In [223]:
g, = gradient(ll, mps_new);

In [224]:
mps_new2 = mps_new .- 0.95 .* g.data

MPS
[1] ((dim=2|id=334|"S=1/2,Site,n=1"), (dim=2|id=24|"f(x)"), (dim=4|id=515|"Link,l=1"))
[2] ((dim=2|id=111|"S=1/2,Site,n=2"), (dim=4|id=870|"Link,l=2"), (dim=4|id=515|"Link,l=1"))
[3] ((dim=2|id=51|"S=1/2,Site,n=3"), (dim=4|id=852|"Link,l=3"), (dim=4|id=870|"Link,l=2"))
[4] ((dim=2|id=207|"S=1/2,Site,n=4"), (dim=2|id=448|"Link,l=4"), (dim=4|id=852|"Link,l=3"))
[5] ((dim=2|id=554|"S=1/2,Site,n=5"), (dim=2|id=448|"Link,l=4"))


In [225]:
ll(mps_new2)

0.012319773114485781

In [226]:
get_probas(mps_new2, ps).tensor

Dim 1: (dim=2|id=24|"f(x)")
NDTensors.Dense{Float64, Vector{Float64}}
 2-element
 0.009468035091881044
 0.9877558046051496

In [227]:
norm(mps_new2)

34297.31615548319

With bond tensor?

In [228]:
bt = mps[1] * mps[2]

ITensor ord=4 (dim=2|id=334|"S=1/2,Site,n=1") (dim=2|id=24|"f(x)") (dim=2|id=111|"S=1/2,Site,n=2") (dim=4|id=497|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [234]:
function lossbt(bt, ps, mps)
    ground_truth_label = 1
    amp = conj(ps[1]) * conj(ps[2])
    for i in 3:5
        amp *= conj(ps[i]) * mps[i]
    end
    yhat = bt * amp
    label_idx = first(inds(yhat))
    y = onehot(label_idx => (ground_truth_label + 1))
    f_ln = first(yhat * y)
    orthogonalize!(mps, 1)
    Z = conj(mps[1]) * mps[1]
    p = abs2.(f_ln) / abs(Z[])
    loss = -log(p)
    return loss
end
    

lossbt (generic function with 2 methods)

In [236]:
lbt = x -> lossbt(x, ps, mps)

#73 (generic function with 1 method)

In [237]:
lbt(bt)

4.31599921370335

In [240]:
g, = gradient(lbt, bt);

In [250]:
bt_new = bt - 0.05 * g

ITensor ord=4 (dim=2|id=334|"S=1/2,Site,n=1") (dim=2|id=24|"f(x)") (dim=2|id=111|"S=1/2,Site,n=2") (dim=4|id=497|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [251]:
lbt(bt_new)

1.3161165727640542