In [1]:
using Flux

In [2]:
using Random
using Statistics
using Plots

In [3]:
# Generate dataset of NN matricies W and H and their product V
#set seed
Random.seed!(314)

# V ∈ R^(m × n), and r is the (low) rank of V
m, n = 15, 15
r = 5
true_W = abs.(randn((m, r)))

train_set_size = 5
test_set_size  = 1

train_H = [abs.(randn((r, n))) for _ ∈ 1:train_set_size]
test_H  = [abs.(randn((r, n))) for _ ∈ 1:test_set_size ]
train_V = [true_W*train_H[i]   for i ∈ 1:train_set_size]
test_V  = [true_W*test_H[i]    for i ∈ 1:test_set_size ];

# LATER can add small noise

1-element Vector{Matrix{Float64}}:
 [4.550121595210181 1.700274660121937 … 5.254907657778933 2.083064155509878; 3.8471541142042067 3.1283424181052757 … 3.8738437723677883 2.6841725312646405; … ; 1.8203071153870078 1.3698744445597113 … 2.0898543043652213 0.53601649623376; 4.384254098580098 2.4411220123602884 … 5.715499758606215 2.2778623971890495]

In [23]:
# Define the model (copied from https://fluxml.ai/Flux.jl/stable/models/advanced/)
struct MyLayer
    W
end

function MyLayer(m,r) # TODO add type stable-ness
    W = abs.(randn((m,r))) # TODO adjust initialization
    MyLayer(W)
end

function (a::MyLayer)((H, V)) #must pass a single item into chains, that is why (H,V) are grouped
    W = a.W
    return (H .* (W'*V) ./ (W'*W*H), V) #one update step
    # TODO can add regularization later
end
Flux.trainable(a::MyLayer) = (a.W,)
#Flux.@functor MyLayer 

struct MyModel
    chain::Chain
end

H_guess = abs.(randn((r,n))) # same random guess for all inputs

# this is automatic from the struct
#function MyModel(chain::Chain)
#    MyModel(chain) 
#end

function (m::MyModel)(V::Matrix{Float64}, H_guess=H_guess)
    H_out, V_out = m.chain((H_guess, V)) #must pass a single item into chains
    return H_out #just care about the first entry
end

function (m::MyModel)(V::Vector{Matrix{Float64}}) #TODO update Float64 to something more general
    return m.(V) #apply m to all elements in V
end
  
# Call @functor to allow for training. Described below in more detail.
Flux.@functor MyModel

my_layer = MyLayer(m,r)
chain = Chain([my_layer, my_layer, my_layer]) #three layer network, each layer shares parameters
model(V) = (MyModel(chain))(V)

model (generic function with 1 method)

In [19]:
Flux.params(my_layer)

Params([[0.8369810917757003 1.9484490853346115 … 1.7028301525984841 1.1849310086642832; 3.284342060186844 0.7522631357129886 … 1.380454609814672 1.6805481979043364; … ; 0.945360251000689 1.140117305034255 … 0.515135349502517 0.5831160684308441; 1.1630795992759264 0.8821387410872338 … 0.7217666823970634 1.0283045175298946]])

In [20]:
a = model(train_V)

5-element Vector{Matrix{Float64}}:
 [0.4563601479487161 0.05735049501675889 … 0.37823974090732115 0.04715936995397002; 0.5177975942067371 0.6318323707579199 … 0.4824335929629952 0.4767183550258751; … ; 0.2517072578420455 0.7165398334508908 … 0.3750541894372231 0.5731476511879695; 0.5038194799079283 0.5434531299255546 … 0.9799203856286202 0.2549237596127586]
 [0.9754424491803148 0.14985080268852766 … 0.3315935188437829 0.04711185170546529; 0.9040101232127244 1.6533579015213202 … 0.2928732321965072 0.42155361940300723; … ; 0.49406585419135735 1.8315940046013917 … 0.29782362752953284 0.6410324400109199; 0.9979864585426028 1.3304055180543664 … 0.9303867630634411 0.3395330899980648]
 [1.4553100751526253 0.04245870136254515 … 0.4464299800561515 0.09384326041147444; 1.04146143566033 0.6583061701314502 … 0.5210871014468536 1.193077702367736; … ; 0.6420278819506194 0.5390135192242497 … 0.47179915549946844 1.3663261481158055; 1.5017543009099987 0.3980717848945929 … 1.5085902168591232 0.543487181

In [24]:
# Train the model
using Flux: mse
loss(model, V_input, H_truth) = mean(mse.(model(V_input), H_truth)) # mean squared error
                                                                    # the outer mean is to handle batches
# TODO idealy the model can be applied on batches without needing the elementwise application dot here
#loss(model, train_V, train_H)

using Flux: train!

opt = Descent() #or Adam(), see https://fluxml.ai/Flux.jl/stable/training/optimisers/ for full list

Descent(0.1)

In [25]:
data = [(train_V, train_H)]
train!(loss, model, data, opt) #cant find trainable parameters, TODO need to pas them through MyModel

└ @ Optimisers C:\Users\Nicholas\.julia\packages\Optimisers\BT5bT\src\interface.jl:27


In [None]:
maxit = 10
training_loss = zeros(1,maxit)
testing_loss  = zeros(1,maxit)
for i ∈ 1:maxit
    train!(loss, model, data, opt)
    training_loss[i] = loss(model, train_V, train_H)
    testing_loss[i]  = loss(model, test_V , test_H )
end

In [None]:
# Evaluate the model performance on the training and testing set
plt(i,[training_loss,testing_loss])

In [None]:
# Observe how close the learned weights ̂W and the true W are
learned_W = Flux.params(model)
difference = abs.(learned_W - true_W)
heatmap(difference)