In [1]:
using Flux, MLDatasets, ImageCore, StatsBase
using MLDatasets: MNIST
using Flux: train!, update!
using Flux: onehot, throttle, crossentropy, relu, sigmoid
using StatsBase: sample

In [2]:
using Flux
using Flux: Data.DataLoader
using Flux: onehotbatch, onecold, crossentropy
using Flux: @epochs
using Statistics
using MLDatasets

# load full training set
train_x, train_y = MNIST.traindata();

train_x_vec = [vec(train_x[:, :, i]) for i = 1:60000];
train_y_hot = [onehot(train_y[i], 0:9) for i = 1:60000];

# load full test set
test_x,  test_y  = MNIST.testdata();

test_x_vec = [vec(test_x[:, :, i]) for i = 1:10000];
test_y_hot = [onehot(test_y[i], 0:9) for i = 1:10000];

traindata = [(train_x_vec[i], train_y_hot[i]) for i = 1:60000];
testdata = [(test_x_vec[i], test_y_hot[i]) for i = 1:10000];

In [3]:
## Train final linear layer with random standard neural network
## Train final linear layer with random convolutional neural network
## Hopefully the second one has a lower loss on average (If not need to rethink some things)
## Train a standard neural network parametrized by a linear function of each layers weights
## Evolve the linear parametrization using an evolutionary algorithm where the fitness is the average loss 
## after training of n (=10?) input parameters. Can the parameters of the linear reparameterization model
## be sparsified?

In [4]:
# TODO
# redefine everything using Julia "destructure" and "restructure" (which is differentiable)
# operating on the entire model rather than just individual layers
# This may not work... but have the custom layer accept a model and a
# reparameterization matrix as a parameter, use the reparameterized model operating
# on x as the method (something like re(m.R * p), where re is restructure, m.R is the 
# reparameterization matrix, and p is the set of parameters returned by destructure)
mutable struct AffineReparam
    input_param::Vector{Float64}
    W_param::Matrix{Float64}
    b_param::Vector{Float64}
    input_size::Int64
end

# Overload call, so the object can be used as a function
# Note that if input_params will not be differentiated, the result of
# the parameter multiplication can be stored and does not need to be constantly recomputed
(m::AffineReparam)(x) = relu.(reshape(m.W_param * m.input_param + m.b_param, :, m.input_size) * x)

In [5]:
mutable struct AffineReparamInputCache
    W_data::Matrix{Float64}
    
    prev_input_param::Vector{Float64}
    input_param::Vector{Float64}
    
    W_param::Matrix{Float64}
    b_param::Vector{Float64}
    
    input_size::Int64
end

AffineReparamInputCache(
    input_param::Vector{Float64}, 
    W_param::Matrix{Float64}, 
    b_param::Vector{Float64},
    input_size::Int64
    ) = AffineReparamInputCache(
        reshape(W_param * input_param + b_param, :, input_size), 
        input_param, 
        input_param, 
        W_param, 
        b_param, 
        input_size
    )

function method(x::Vector, aric::AffineReparamInputCache)
    if (aric.prev_input_param != aric.input_param)
        aric.W_data = reshape(aric.W_param * aric.input_param + aric.b_param, :, aric.input_size)
        aric.prev_input_param = aric.input_param
    end
    return relu.(aric.W_data * x)
end

(m::AffineReparamInputCache)(x::Vector) = method(x, m)

In [24]:
mutable struct AffineReparamWbCache
    W_data::Matrix{Float64}
    
    input_param::Vector{Float64}
    
    W_param::Matrix{Float64}
    prev_W_param::Matrix{Float64}
    
    b_param::Vector{Float64}
    prev_b_param::Vector{Float64}
    
    input_size::Int64
end

AffineReparamWbCache(
    input_param::Vector{Float64}, 
    W_param::Matrix{Float64}, 
    b_param::Vector{Float64},
    input_size::Int64
    ) = AffineReparamWbCache(
        reshape(W_param * input_param + b_param, :, m.input_size), 
        input_param, 
        input_param, 
        W_param,
        W_param,
        b_param,
        b_param,
        input_size
    )

# This has not been tested
function method(x::Vector, aric::AffineReparamWbCache)
    W_changed = aric.prev_W_param != aric.W_param
    b_changed = aric.prev_b_param != aric.b_param
    if (W_changed)
        aric.W_data = reshape(aric.W_param * aric.input_param + aric.b_param, :, aric.input_size)
        aric.prev_W_param = aric.W_param
        aric.prev_b_param = aric.b_param
    elseif (b_changed)
        aric.W_data = reshape(vec(aric.W_data) - aric.prev_b_param + aric.b_param, :, aric.input_size)
        aric.prev_b_param = aric.b_param
    end
    return relu.(aric.W_data * x)
end

(m::AffineReparamWbCache)(x::Vector) = method(x, m)

In [15]:
max_parameters = 10;
input = 784;
hidden_out = 784;

W_param = randn(input*hidden_out, max_parameters);
b_param = randn(input*hidden_out);

