In [1]:
using ITensors
using Zygote
using Random
using Optim

Let's create a random complex valued 5 site MPS to begin...

In [2]:
Random.seed!(42)
s = siteinds("S=1/2", 5)
mps = randomMPS(ComplexF64, s; linkdims=4)

MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"), (dim=4|id=106|"Link,l=1"))
[2] ((dim=4|id=106|"Link,l=1"), (dim=2|id=135|"S=1/2,Site,n=2"), (dim=4|id=230|"Link,l=2"))
[3] ((dim=4|id=230|"Link,l=2"), (dim=2|id=984|"S=1/2,Site,n=3"), (dim=4|id=467|"Link,l=3"))
[4] ((dim=4|id=467|"Link,l=3"), (dim=2|id=815|"S=1/2,Site,n=4"), (dim=2|id=242|"Link,l=4"))
[5] ((dim=2|id=242|"Link,l=4"), (dim=2|id=691|"S=1/2,Site,n=5"))


In [3]:
function complex_feature_map(x::Float64)
    s1 = exp(1im * (3π/2) * x) * cospi(0.5 * x)
    s2 = exp(-1im * (2π/2) * x) * sinpi(0.5 * x)
    return [s1, s2]
end

complex_feature_map (generic function with 1 method)

In [4]:
function sample_to_product_state(sample::Vector, site_inds::Vector{Index{Int64}})
    n_sites = length(site_inds)
    product_state = MPS(ComplexF64, site_inds; linkdims=1)
    for j=1:n_sites
        T = ITensor(site_inds[j])
        zero_state, one_state = complex_feature_map(sample[j])
        T[1] = zero_state
        T[2] = one_state
        product_state[j] = T 
    end
    return product_state
end

sample_to_product_state (generic function with 1 method)

Generate sample and encode as a product state.

In [5]:
sample = zeros(5)

5-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0

In [6]:
ps = sample_to_product_state(sample, s)

MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=815|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=691|"S=1/2,Site,n=5"),)


## Create Bond Tensor

We will create a bond tensor between sites 1 and 2

In [7]:
BT12 = mps[1] * mps[2]

ITensor ord=3 (dim=2|id=325|"S=1/2,Site,n=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

Now let's define our loss function. Since we are dealing with a single product state and bond tensor, we can make it specific:

In [8]:
function loss(B::ITensor)
    y = 1.0 # train the bond tensor to maximise overlap
    # bond tensor on sites 1 and 2
    lid = 1
    rid = 2 
    phi_tilde = ps[lid] * ps[rid] * mps[3] * ps[3] * mps[4] * ps[4] * mps[5] * ps[5] # effective input
    yhat = B * phi_tilde
    diff_mod_sq = norm(yhat[] - y)^2
    loss = 0.5 * diff_mod_sq
    return loss
end

loss (generic function with 1 method)

Test the loss function...

In [9]:
loss(BT12)

0.4120511832985574

Now construct the proposed analytic gradient for comparison against AD derived gradient:

In [10]:
function analytic_gradient(B::ITensor)
    lid = 1
    rid = 2
    phi_tilde = ps[lid] * ps[rid] * mps[3] * ps[3] * mps[4] * ps[4] * mps[5] * ps[5]
    yhat = B * phi_tilde
    y = 1.0
    dP = yhat[] - y
    grad = 0.5 * dP * conj(phi_tilde)

    return grad
end

analytic_gradient (generic function with 1 method)

# Analytic Gradient Evaluation

In [11]:
analytic_grad = analytic_gradient(BT12)

ITensor ord=3 (dim=2|id=325|"S=1/2,Site,n=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

Inspect the values...

In [12]:
collect(Iterators.flatten(analytic_grad))

16-element Vector{ComplexF64}:
 -0.027823667541818135 - 0.09315554666007617im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
  -0.08952810956088028 + 0.03174243040135726im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
   0.04236007657464593 - 0.028288650564501885im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
  0.039321688985887636 - 0.022450873496918794im
                  -0.0 + 0.0im
                  -0.0 + 0.0im
                  -0.0 + 0.0im

# Zygote Gradient Evaluation

In [13]:
zygote_grad = gradient(loss, BT12)

(ITensor ord=3
Dim 1: (dim=2|id=325|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
 -0.05564733508363627 - 0.18631109332015233im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 2] =
 -0.17905621912176056 + 0.06348486080271452im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 3] =
 0.08472015314929186 - 0.05657730112900377im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 4] =
 0.07864337797177527 - 0.04490174699383759im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im,)

