In [449]:
using Zygote
using ITensors
using Random

In [450]:
function angle_encoder(x::Float64) 
    """Function to convert normalised time series to an angle encoding."""
    @assert x <= 1.0 && x >= 0.0 "Data points must be rescaled between 1 and 0 before encoding using the angle encoder."
    s1 = exp(1im * (3π/2) * x) * cospi(0.5 * x)
    s2 = exp(-1im * (3π/2) * x) * sinpi(0.5 * x)
    return [s1, s2]
 
end

angle_encoder (generic function with 1 method)

In [451]:
function normalised_data_to_product_state(sample::Vector, site_indices::Vector{Index{Int64}})
    """Function to convert a single normalised sample to a product state
    with local dimension 2, as specified by the feature map."""

    @assert length(sample) == length(site_indices) "Mismatch between number of sites and sample length."

    product_state = MPS([ITensor(angle_encoder(sample[i]), site_indices[i]) for i in eachindex(site_indices)])

    return product_state

end

normalised_data_to_product_state (generic function with 1 method)

In [452]:
s = siteinds("S=1/2", 10);
mps = randomMPS(ComplexF64, s; linkdims=5);

In [453]:
sample = rand(10)

10-element Vector{Float64}:
 0.9674321008861496
 0.819229127804589
 0.08472206516123992
 0.1086790717560383
 0.5739626139787982
 0.8195350973458074
 0.21522023414995883
 0.04497022561381525
 0.9730120902802789
 0.9761161514693381

In [454]:
ps = normalised_data_to_product_state(sample, s)

MPS
[1] ((dim=2|id=718|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=732|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=60|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=189|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=884|"S=1/2,Site,n=5"),)
[6] ((dim=2|id=218|"S=1/2,Site,n=6"),)
[7] ((dim=2|id=140|"S=1/2,Site,n=7"),)
[8] ((dim=2|id=610|"S=1/2,Site,n=8"),)
[9] ((dim=2|id=334|"S=1/2,Site,n=9"),)
[10] ((dim=2|id=539|"S=1/2,Site,n=10"),)


In [474]:
function raw_overlap(ps, mps)
    res = conj(ps[1]) * mps[1]
    for i in 2:length(mps)
        res *= conj(ps[i]) * mps[i]
    end
    return res[]
end

raw_overlap (generic function with 1 method)

