In [73]:
using ITensors
using Random
include("MPSLLOptimkit.jl");

In [74]:
sites = siteinds("S=1/2", 100);

In [75]:
function GenerateSumState(training_pstates::Vector{PState}, sites; n_initial=50, maxdim=5, 
    cutoff=1E-10, random_state=42)

    Random.seed!(random_state)

    labels = [state.label for state in training_pstates]
    num_classes = length(unique(labels))
    label_idx = Index(num_classes, "f(x)")
    
    # determine min # of samples per class
    sample_counts = [length(findall(labels .== class)) for class in unique(labels)]
    min_samples = minimum(sample_counts)

    n_samples = min(n_initial, min_samples)

    # store individual label states for summing
    label_states_store = []

    @Threads.threads for class in sort(unique(labels))
        class_indices = findall(labels .== class)
        selected_indices = StatsBase.sample(class_indices, n_samples; replace=false)
        selected_samples = training_pstates[selected_indices]
        sample_mps = [sample.pstate for sample in selected_samples]
        println(sample_mps)
        label_state = +(sample_mps...; cutoff=cutoff, maxdim=maxdim)
        # add random noise as mps
        perturbation = 100 * randomMPS(sites; linkdims=maxdim)
        perturbed_label_state = +(label_state, perturbation; cutoff=cutoff, maxdim=maxdim)
        # construct label index
        label = onehot(label_idx => (class + 1))
        perturbed_label_state[1] *= label
        push!(label_states_store, perturbed_label_state)
    end

    # sum together class-specific label states to get an MPS
    W = +(label_states_store...; cutoff=cutoff, maxdim=maxdim)

    # normalize
    normalize!(W)
    print(W)

    return W

end

GenerateSumState (generic function with 2 methods)

In [76]:
X_train = rand(50, 100)
y_train = rand([0, 1], 50);
X_binarised = BinariseDataset(X_train);
ps = GenerateAllProductStates(X_binarised, y_train, "train", sites);

In [77]:
W = GenerateSumState(ps; n_initial=4, maxdim=40, cutoff=1E-10, random_state=42)

MPS
[1] ((dim=2|id=404|"S=1/2,Site,n=1"), (dim=2|id=325|"f(x)"), (dim=4|id=715|"Link,l=1"))
[2] ((dim=2|id=840|"S=1/2,Site,n=2"), (dim=5|id=918|"Link,l=2"), (dim=4|id=715|"Link,l=1"))
[3] ((dim=2|id=254|"S=1/2,Site,n=3"), (dim=6|id=394|"Link,l=3"), (dim=5|id=918|"Link,l=2"))
[4] ((dim=2|id=300|"S=1/2,Site,n=4"), (dim=7|id=160|"Link,l=4"), (dim=6|id=394|"Link,l=3"))
[5] ((dim=2|id=689|"S=1/2,Site,n=5"), (dim=7|id=165|"Link,l=5"), (dim=7|id=160|"Link,l=4"))
[6] ((dim=2|id=393|"S=1/2,Site,n=6"), (dim=8|id=90|"Link,l=6"), (dim=7|id=165|"Link,l=5"))
[7] ((dim=2|id=249|"S=1/2,Site,n=7"), (dim=8|id=927|"Link,l=7"), (dim=8|id=90|"Link,l=6"))
[8] ((dim=2|id=703|"S=1/2,Site,n=8"), (dim=8|id=82|"Link,l=8"), (dim=8|id=927|"Link,l=7"))
[9] ((dim=2|id=300|"S=1/2,Site,n=9"), (dim=8|id=822|"Link,l=9"), (dim=8|id=82|"Link,l=8"))
[10] ((dim=2|id=523|"S=1/2,Site,n=10"), (dim=8|id=255|"Link,l=10"), (dim=8|id=822|"Link,l=9"))
[11] ((dim=2|id=365|"S=1/2,Site,n=11"), (dim=8|id=277|"Link,l=11"), (dim=8|id=255

In [78]:
res = 1
for i=1:100
    res *= W[i] * ps[15].pstate[i]
end

In [83]:
vector(res)

2-element Vector{Float64}:
 0.0
 0.0