In [14]:
@show zygote_grad

zygote_grad = (ITensor ord=3
Dim 1: (dim=2|id=325|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
 -0.05564733508363627 - 0.18631109332015233im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 2] =
 -0.17905621912176056 + 0.06348486080271452im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 3] =
 0.08472015314929186 - 0.05657730112900377im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 4] =
 0.07864337797177527 - 0.04490174699383759im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im,)


(ITensor ord=3
Dim 1: (dim=2|id=325|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
 -0.05564733508363627 - 0.18631109332015233im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 2] =
 -0.17905621912176056 + 0.06348486080271452im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 3] =
 0.08472015314929186 - 0.05657730112900377im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 4] =
 0.07864337797177527 - 0.04490174699383759im  -0.0 + 0.0im
                -0.0 + 0.0im                  -0.0 + 0.0im,)

In [15]:
@show analytic_grad

analytic_grad = ITensor ord=3
Dim 1: (dim=2|id=325|"S=1/2,Site,n=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2×2×4
[:, :, 1] =
 -0.027823667541818135 - 0.09315554666007617im  -0.0 + 0.0im
                  -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 2] =
 -0.08952810956088028 + 0.03174243040135726im  -0.0 + 0.0im
                 -0.0 + 0.0im                  -0.0 + 0.0im

[:, :, 3] =
 0.04236007657464593 - 0.028288650564501885im  -0.0 + 0.0im
                -0.0 + 0.0im                   -0.0 + 0.0im

[:, :, 4] =
 0.039321688985887636 - 0.022450873496918794im  -0.0 + 0.0im
                 -0.0 + 0.0im                   -0.0 + 0.0im