In [372]:
bt = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=764|"S=1/2,Site,n=1") (dim=2|id=95|"S=1/2,Site,n=2") (dim=5|id=649|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [355]:
function loss(bt, mps, ps)
    phi_tilde = conj(ps[1]) * conj(ps[2])
    for i in 3:10
        phi_tilde *= mps[i] * conj(ps[i])
    end
    yhat = phi_tilde * bt
    norm_val = inner(yhat, yhat)
    p = abs(yhat[])^2 / sqrt(norm_val[])
    loss = -log(real(p))
    return loss
end

loss (generic function with 1 method)

In [373]:
l = x -> loss(x, mps, ps)

#65 (generic function with 1 method)

In [374]:
bt_old = bt

ITensor ord=3 (dim=2|id=764|"S=1/2,Site,n=1") (dim=2|id=95|"S=1/2,Site,n=2") (dim=5|id=649|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [377]:
for i in 1:100
    f, (g,) = withgradient(l, bt_old)
    bt_new = bt_old - 0.8 * g
    println(f)
    bt_old = bt_new
end

0.9747101482636243
0.9494859734123181
0.9254879825864393
0.9026018258464243
0.8807284932305783
0.8597816814119945
0.8396857032719909
0.8203738111803388
0.8017868393257689
0.7838720948104204
0.7665824446683752
0.7498755586301509
0.7337132767606614
0.7180610780171822
0.7028876309740216
0.6881644119082481
0.6738653784656469
0.6599666894640668
0.6464464632132463
0.6332845681608003
0.6204624408055389
0.6079629267202402
0.5957701412479696
0.5838693470181279
0.5722468459003637
0.5608898833992414
0.5497865638078325
0.5389257746980178
0.5282971195400475
0.5178908574223855
0.5076978489917734
0.4977095078582302
0.48791775681465593
0.4783149883092854
0.46889402868431584
0.45964810575781884
0.450570819380475
0.4416561146452089
0.4328982574677675
0.4242918122906346
0.41583162169237314
0.4075127877101261
0.3993306547053042
0.3912807936218364
0.38335898750325226
0.37556121814962723
0.3678836538083322
0.36032263780387497
0.35287467802210487
0.34553643717283394
0.3383047237626958
0.33117648371693487
0.3

In [380]:
phi_tilde = conj(ps[1]) * conj(ps[2])
for i in 3:10
    phi_tilde *= mps[i] * conj(ps[i])
end
yhat = phi_tilde * bt_old

ITensor ord=0
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

-------

Alt. formulation

In [478]:
function loss2(bt, mps, ps)
    phi_tilde = conj(ps[1]) * conj(ps[2])
    for i in 3:10
        phi_tilde *= mps[i] * conj(ps[i])
    end
    yhat = phi_tilde * bt
    p = norm(yhat[])^2
    return -log(p)
end

loss2 (generic function with 1 method)

In [479]:
bt = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=718|"S=1/2,Site,n=1") (dim=2|id=732|"S=1/2,Site,n=2") (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [480]:
l2 = x -> loss2(x, mps, ps)

#79 (generic function with 1 method)

In [481]:
l2(bt)

6.789787517503335

In [482]:
gradient(l2, bt)

(ITensor ord=3
Dim 1: (dim=2|id=718|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=732|"S=1/2,Site,n=2")
Dim 3: (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×5
[:, :, 1] =
 -0.05414210766480898 + 0.0594908679666256im  …  0.17746053947913137 + 0.2109003899501478im
   1.3590547832759137 - 0.7880791112726404im      -2.059346083334616 - 4.973661666926607im

[:, :, 2] =
 0.0019243420069964459 - 0.013965008118887097im  …  -0.04655544407735828 - 0.012877784258328147im
   -0.1182351308082307 + 0.24863890317183862im        0.7907567184779841 + 0.5144802453209532im

[:, :, 3] =
 0.01719848219978476 - 0.03201931767150516im  …  -0.10093664313173749 - 0.07295264064605618im
 -0.5091424036824719 + 0.49463253366578985im         1.44869659446542 + 1.9538358884821951im

[:, :, 4] =
 -0.008303274108484755 - 0.009471652490615887im  …  -0.03593947585002528 + 0.023898916228253887im
   0.09869390571110129 + 0.22533743850778318im        0.8101351700168455 - 0.23285835575650804im

[:, :, 5

In [483]:
phi_tilde = conj(ps[1]) * conj(ps[2])
for i in 3:10
    phi_tilde *= mps[i] * conj(ps[i])
end
yhat = phi_tilde * bt

ITensor ord=0
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [484]:
@show -2*conj(phi_tilde)/conj(yhat)

(-2 * conj(phi_tilde)) / conj(yhat) = ITensor ord=3
Dim 1: (dim=2|id=718|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=732|"S=1/2,Site,n=2")
Dim 3: (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×5
[:, :, 1] =
 -0.05414210766480897 + 0.05949086796662561im  0.17746053947913143 + 0.2109003899501478im
   1.3590547832759137 - 0.7880791112726407im   -2.0593460833346167 - 4.973661666926608im

[:, :, 2] =
 0.0019243420069964443 - 0.013965008118887099im  -0.04655544407735829 - 0.012877784258328149im
  -0.11823513080823068 + 0.2486389031718387im      0.7907567184779842 + 0.514480245320953im

[:, :, 3] =
 0.01719848219978476 - 0.03201931767150517im  -0.1009366431317375 - 0.07295264064605618im
 -0.5091424036824718 + 0.4946325336657899im    1.4486965944654204 + 1.9538358884821951im

[:, :, 4] =
 -0.008303274108484757 - 0.009471652490615887im  -0.03593947585002527 + 0.023898916228253894im
   0.09869390571110133 + 0.2253374385077832im      0.8101351700168455 - 0.23285835575650812

ITensor ord=3 (dim=2|id=718|"S=1/2,Site,n=1") (dim=2|id=732|"S=1/2,Site,n=2") (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

checks out

In [485]:
g = -2*conj(phi_tilde)/conj(yhat)

ITensor ord=3 (dim=2|id=718|"S=1/2,Site,n=1") (dim=2|id=732|"S=1/2,Site,n=2") (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [486]:
bt_new = bt - 0.4 * g

ITensor ord=3 (dim=2|id=718|"S=1/2,Site,n=1") (dim=2|id=732|"S=1/2,Site,n=2") (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [487]:
l2(bt_new)

2.3705090817801664

In [491]:
bt_old = bt

ITensor ord=3 (dim=2|id=718|"S=1/2,Site,n=1") (dim=2|id=732|"S=1/2,Site,n=2") (dim=5|id=432|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [492]:
for i in 1:15
    f, (g,) = withgradient(l2, bt_old)
    bt_new = bt_old - 0.4 * g
    bt_old = bt_new 
    println(f)
end

6.789787517503335
2.3705090817801664
2.1840798107752595
2.0281574555554274
1.8940092592775732
1.7762162813743438
1.6711750231185483
1.5763635253822805
1.4899457435870969
1.4105423277661053
1.337090099139875
1.268751887818873
1.204856468143833
1.1448572614089865
1.088303167459855
