In [7]:
using Flux, DelimitedFiles                                                       
using LinearAlgebra, Random                                                      
using Distributions                                                              
using Flux: onehot, onecold, throttle, logitcrossentropy, chunk, batchseq
using Flux.Optimise: update!       
using Base.Iterators: partition
using Zygote 

In [8]:
labels = 0:1
batch_size = 10

train_path = "samples_N=2_h=1"
psi_path = "psi_N=2_h=1"

data = Int.(readdlm(train_path))
psi = readdlm(psi_path)[:,1]

train_data = data[1:10000, :]
test_data = data[10001:end, :]

N = size(train_data,2)

m = Chain(
    GRU(length(labels), 100),
    GRU(100, length(labels)),
    softmax
)
m = f64(m)
opt = ADAM(0.01)

train_data = map(ch -> onehot(ch, labels), train_data)
train_data = collect(partition(batchseq(chunk(train_data, batch_size)), N));

In [9]:
function probability(v)
    # initialize the calculation with an all-zero configuration
    init_config = zeros(1,size(v[1],2))
    # onehot init_config
    init_config = map(ch -> onehot(ch, labels), init_config)
    init_config = batchseq(chunk(init_config, size(v[1],2)))

    # concactenate init_config and v[1:end-1]
    vp = vcat(init_config, v[1:end-1])

    # apply model to vp
    probs = m.(vp)

    # multiply conditionals to get probability vector
    probs = vcat(probs...)
    v = vcat(v...)
    probs = dot.(probs, v)
    
    probs[probs .== 0] .= 1
    probs = prod(probs, dims=1)
    
    return probs
end  

function loss(v)
    prob = probability(v)
    log_prob = log.(prob)
    return -sum(log_prob) / length(prob)
end

function generate_hilbert_space(;hot=false)
    # TODO: currently not outputting correct format for hot=True
    dim = [i for i in 0:2^N-1]
    space = reshape(parse.(Int64, split(bitstring(dim[1])[end-N+1:end],"")), (N,))
    for i in 2:length(dim)
        tmp = reshape(parse.(Int64, split(bitstring(dim[i])[end-N+1:end],"")), (N,))
        space = hcat(space, tmp)
    end
    
    space = reshape(space, (length(dim),N))
    
    if hot
        space = map(ch -> onehot(ch, labels), space)
        space = collect(partition(batchseq(chunk(space', length(dim))), N))[1]
    end
    
    return space
end

function fidelity(space, target)
    return dot(target, sqrt.(probability(space)))
end

fidelity (generic function with 1 method)

In [10]:
#=
v = [0 0; 1 1]
@show chunk(v',2)
v = map(ch -> onehot(ch, labels), v)
@show chunk(v', 2)
@show collect(partition(batchseq(chunk(v', 2)), 2))
=#

In [11]:
# TODO: fix generate hilbert space function
#space = generate_hilbert_space(hot=true)
# generate space for N=2 manually for now
space = [[1 1 0 0; 0 0 1 1], [1 0 1 0; 0 1 0 1]]
ps = Flux.params(m)
                                                                                 
epochs = 1:100                                                                   
num_batches = size(train_data,1)                           

1000

In [12]:
@show loss(train_data[1])

for ep in epochs
    @show ep
    for b in 1:num_batches
        batch = train_data[b]
        Flux.reset!(m)
        gs = Flux.gradient(() -> loss(batch), ps)
        
        update!(opt, ps, gs)
    end

    println("fidelity: ", fidelity(space, psi))
end

loss(train_data[1]) = 1.5534910535991213
ep = 1


LoadError: Mutating arrays is not supported