In [441]:
using Zygote
using ITensors
using Random
using Plots 
using Base.Threads

In [442]:
struct PState
    pstate::MPS
    label::Int
end

In [443]:
function complex_feature_map(x::Float64)
    s1 = exp(1im * (3π/2) * x) * cospi(0.5 * x)
    s2 = exp(-1im * (2π/2) * x) * sinpi(0.5 * x)
    return [s1, s2]
end

complex_feature_map (generic function with 1 method)

In [444]:
function generate_training_data(samples_per_class::Int)

    class_A_samples = zeros(samples_per_class, 5)
    class_B_samples = ones(samples_per_class, 5)
    all_samples = vcat(class_A_samples, class_B_samples)
    all_labels = Int.(vcat(zeros(size(class_A_samples)[1]), ones(size(class_B_samples)[1])))

    return all_samples, all_labels

end

generate_training_data (generic function with 1 method)

In [445]:
function sample_to_product_state(sample::Vector, site_inds::Vector{Index{Int64}})
    n_sites = length(site_inds)
    product_state = MPS(ComplexF64, site_inds; linkdims=1)
    for j=1:n_sites
        T = ITensor(site_inds[j])
        zero_state, one_state = complex_feature_map(sample[j])
        T[1] = zero_state
        T[2] = one_state
        product_state[j] = T 
    end
    return product_state
end

sample_to_product_state (generic function with 1 method)

In [446]:
function dataset_to_product_state(dataset::Matrix, labels::Vector, sites::Vector{Index{Int64}})

    all_product_states = Vector{PState}(undef, size(dataset)[1])
    for p=1:length(all_product_states)
        sample_pstate = sample_to_product_state(dataset[p, :], sites)
        sample_label = labels[p]
        product_state = PState(sample_pstate, sample_label)
        all_product_states[p] = product_state
    end

    return all_product_states

end

dataset_to_product_state (generic function with 1 method)

In [447]:
s = siteinds("S=1/2", 5)
mps = randomMPS(ComplexF64, s; linkdims=4)
all_samples, all_labels = generate_training_data(200)
all_product_states = dataset_to_product_state(all_samples, all_labels, s)
shuffle!(all_product_states);

create and attach label index

In [448]:
label_idx = Index(2, "f(x)")
old_mps_site1 = deepcopy(mps[1])
old_mps_site1_inds = inds(old_mps_site1)
new_mps_site1_inds = old_mps_site1_inds, label_idx
new_mps_site1_tensor = randomITensor(ComplexF64, new_mps_site1_inds);

In [449]:
mps[1] = new_mps_site1_tensor
normalize!(mps)

MPS
[1] ((dim=2|id=754|"S=1/2,Site,n=1"), (dim=4|id=455|"Link,l=1"), (dim=2|id=567|"f(x)"))
[2] ((dim=4|id=455|"Link,l=1"), (dim=2|id=485|"S=1/2,Site,n=2"), (dim=4|id=682|"Link,l=2"))
[3] ((dim=4|id=682|"Link,l=2"), (dim=2|id=432|"S=1/2,Site,n=3"), (dim=4|id=644|"Link,l=3"))
[4] ((dim=4|id=644|"Link,l=3"), (dim=2|id=970|"S=1/2,Site,n=4"), (dim=2|id=495|"Link,l=4"))
[5] ((dim=2|id=495|"Link,l=4"), (dim=2|id=678|"S=1/2,Site,n=5"))


# Start with site 1 and 2

In [469]:
# form the bond tensor
BT_s1_s2 = mps[1] * mps[2]

ITensor ord=4 (dim=2|id=754|"S=1/2,Site,n=1") (dim=2|id=567|"f(x)") (dim=2|id=485|"S=1/2,Site,n=2") (dim=4|id=682|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [470]:
LE = Matrix{ITensor}(undef, 400, 5);
RE = Matrix{ITensor}(undef, 400, 5);

In [471]:
N_train = 400

400

In [472]:
for i = 1:N_train
    RE[i, 5] = conj(all_product_states[i].pstate[5]) * mps[5]
end

for j = (5-1):-1:1
    for i = 1:N_train
        RE[i, j] = RE[i, j+1] * mps[j] * conj(all_product_states[i].pstate[j])
    end
end

In [473]:
function loss_per_sample(BT, mps, RE, ps, psid)
    """Loss for BT1 w.r.t. the mps for a single sample"""
    # first construct the prediction yhat
    y = ps.label

    yhat = BT
    BT *= RE[psid, 3]

    yhat *= conj(ps.pstate[psid]) * conj(ps.pstate[psid])
    label_idx = inds(yhat)[1]
    
    yhat1 = yhat[1]
    yhat2 = yhat[2]

    yhat1 = abs(yhat1[])
    yhat2 = abs(yhat2[])


    diff_sq1 = (yhat1 - y)^2
    diff_sq2 = (yhat2 - y)^2

    sum_of_sq_diff = diff_sq1 + diff_sq2

    loss = 0.5 * sum_of_sq_diff

    return loss
end

loss_per_sample (generic function with 4 methods)

In [489]:
loss_per_sample(BT_s1_s2, mps, RE, all_product_states[1], 1)

0.7692825510879107

In [490]:
f, (∇,) = withgradient(loss_per_sample, BT_s1_s2, mps, RE, all_product_states[1], 1);