In [1]:
using Flux, DelimitedFiles                                                       
using LinearAlgebra, Random                                                      
using Distributions                                                              
using Flux: onehot, onehotbatch, onecold, throttle, kldivergence, chunk, batchseq
using Flux.Optimise: update!                                                     
using Zygote 

In [2]:
Zygote.@nograd onehotbatch

labels = 0:1
batch_size = 10

train_path = "tfim1D_samples"
psi_path = "tfim1D_psi"

data = Int.(readdlm(train_path))
psi = readdlm(psi_path)[:,1]

train_data = data[1:10000, :]
test_data = data[10001:end, :]

N = size(train_data,2)

m = Chain(
    GRU(length(labels), 100),
    GRU(100, length(labels)),
    softmax
)
m = f64(m)
opt = ADAM(0.01)

train_data = [train_data[i,:] for i = 1:size(train_data, 1)]
test_data = [test_data[i,:] for i = 1:size(test_data, 1)]

train_data = map(ch -> onehotbatch(ch, labels), train_data)
test_data = map(ch -> onehotbatch(ch, labels), test_data);

In [3]:
function probability(v)
    Flux.reset!(m)
    
    if length(size(v)) == 1
        v0 = onehot(0, labels)
        prob = dot(m(v0), v[1][:,1])
        
        for i in 1:N-1
            prob *= dot(m(v[1][:,i]), v[1][:,i+1])
        end
        Flux.reset!(m)

    else
        prob = zeros(size(v,1))
        for j in 1:size(v,1)
            v0 = onehot(0, labels)
            
            prob[j] = dot(m(v0), v[j][:,1])
            for i in 1:N-1
                prob[j] *= dot(m(v[j][:,i]), v[j][:,i+1])
            end
            
            Flux.reset!(m)
        end
    end

    return prob
end

function loss(v)
    # Negative log-likelihood
    if length(size(v)) == 1
        prob = probability(v)
        log_prob = log(prob)
        return -log_prob
    else
        prob = probability(v)
        log_prob = log.(prob)
        return -sum(log_prob) / length(prob)
    end
end

function generate_hilbert_space(;hot=false)
    dim = [i for i in 0:2^N-1]
    space = space = parse.(Int64, split(bitstring(dim[1])[end-N+1:end],""))
    
    for i in 2:length(dim)
        tmp = parse.(Int64, split(bitstring(dim[i])[end-N+1:end],""))
        space = hcat(space, tmp)
    end
    
    space = transpose(space)

    if hot
        space = [space[i,:] for i = 1:size(space, 1)]
        space = map(ch -> onehotbatch(ch, labels), space) 
        space = reshape(space, (length(dim),1))
    end
    
    return space
end

function fidelity(space, target)
    return dot(target, sqrt.(probability(space)))
end

fidelity (generic function with 1 method)

In [4]:
space = generate_hilbert_space(hot=true)
ps = Flux.params(m)

train_data = [view(train_data, k:k+batch_size-1, :) for k in 1:batch_size:size(train_data,1)]; 
test_data = [view(test_data, k:k+batch_size-1, :) for k in 1:batch_size:size(test_data,1)]; 
                                                                                 
epochs = 1:100                                                                   
num_batches = size(train_data,1) # needs to generalize                           
num_samples = 2000

println("fidelity: ", fidelity(space, psi))
println("loss: ", loss(test_data[1]))

fidelity: 0.8978442343132055
loss: 2.9472771669840547


In [5]:
for ep in epochs
    @show ep
    for b in 1:num_batches
        batch = train_data[b]
        l = loss(batch)
        gs = Flux.gradient(() -> l, ps)
        # gs = Flux.gradient(() -> loss(batch), Flux.params(m))
        # the code below doesn't work for some reason. 
        # Error: Mutating arrays is not supported
        
        update!(opt, ps, gs)
    end

    println("loss: ",loss(test_data[1]))
    println("fidelity: ", fidelity(space, psi))

end

ep = 1
loss: 2.9472771669840547
fidelity: 0.8978442343132055
ep = 2
loss: 2.9472771669840547
fidelity: 0.8978442343132055
ep = 3
loss: 2.9472771669840547
fidelity: 0.8978442343132055
ep = 4
loss: 2.9472771669840547
fidelity: 0.8978442343132055
ep = 5
loss: 2.9472771669840547
fidelity: 0.8978442343132055
ep = 6


LoadError: InterruptException: