In [328]:
using Zygote
using Distributions
using ITensors
using Random
using Plots 
using Base.Threads

In [329]:
struct PState
    pstate::MPS
    label::Int
end

In [330]:
function complex_feature_map(x::Float64)
    s1 = exp(1im * (3π/2) * x) * cospi(0.5 * x)
    s2 = exp(-1im * (2π/2) * x) * sinpi(0.5 * x)
    return [s1, s2]
end

complex_feature_map (generic function with 1 method)

In [331]:
function generate_training_data(samples_per_class::Int)

    class_A_samples = zeros(samples_per_class, 3)
    class_B_samples = ones(samples_per_class, 3)
    all_samples = vcat(class_A_samples, class_B_samples)
    all_labels = Int.(vcat(zeros(size(class_A_samples)[1]), ones(size(class_B_samples)[1])))

    return all_samples, all_labels

end

generate_training_data (generic function with 1 method)

In [332]:
function sample_to_product_state(sample::Vector, site_inds::Vector{Index{Int64}})
    n_sites = length(site_inds)
    product_state = MPS(ComplexF64, site_inds; linkdims=1)
    for j=1:n_sites
        T = ITensor(site_inds[j])
        zero_state, one_state = complex_feature_map(sample[j])
        T[1] = zero_state
        T[2] = one_state
        product_state[j] = T 
    end
    return product_state
end

sample_to_product_state (generic function with 1 method)

In [333]:
function dataset_to_product_state(dataset::Matrix, labels::Vector, sites::Vector{Index{Int64}})

    all_product_states = Vector{PState}(undef, size(dataset)[1])
    for p=1:length(all_product_states)
        sample_pstate = sample_to_product_state(dataset[p, :], sites)
        sample_label = labels[p]
        product_state = PState(sample_pstate, sample_label)
        all_product_states[p] = product_state
    end

    return all_product_states

end

dataset_to_product_state (generic function with 1 method)

In [334]:
s = siteinds("S=1/2", 3)
mps = randomMPS(ComplexF64, s; linkdims=4)
all_samples, all_labels = generate_training_data(100)
all_pstates = dataset_to_product_state(all_samples, all_labels, s);

In [335]:
mps

MPS
[1] ((dim=2|id=703|"S=1/2,Site,n=1"), (dim=4|id=931|"Link,l=1"))
[2] ((dim=4|id=931|"Link,l=1"), (dim=2|id=922|"S=1/2,Site,n=2"), (dim=2|id=767|"Link,l=2"))
[3] ((dim=2|id=767|"Link,l=2"), (dim=2|id=503|"S=1/2,Site,n=3"))


Make loss function. Takes in:
- Bond Tensor
- product state
- LE
- RE 
\
Outputs:
- Loss

In [336]:
N_train = 200
num_sites = 3

3

In [337]:
ITensor()

ITensor ord=0
NDTensors.EmptyStorage{NDTensors.EmptyNumber, NDTensors.Dense{NDTensors.EmptyNumber, Vector{NDTensors.EmptyNumber}}}

In [338]:
LE = Matrix{ITensor}(undef, N_train, num_sites);
RE = Matrix{ITensor}(undef, N_train, num_sites);

In [339]:
for i = 1:N_train
    RE[i, num_sites] = conj(all_pstates[i].pstate[num_sites]) * mps[num_sites]
end

for j = (num_sites-1):-1:1
    for i = 1:N_train
        RE[i, j] = RE[i, j+1] * mps[j] * conj(all_pstates[i].pstate[j])
    end
end

In [340]:
BT = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [341]:
function loss_per_sample(BT::ITensor, ps::PState, RE::Matrix, psid::Int, lid::Int, rid::Int)
    y = ps.label
    yhat = BT * conj(ps.pstate[lid]) * conj(ps.pstate[rid])
    yhat *= RE[psid, rid+1]
    
    yhat = abs(yhat[])

    diff_sq = (yhat - y)^2

    loss = 0.5 * diff_sq

    return loss
    
end


loss_per_sample (generic function with 6 methods)

In [342]:
loss_per_sample(BT, all_pstates[101], RE, 1, 1, 2)

0.2665887185082914

Returns gradient of loss function with resepct to each argument, so it will return 6 outputs. We only want the first (w.r.t. BT)

In [343]:
f, (∇,) = withgradient(loss_per_sample, BT, all_pstates[101], RE, 101, 1, 2)

(val = 0.28868500425883253, grad = (ITensor ord=3
Dim 1: (dim=2|id=767|"Link,l=2")
Dim 2: (dim=2|id=922|"S=1/2,Site,n=2")
Dim 3: (dim=2|id=703|"S=1/2,Site,n=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×2
[:, :, 1] =
 0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.30019180092467457 + 0.1113492473465553im
 0.0 + 0.0im  0.42217271984552895 + 0.544634217533979im, (pstate = (data = Union{Nothing, ITensor}[ITensor ord=1
Dim 1: (dim=2|id=703|"S=1/2,Site,n=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.20129162898120856 + 0.3819683557724241im
   0.1824786662312385 + 2.176810701100838e-17im, ITensor ord=1
Dim 1: (dim=2|id=922|"S=1/2,Site,n=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.40113790112571196 + 0.3815515686438984im
    0.364957332462477 + 3.8913218532997554e-17im, nothing], llim = nothing, rlim = nothing), label = 0.7598486747489036), Union{Nothing, ITensor}[nothing nothing nothing; nothing nothing n

In [344]:
nabs = Vector{ITensor}(undef, 200)
fs = Vector{Float64}(undef, 200)
for i=1:200
    f, (∇,) = withgradient(loss_per_sample, BT, all_pstates[i], RE, i, 1, 2)
    nabs[i] = ∇
    fs[i] = f
end

In [345]:
∇_total = sum(nabs)
fs_total = sum(fs)

38.08994033569402

In [346]:
gradient_tensor = ∇_total ./ 200
loss_final = fs_total / 200

0.19044970167847008

Apply update

In [347]:
BT_old = BT

ITensor ord=3 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [348]:
lr = 0.5
BT_new = BT_old - lr * gradient_tensor

ITensor ord=3 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

#### SVD apart new bond tensor

In [349]:
BT_new

ITensor ord=3 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [350]:
left_site_index = uniqueinds(mps[1], mps[2])

1-element Vector{Index{Int64}}:
 (dim=2|id=703|"S=1/2,Site,n=1")

In [351]:
U, S, V = svd(BT_new, left_site_index; lefttags="Link,l=1");

In [352]:
U

ITensor ord=2 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [353]:
S

ITensor ord=2 (dim=2|id=986|"Link,l=1") (dim=2|id=211|"Link,v")
NDTensors.Diag{Float64, Vector{Float64}}

In [354]:
V

ITensor ord=3 (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2") (dim=2|id=211|"Link,v")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [355]:
left_site_new = U
right_site_new = S * V

ITensor ord=3 (dim=2|id=986|"Link,l=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [356]:
for i = 1:200
    LE[i, 1] = left_site_new * conj(all_pstates[i].pstate[1])
end

Add sites back into the mps

In [357]:
mps[1] = left_site_new

ITensor ord=2 (dim=2|id=703|"S=1/2,Site,n=1") (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [358]:
mps[2] = right_site_new

ITensor ord=3 (dim=2|id=986|"Link,l=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=767|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [359]:
normalize!(mps)

MPS
[1] ((dim=2|id=703|"S=1/2,Site,n=1"), (dim=2|id=986|"Link,l=1"))
[2] ((dim=2|id=986|"Link,l=1"), (dim=2|id=922|"S=1/2,Site,n=2"), (dim=2|id=767|"Link,l=2"))
[3] ((dim=2|id=767|"Link,l=2"), (dim=2|id=503|"S=1/2,Site,n=3"))


# Now for sites 2-3

In [360]:
BT_s2s3 = mps[2] * mps[3]

ITensor ord=3 (dim=2|id=986|"Link,l=1") (dim=2|id=922|"S=1/2,Site,n=2") (dim=2|id=503|"S=1/2,Site,n=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [361]:
function loss_per_sample2(BT::ITensor, ps::PState, LE::Matrix, psid::Int, lid::Int, rid::Int)
    y = ps.label
    yhat = BT * conj(ps.pstate[lid]) * conj(ps.pstate[rid])
    yhat *= LE[psid, lid-1]
    
    yhat = abs(yhat[])

    diff_sq = (yhat - y)^2

    loss = 0.5 * diff_sq

    return loss
    
end


loss_per_sample2 (generic function with 2 methods)

In [362]:
loss_per_sample2(BT_s2s3, all_pstates[1], LE, 1, 2, 3)

0.05070177352400256

In [375]:
LE[:, 2] = LE[:, 1]
LE[:, 3] = LE[:, 1]

200-element Vector{ITensor}:
 ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.7171921444508376 + 4.6419259161868125e-18im
 -0.6968754752019972 - 4.9483564363427824e-18im
 ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.7171921444508376 + 4.6419259161868125e-18im
 -0.6968754752019972 - 4.9483564363427824e-18im
 ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.7171921444508376 + 4.6419259161868125e-18im
 -0.6968754752019972 - 4.9483564363427824e-18im
 ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.7171921444508376 + 4.6419259161868125e-18im
 -0.6968754752019972 - 4.9483564363427824e-18im
 ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 -0.7171921444508376 + 4.6419259161868125e-18im
 -0.

In [377]:
nab = gradient(loss_per_sample2, BT_s2s3, all_pstates[1], LE, 1, 2, 3)

(ITensor ord=3
Dim 1: (dim=2|id=986|"Link,l=1")
Dim 2: (dim=2|id=503|"S=1/2,Site,n=3")
Dim 3: (dim=2|id=922|"S=1/2,Site,n=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×2
[:, :, 1] =
 -0.049483289637400146 + 0.2229569378368305im   0.0 + 0.0im
  -0.04808152354628181 + 0.21664099810183296im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im, (pstate = (data = Union{Nothing, ITensor}[nothing, ITensor ord=1
Dim 1: (dim=2|id=922|"S=1/2,Site,n=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 0.10140354704800512 + 1.2042038752093967e-18im
 0.07096073107074409 + 0.05167708915573277im, ITensor ord=1
Dim 1: (dim=2|id=503|"S=1/2,Site,n=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 0.20280709409601025 + 2.4084077504187934e-18im
 0.09035702443087401 - 0.04775258323717094im], llim = nothing, rlim = nothing), label = -0.3184392360372778), Union{Nothing, ITensor}[ITensor ord=1
Dim 1: (dim=2|id=986|"Link,l=1")
NDTensors.Dense{Complex

In [378]:
nab[1]

ITensor ord=3 (dim=2|id=986|"Link,l=1") (dim=2|id=503|"S=1/2,Site,n=3") (dim=2|id=922|"S=1/2,Site,n=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}