In [670]:
using Zygote
using ITensors

In [671]:
s = siteinds("S=1/2", 10);
mps = randomMPS(ComplexF64, s; linkdims=4);

In [672]:
function feature_map(x::Float64)
    s1 = exp(1im * (3π/2) * x) * cospi(0.5 * x)
    s2 = exp(-1im * (3π/2) * x) * sinpi(0.5 * x)
    return [s1, s2]
end

feature_map (generic function with 1 method)

In [673]:
ps = MPS([ITensor(feature_map(rand()), si) for si in s])

MPS
[1] ((dim=2|id=942|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=490|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=338|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=932|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=442|"S=1/2,Site,n=5"),)
[6] ((dim=2|id=681|"S=1/2,Site,n=6"),)
[7] ((dim=2|id=262|"S=1/2,Site,n=7"),)
[8] ((dim=2|id=868|"S=1/2,Site,n=8"),)
[9] ((dim=2|id=195|"S=1/2,Site,n=9"),)
[10] ((dim=2|id=586|"S=1/2,Site,n=10"),)


make bond tensor

In [674]:
BT = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [675]:
function nll(BT, mps, ps)
    yhat = BT * ps[1] * ps[2]
    for i = 3:length(mps)
        yhat *= mps[i] * ps[i]
    end
    yhat = abs(yhat[])^2
    return -log(yhat)
end

nll (generic function with 1 method)

In [676]:
nll(BT, mps, ps)

7.528188588716975

In [677]:
loss = x -> nll(x, mps, ps)

#111 (generic function with 1 method)

In [678]:
f, (g,) = withgradient(loss, BT)

(val = 7.528188588716975, grad = (ITensor ord=3
Dim 1: (dim=4|id=87|"Link,l=2")
Dim 2: (dim=2|id=490|"S=1/2,Site,n=2")
Dim 3: (dim=2|id=942|"S=1/2,Site,n=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4×2×2
[:, :, 1] =
 0.9567526545414118 + 0.1112409821442079im  …  -0.33630065193966974 + 0.23497961924904395im
   5.40396809656421 - 2.746919788791954im       -0.9454992260385585 + 2.4027020287540477im
 4.5766573682735965 - 1.5536554793645652im      -1.0191610795819446 + 1.7886431020212867im
 -2.039171249344322 + 1.9128432150440846im       0.1090955329529418 - 1.1858759867348216im

[:, :, 2] =
 -0.5683576606167874 - 0.15836135866804887im  …   0.2258614841165202 - 0.11018573227695233im
 -3.5314194977196465 + 1.147936975314718im         0.800783758307358 - 1.3639298806560691im
 -2.9172475354115375 + 0.5046066957467812im       0.7869215082831295 - 0.9853447417741049im
  1.4159625382358079 - 0.9634321099435424im      -0.1788673044670502 + 0.707207402237001im,))

In [679]:
#BT_old = BT

In [680]:
# for i in 1:100
#     f, (g,) = withgradient(loss, BT_old)
#     BT_new = BT_old - 0.4 * g
#     println(f)
#     BT_old = BT_new
# end

In [681]:
g.tensor

Dim 1: (dim=4|id=87|"Link,l=2")
Dim 2: (dim=2|id=490|"S=1/2,Site,n=2")
Dim 3: (dim=2|id=942|"S=1/2,Site,n=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4×2×2
[:, :, 1] =
 0.9567526545414118 + 0.1112409821442079im  …  -0.33630065193966974 + 0.23497961924904395im
   5.40396809656421 - 2.746919788791954im       -0.9454992260385585 + 2.4027020287540477im
 4.5766573682735965 - 1.5536554793645652im      -1.0191610795819446 + 1.7886431020212867im
 -2.039171249344322 + 1.9128432150440846im       0.1090955329529418 - 1.1858759867348216im

[:, :, 2] =
 -0.5683576606167874 - 0.15836135866804887im  …   0.2258614841165202 - 0.11018573227695233im
 -3.5314194977196465 + 1.147936975314718im         0.800783758307358 - 1.3639298806560691im
 -2.9172475354115375 + 0.5046066957467812im       0.7869215082831295 - 0.9853447417741049im
  1.4159625382358079 - 0.9634321099435424im      -0.1788673044670502 + 0.707207402237001im

In [682]:
BT_new = BT - 0.4 * g

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [683]:
loss(BT_new)

1.2250030445176554

In [684]:
BT_old = BT

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

Analytical gradient

In [685]:
phi_tilde = ps[1] * ps[2]
for i in 3:length(mps)
    phi_tilde *= mps[i] * ps[i]
end

In [686]:
f_out = (BT * phi_tilde)[]

-0.007234770243740869 + 0.022031102810516658im

In [687]:
full_analytical_gradient = conj(phi_tilde)/conj(f_out)

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [688]:
permute(g, (inds(g)[3], inds(g)[2], inds(g)[1])).tensor

Dim 1: (dim=2|id=942|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=490|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
  0.9567526545414118 + 0.1112409821442079im   …  -0.33630065193966974 + 0.23497961924904395im
 -0.5683576606167874 - 0.15836135866804887im       0.2258614841165202 - 0.11018573227695233im

[:, :, 2] =
    5.40396809656421 - 2.746919788791954im  …  -0.9454992260385585 + 2.4027020287540477im
 -3.5314194977196465 + 1.147936975314718im       0.800783758307358 - 1.3639298806560691im

[:, :, 3] =
  4.5766573682735965 - 1.5536554793645652im  …  -1.0191610795819446 + 1.7886431020212867im
 -2.9172475354115375 + 0.5046066957467812im      0.7869215082831295 - 0.9853447417741049im

[:, :, 4] =
 -2.039171249344322 + 1.9128432150440846im  …   0.1090955329529418 - 1.1858759867348216im
 1.4159625382358079 - 0.9634321099435424im     -0.1788673044670502 + 0.707207402237001im

In [689]:
(-full_analytical_gradient * 2).tensor

Dim 1: (dim=2|id=942|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=490|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
  0.9567526545414133 + 0.1112409821442075im   …  -0.33630065193967007 + 0.2349796192490445im
 -0.5683576606167882 - 0.15836135866804765im      0.22586148411652027 - 0.11018573227695293im

[:, :, 2] =
   5.403968096564213 - 2.746919788791959im   …  -0.9454992260385587 + 2.402702028754051im
 -3.5314194977196505 + 1.1479369753147206im      0.8007837583073587 - 1.3639298806560716im

[:, :, 3] =
   4.576657368273598 - 1.5536554793645678im  …  -1.0191610795819446 + 1.7886431020212892im
 -2.9172475354115397 + 0.5046066957467831im      0.7869215082831297 - 0.985344741774106im

[:, :, 4] =
 -2.039171249344323 + 1.9128432150440875im  …   0.10909553295294167 - 1.1858759867348228im
 1.4159625382358079 - 0.9634321099435432im     -0.17886730446705032 + 0.7072074022370017im

In [690]:
BT_new_ag = BT - 0.4 * full_analytical_gradient

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [691]:
loss(BT_new_ag)

2.8860361193259614

In [692]:
BT_old = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=942|"S=1/2,Site,n=1") (dim=2|id=490|"S=1/2,Site,n=2") (dim=4|id=87|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [693]:
loss(BT_old)

7.528188588716975