In [16]:
input_params1 = (1/max_parameters)*(rand(max_parameters) - 0.5*ones(max_parameters));
input_params2 = (1/(max_parameters*hidden_out))*(rand(max_parameters) - 0.5*ones(max_parameters));

hidden1 = AffineReparam(input_params1, W_param, b_param, input);
hidden2 = AffineReparamInputCache(input_params2, W_param, b_param, input);

Flux.trainable(hidden1::AffineReparamInputCache) = (hidden1.input_param, )
Flux.trainable(hidden2::AffineReparamInputCache) = (hidden2.input_param, )

In [17]:
data_model = Chain(
      hidden2,
      Dense(hidden_out, 10),
      softmax
)

num_hidden_layers = 1


ps_hidden = Flux.params(data_model[1:num_hidden_layers])
ps_out = Flux.params(data_model[num_hidden_layers+1:end])
ps = Flux.params(data_model)

Params([[4.621443019822122e-5, -4.098006512920417e-5, 4.321213379893864e-5, -3.74315847452869e-5, -4.679209992896548e-5, -9.17439277670436e-6, 3.44770393508489e-5, -6.284720797634103e-5, 2.4403976087266596e-5, -1.900403255369029e-5], Float32[0.08512958 0.029645827 … 0.05858114 -0.051549822; -0.05612156 -0.054823354 … 0.040251385 0.06866392; … ; -0.0019142713 -0.02679043 … -0.017604735 0.057830133; 0.0020862932 -0.001985878 … -0.055465598 -0.083531804], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

In [10]:
trainsample = sample(traindata, 6000, replace=false);

In [54]:
data_model = Chain(
    Dense(input, hidden_out, relu),
    Dense(hidden_out, 10),
    softmax
)
ps_out = Flux.params(data_model[num_hidden_layers+1:end])
ps = Flux.params(data_model)

Params([Float32[0.019757971 -0.06254386 … -0.042520065 -0.06680834; -0.027373645 0.036480814 … 0.03772269 0.07769612; … ; 0.013659548 0.07833217 … 0.07321512 0.049626846; -0.0064494917 0.03465075 … -0.031358644 0.046776194], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.17692272 0.33308834 … 0.3083141 0.12433642; 0.38976818 -0.5449518 … -0.13594863 0.51308155; … ; 0.445382 -0.14179541 … 0.47729635 0.3308579; -0.28386572 -0.13055067 … -0.5003385 0.035294075], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

In [18]:
loss(x, y) = Flux.Losses.crossentropy(data_model(x), y)
opt = Descent(.001)

epochs = 1
for epoch in 1:epochs
    loss_total = 0
    for (i, d) in enumerate(trainsample)
        diff_ps = ps_out
        gs = gradient(diff_ps) do
          training_loss = loss(d...)
          # Code inserted here will be differentiated, unless you need that gradient information
          # it is better to do the work outside this block.
          return training_loss
        end
        loss_total += loss(d...)
        if i%100==0
            println(loss_total/i)
        end
        # Insert whatever code you want here that needs training_loss, e.g. logging.
        # logging_callback(training_loss)
        # Insert what ever code you want here that needs gradient.
        # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
        update!(opt, diff_ps, gs)
    end
    print(epoch)
end

13.737920492781297
11.292703476288668
9.687162763882101
8.976711613548247
8.212386323411408
7.662491690856086
7.133717443827318
6.946730588710791
6.748966088249661
6.676684505510985
6.442453409085176
6.436664022720092
6.14534293009462
5.981560920160931
5.716034799171579
5.619660653357848
5.60205436060706
5.55411420492969
5.407255763298242
5.258720546464656
5.1143038645518395
5.063784894200569
4.984624834641118
4.920884895694458
4.849307204589918
4.758002899784046
4.697330896490663
4.6712431141644455
4.638297400090243
4.564294478198112
4.519527232731251
4.434911797064934
4.386404096544667
4.3747374347035635
4.324686194850671
4.306626746313406
4.2785413258261045
4.256428926563204
4.213508861671226
4.213840574720204
4.168278752848328
4.120812441340383
4.142371699890549
4.119482964598424
4.094899818521346
4.061786627986874
4.034886211578239
3.999136116449572
3.9592355422799086
3.9371193897402277
3.9044702020109754
3.876623150586507
3.8464911223410585
3.825075802133967
3.811055239224025
3.8

In [19]:
accuracy(x, y) = mean(onecold.(x) .== onecold.(y))
#accuracy(m.(train_x_vec), train_y_hot)
ps_hidden

Params([[4.621443019822122e-5, -4.098006512920417e-5, 4.321213379893864e-5, -3.74315847452869e-5, -4.679209992896548e-5, -9.17439277670436e-6, 3.44770393508489e-5, -6.284720797634103e-5, 2.4403976087266596e-5, -1.900403255369029e-5]])

In [20]:
accuracy(data_model.(test_x_vec), test_y_hot)

0.8418