ITensor ord=3 (dim=2|id=325|"S=1/2,Site,n=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

So it seems we are out by a factor of 2. The analytic gradient is half the zygote gradient.

# Test another bond tensor

In [16]:
BT23 = mps[2] * mps[3]

ITensor ord=4 (dim=4|id=106|"Link,l=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=2|id=984|"S=1/2,Site,n=3") (dim=4|id=467|"Link,l=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [17]:
function loss23(B::ITensor)
    y = 1.0 # train the bond tensor to maximise overlap
    # bond tensor on sites 1 and 2
    lid = 2
    rid = 3 
    phi_tilde = mps[1] * ps[1] * ps[lid] * ps[rid] * mps[4] * ps[4] * mps[5] * ps[5] # effective input
    yhat = B * phi_tilde
    diff_mod_sq = norm(yhat[] - y)^2
    loss = 0.5 * diff_mod_sq
    return loss
end

loss23 (generic function with 1 method)

In [18]:
loss23(BT23)

0.4120511832985574

In [19]:
function analytic_gradient23(B::ITensor)
    lid = 2
    rid = 3
    phi_tilde = mps[1] * ps[1] * ps[lid] * ps[rid] * mps[4] * ps[4] * mps[5] * ps[5] # effective input
    yhat = B * phi_tilde
    y = 1.0
    dP = yhat[] - y
    grad = 0.5 * dP * conj(phi_tilde)

    return grad
end

analytic_gradient23 (generic function with 1 method)

In [20]:
analytic_grad23 = analytic_gradient23(BT23)
@show analytic_grad23

analytic_grad23 = ITensor ord=4
Dim 1: (dim=4|id=106|"Link,l=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=2|id=984|"S=1/2,Site,n=3")
Dim 4: (dim=4|id=467|"Link,l=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4×2×2×4
[:, :, 1, 1] =
   0.01011703916110151 + 0.032159849621781776im  -0.0 + 0.0im
  0.022678563632097453 - 0.013049850951575947im  -0.0 + 0.0im
  -0.02865986361853609 + 0.013492150565468316im  -0.0 + 0.0im
 0.0013413360956822478 - 0.025082537444381374im  -0.0 + 0.0im

[:, :, 2, 1] =
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im

[:, :, 1, 2] =
 0.030913869293510886 + 0.09641329962162927im   -0.0 + 0.0im
  0.06789129363109202 - 0.039564938278620884im  -0.0 + 0.0im
 -0.08584663909490425 + 0.040991989664557756im  -0.0 + 0.0im
 0.003614037868999127 - 0.0753483847867521im    -0.0 + 0.0im

[:, :, 2, 2] =
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -

ITensor ord=4 (dim=4|id=106|"Link,l=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=2|id=984|"S=1/2,Site,n=3") (dim=4|id=467|"Link,l=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [21]:
zygote_grad23 = gradient(loss23, BT23)

(ITensor ord=4
Dim 1: (dim=4|id=106|"Link,l=1")
Dim 2: (dim=2|id=135|"S=1/2,Site,n=2")
Dim 3: (dim=2|id=984|"S=1/2,Site,n=3")
Dim 4: (dim=4|id=467|"Link,l=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4×2×2×4
[:, :, 1, 1] =
   0.02023407832220302 + 0.06431969924356355im   -0.0 + 0.0im
  0.045357127264194906 - 0.026099701903151893im  -0.0 + 0.0im
  -0.05731972723707218 + 0.026984301130936632im  -0.0 + 0.0im
 0.0026826721913644955 - 0.05016507488876275im   -0.0 + 0.0im

[:, :, 2, 1] =
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im

[:, :, 1, 2] =
  0.06182773858702177 + 0.19282659924325854im  -0.0 + 0.0im
  0.13578258726218403 - 0.07912987655724177im  -0.0 + 0.0im
  -0.1716932781898085 + 0.08198397932911551im  -0.0 + 0.0im
 0.007228075737998254 - 0.1506967695735042im   -0.0 + 0.0im

[:, :, 2, 2] =
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im
 -0.0 + 0.0im  -0.0 + 0.0im

[:, :, 1

Same result - analytic gradient is half the zygote gradient

In [22]:
wirtinger

UndefVarError: UndefVarError: `wirtinger` not defined

# Optim + Analytic Gradient

In [23]:
function flatten_bond_tensor(BT::ITensor)
    """Function to flatten an ITensor so that it can be fed into Optim
    as a vector."""
    # should probably return the indices as well
    # might need checks to ensure correct assignment of indices to values
    flattened_tensor = collect(Iterators.flatten(BT))
    return flattened_tensor, inds(BT)
end

flatten_bond_tensor (generic function with 1 method)

In [24]:
function reconstruct_bond_tensor(BT_flat::Vector, indices)
    BT = ITensor(indices)
    # ORDER OF ASSIGNMENT MUST MATCH THE ORDER OF FLATTENING
    for (n, val) in enumerate(BT_flat)
        BT[n] = val
    end

    return BT

end

reconstruct_bond_tensor (generic function with 1 method)

In [25]:
BT12_flat, BT12_inds = flatten_bond_tensor(BT12);

Redefine the gradient for flattened bond tensor:

In [26]:
function loss_flat(params::Vector, bt_inds)
    B = reconstruct_bond_tensor(params, bt_inds)
    loss_val = loss(B)
    return loss_val
end

loss_flat (generic function with 1 method)

Check loss is same as for ITensor (not flattened)

In [27]:
loss_flat(BT12_flat, BT12_inds)

0.4120511832985574

In [28]:
loss(BT12)

0.4120511832985574

In [29]:
function gradient_flat(params::Vector, bt_inds)
    B = reconstruct_bond_tensor(params, bt_inds)
    grad = analytic_gradient(B)
    return grad
end

gradient_flat (generic function with 1 method)

In [30]:
analytic_gradient(BT12) == gradient_flat(BT12_flat, BT12_inds)

true

In [31]:
cost = x -> loss_flat(x, BT12_inds)

#9 (generic function with 1 method)

In [32]:
grad = x -> gradient_flat(x, BT12_inds)

#11 (generic function with 1 method)

In [33]:
struct PState
    """Define a custom struct for product states"""
    pstate::MPS # product state as a vector of ITenors (MPS)
    label::Int # ground truth class label
    id::Int # identifier for caching
end

In [34]:
function generate_training_data(samples_per_class::Int; data_pts::Int=5)

    class_A_samples = zeros(samples_per_class, data_pts)
    class_B_samples = ones(samples_per_class, data_pts)
    all_samples = vcat(class_A_samples, class_B_samples)
    all_labels = Int.(vcat(zeros(size(class_A_samples)[1]), ones(size(class_B_samples)[1])))

    shuffle_idxs = shuffle(1:samples_per_class*2)


    return all_samples[shuffle_idxs, :], all_labels[shuffle_idxs]

end

generate_training_data (generic function with 1 method)

In [35]:

function sample_to_product_state(ts::Vector, site_inds::Vector{Index{Int64}})
    """Convert a SINGLE time series (ts) to a product state (mps)"""
    n_sites = length(site_inds)
    product_state = MPS(n_sites)
    for site in 1:n_sites
        # loop over each site, create an itensor and fill with encoded values
        T = ITensor(ComplexF64, site_inds[site])
        zero_state, one_state = complex_feature_map(ts[site]) # 
        T[1] = zero_state
        T[2] = one_state
        product_state[site] = T
    end

    return product_state

end

sample_to_product_state (generic function with 1 method)

In [36]:
function dataset_to_product_state(ts_dataset::Matrix, ts_labels::Vector{Int}, site_inds::Vector{Index{Int64}})
    """Convert ALL time series (ts) in a dataset to a vector of
    PStates"""
    dataset_shape = size(ts_dataset)
    #@assert dataset_shape[1] > dataset_shape[2] "Ensure time series are in rows"

    all_product_states = Vector{PState}(undef, dataset_shape[1])
    for p in 1:length(all_product_states)
        # note, now using column-major ordering, so ts stored in COLUMNS not rows
        time_series_as_product_state = sample_to_product_state(ts_dataset[p, :], site_inds)
        time_series_label = ts_labels[p]
        product_state = PState(time_series_as_product_state, time_series_label, p)
        all_product_states[p] = product_state
    end

    return all_product_states
end

dataset_to_product_state (generic function with 1 method)

In [37]:
function flatten_bond_tensor(BT::ITensor)
    """Function to flatten an ITensor so that it can be fed into Optim
    as a vector."""
    # should probably return the indices as well
    # might need checks to ensure correct assignment of indices to values
    flattened_tensor = collect(Iterators.flatten(BT))
    return flattened_tensor, inds(BT)
end

function reconstruct_bond_tensor(BT_flat::Vector, indices)
    BT = ITensor(indices)
    # ORDER OF ASSIGNMENT MUST MATCH THE ORDER OF FLATTENING
    for (n, val) in enumerate(BT_flat)
        BT[n] = val
    end

    return BT

end

reconstruct_bond_tensor (generic function with 1 method)

In [38]:
function construct_caches(mps::MPS, training_product_states::Vector{PState}; going_left=true)
    """Function to pre-allocate tensor contractions between the MPS and the product states.
    LE stores the left environment, i.e. all accumulate contractions from site 1 to site N
    RE stores the right env., all contractions from site N to site 1."""

    # get the number of training samples to pre-allocated caches
    n_train = length(training_product_states)
    n = length(mps)
    # make the caches
    LE = Matrix{ITensor}(undef, n_train, n)
    RE = Matrix{ITensor}(undef, n_train, n)

    for i in 1:n_train 
        # get the product state for the current training sample
        ps = training_product_states[i].pstate

        if going_left
            # initialise the first contraction
            LE[i, 1] = mps[1] * ps[1]
            for j in 2:n
                LE[i, j] = LE[i, j-1] * ps[j] * mps[j]
            end
            
        else
            # going right
            RE[i, n] = ps[n] * mps[n]
            # accumulate remaining sites
            for j in n-1:-1:1
                RE[i, j] = RE[i, j+1] * ps[j] * mps[j]
            end
        end
    end

    return LE, RE
    
end

construct_caches (generic function with 1 method)

In [39]:
samples, labels = generate_training_data(10);
all_pstates = dataset_to_product_state(samples, labels, s)

20-element Vector{PState}:
 PState(MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=815|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=691|"S=1/2,Site,n=5"),)
, 0, 1)
 PState(MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=815|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=691|"S=1/2,Site,n=5"),)
, 0, 2)
 PState(MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=815|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=691|"S=1/2,Site,n=5"),)
, 1, 3)
 PState(MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1/2,Site,n=3"),)
[4] ((dim=2|id=815|"S=1/2,Site,n=4"),)
[5] ((dim=2|id=691|"S=1/2,Site,n=5"),)
, 1, 4)
 PState(MPS
[1] ((dim=2|id=325|"S=1/2,Site,n=1"),)
[2] ((dim=2|id=135|"S=1/2,Site,n=2"),)
[3] ((dim=2|id=984|"S=1

In [40]:
LE, RE = construct_caches(mps, all_pstates; going_left=false)

(ITensor[#undef #undef … #undef #undef; #undef #undef … #undef #undef; … ; #undef #undef … #undef #undef; #undef #undef … #undef #undef], ITensor[ITensor ord=0
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 0-dimensional
0.09251635524612625 - 0.02399585592013451im ITensor ord=1
Dim 1: (dim=4|id=106|"Link,l=1")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4-element
 -0.12586404888028077 + 0.036145725698430865im
   0.1662628247002933 + 0.014721348596851104im
  0.01652614991548789 + 0.027662784848295142im
  0.07247047733379401 - 0.06213913946417862im … ITensor ord=1
Dim 1: (dim=4|id=467|"Link,l=3")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 4-element
 -0.06259047707742635 - 0.2011623231544069im
 -0.18464593780559382 - 0.6051510565944971im
 -0.37109045035808635 + 0.6350020455167755im
  0.11469803682472071 - 0.034969536283704986im ITensor ord=1
Dim 1: (dim=2|id=242|"Link,l=4")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}
 2-element
 0.4800595159922907 - 0.5357791469433074im
 

In [41]:
function fg!(F, G, x, B_inds, LE::Matrix, RE::Matrix, pss::Vector{PState}, lid::Int, rid::Int)
    # common computations
    B = reconstruct_bond_tensor(x, B_inds)
    loss_accum = 0
    grad_accum = ITensor()

    for ps in pss
        prod_state = ps.pstate
        phi_tilde = prod_state[lid] * prod_state[rid]
        n = size(LE, 2) # number of mps sites
        if lid == 1
            phi_tilde *= RE[ps.id, rid+1]
        elseif rid == n
            phi_tilde *= LE[ps.id, lid-1]
        else
            phi_tilde *= LE[ps.id, lid-1] * RE[ps.id, rid+1]
        end
        yhat = B * phi_tilde
        y = ps.label
        dP = yhat[] - y
        diff_sq = norm(dP)^2
        loss_accum += 0.5 * diff_sq
    
        if G !== nothing
            # compute gradient
            grad = dP * conj(phi_tilde)
            grad_accum += grad
        end
    end

    if G !== nothing
        grad_overall = grad_accum ./ length(pss)
        copyto!(G, grad_overall)
    end

    if F !== nothing
        final_loss = loss_accum / length(pss)
        return final_loss
    end
end

fg! (generic function with 1 method)

In [42]:
function create_fg!(B_inds, LE::Matrix, RE::Matrix, pss::Vector{PState}, lid::Int, rid::Int)
    return function (F, G, x)
        # Your existing fg! function body here, using B_inds, LE, RE, pss, lid, rid
        # as captured from the surrounding scope
        B = reconstruct_bond_tensor(x, B_inds)
        loss_accum = 0
        grad_accum = ITensor()

        for ps in pss
            prod_state = ps.pstate
            phi_tilde = prod_state[lid] * prod_state[rid]
            n = size(LE, 2) # number of mps sites
            if lid == 1
                phi_tilde *= RE[ps.id, rid+1]
            elseif rid == n
                phi_tilde *= LE[ps.id, lid-1]
            else
                phi_tilde *= LE[ps.id, lid-1] * RE[ps.id, rid+1]
            end
            yhat = B * phi_tilde
            y = ps.label
            dP = yhat[] - y
            diff_sq = norm(dP)^2
            loss_accum += 0.5 * diff_sq
        
            if G !== nothing
                # compute gradient
                grad = dP * conj(phi_tilde)
                grad_accum += grad
            end
        end

        if G !== nothing
            grad_overall = grad_accum ./ length(pss)
            copyto!(G, grad_overall)
        end

        if F !== nothing
            final_loss = loss_accum / length(pss)
            return final_loss
        end
    end
end

create_fg! (generic function with 1 method)

In [43]:
BT12

ITensor ord=3 (dim=2|id=325|"S=1/2,Site,n=1") (dim=2|id=135|"S=1/2,Site,n=2") (dim=4|id=230|"Link,l=2")
NDTensors.Dense{ComplexF64, Vector{ComplexF64}}

In [44]:
BT_init = deepcopy(BT12)
BT12_flat, BT12_flat_inds = flatten_bond_tensor(BT_init);

In [45]:
fg! = create_fg!(BT12_flat_inds, LE, RE, all_pstates, 1, 2)

ErrorException: invalid redefinition of constant Main.fg!

In [46]:
Optim.optimize(Optim.only_fg!(fg!), BT12_flat, Optim.LBFGS()) 

MethodError: MethodError: no method matching fg!(::Float64, ::Vector{ComplexF64}, ::Vector{ComplexF64})

Closest candidates are:
  fg!(::Any, ::Any, ::Any, !Matched::Any, !Matched::Matrix, !Matched::Matrix, !Matched::Vector{PState}, !Matched::Int64, !Matched::Int64)
   @ Main ~/Documents/QuantumInspiredML/MPS_MSE/complex-opt/Julia/autograd/AD_versus_analytic.ipynb:1
