In [1]:
using Zygote
using Flux: σ, softmax, logitcrossentropy, Chain, Optimise, onehotbatch, onecold, Dense
using Flux.Data: MNIST
using SparseArrays
using StatsBase: sample, shuffle, mean
import Base: broadcast, broadcasted
#using Plots

┌ Info: CUDAdrv.jl failed to initialize, GPU functionality unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)
└ @ CUDAdrv /home/olszewskip/.julia/packages/CUDAdrv/mCr0O/src/CUDAdrv.jl:69


In [2]:
labels = MNIST.labels();
label_pool = sort(Vector(collect(Set(labels))));
y = onehotbatch(labels, label_pool);

imgs = MNIST.images();
# lighter_than(pixel, threshold) = pixel.val > threshold
# function lightness_threshold(pixel_array::Array{T, 2}, threshold) where T
#     convert.(T, lighter_than.(pixel_array, threshold))
# end
# imgs_blackenwhite = lightness_threshold.(imgs, 0.7);
#X = hcat(convert.(Array{Float32, 1}, reshape.(imgs_blackenwhite, :))...);
X = hcat(convert.(Array{Float32, 1}, reshape.(imgs, :))...);

index_train = floor(Int, 0.7 * size(X, 2))
index_val = index_train + floor(Int, 0.2 * size(X, 2))
X_train = X[:, 1:index_train];
y_train = y[:, 1:index_train];
X_val = X[:, index_train+1:index_val];
y_val = y[:, index_train+1:index_val];
X_test = X[:, index_val+1:end];
y_test = y[:, index_val+1:end];

In [3]:
mutable struct SprAffine{S,T,F}
    W::S
    b::T
    σ::F
end

function Base.show(io::IO, layer::SprAffine)
  print(io, "SprAffine(", size(layer.W, 2), ", ", size(layer.W, 1))
  print(io, ", ", round(length(layer.W.nzval) / length(layer.W), sigdigits=6))
  layer.σ == identity || print(io, ", ", layer.σ)
  print(io, ")")
end

SprAffine(in::Number, out::Number, frac::AbstractFloat, σ::Function=identity) =
    #SprAffine(spones(out, in, frac), spzeros(out), fraq, σ)
    SprAffine(sprandn(Float32, out, in, frac), zeros(Float32, out), σ)

(layer::SprAffine)(x) = layer.σ.(layer.W * x .+ layer.b)

In [4]:
function loss(model, X, y)
    return logitcrossentropy(model(X), y)
end
accuracy(model, X, y) = mean(onecold(softmax(model(X))) .== onecold(y));

In [5]:
function get_model(in, hidden, out, fraction)
    return Chain(
            SprAffine(in, hidden, fraction),
            SprAffine(hidden, out, fraction),
           );
end

get_model (generic function with 1 method)

In [6]:
model = get_model(size(X, 1), 100, size(y, 1), 0.1)
loss(model, sparse(X[:, 1:32]), y[:, 1:32])
accuracy(model, sparse(X[:, 1:32]), y[:, 1:32])

0.09375

In [12]:
X_sp = sparse(X);

In [15]:
loss_local, back = pullback(model) do model
    loss(model,
         X_sp[:, 1:32],
         y[:, 1:32])
end;
grads = back(1)[1];

In [18]:
opt = Optimise.Descent()

Flux.Optimise.Descent(0.1)

In [33]:
I,J,_ = findnz(model[1].W)
weights_elems = Set(zip(I, J))

Set(Tuple{Int64,Int64}[(1, 151), (50, 506), (27, 552), (75, 782), (30, 197), (3, 694), (39, 486), (13, 569), (78, 513), (51, 182)  …  (52, 195), (70, 266), (58, 37), (89, 292), (88, 487), (22, 81), (83, 162), (39, 244), (76, 131), (35, 224)])

In [34]:
I,J,_ = findnz(getfield(grads.layers[1], 1).W)
grad_elems = Set(zip(I, J))

Set(Tuple{Int64,Int64}[(69, 598), (29, 164), (96, 716), (95, 522), (47, 153), (70, 631), (59, 479), (69, 239), (85, 547), (96, 544)  …  (80, 458), (97, 320), (86, 595), (73, 552), (59, 490), (2, 262), (22, 681), (83, 162), (21, 220), (39, 447)])

In [36]:
collect(intersect(weights_elems, grad_elems))[1]

(11, 444)

In [40]:
model[1].W[11, 444]

1.8742305f0

In [38]:
getfield(grads.layers[1], 1).W[11, 444]

-0.011893435f0

In [39]:
update_!(opt, model[1].W, getfield(grads.layers[1], 1).W);

In [7]:
function broadcast(-, arr_sink::AbstractSparseMatrix{Tv,Ti},
                      arr_source::AbstractSparseMatrix{Tv,Ti}) where {Tv,Ti}
    for (col_index, (ptr_sink, end_sink)) in enumerate(zip(arr_sink.colptr[1:end-1], arr_sink.colptr[2:end]))
        ptr_source = arr_source.colptr[col_index]
        end_source = arr_source.colptr[col_index + 1]
        while ptr_sink < end_sink && ptr_source < end_source
            #println(ptr_sink, " ", arr_sink.rowval[ptr_sink], " ", ptr_source, " ", arr_source.rowval[ptr_source])
            if arr_sink.rowval[ptr_sink] < arr_source.rowval[ptr_source]
                #println("sink up")
                ptr_sink += 1
                continue
            end
            if arr_sink.rowval[ptr_sink] > arr_source.rowval[ptr_source]
                #println("source up")
                ptr_source += 1
                continue
            end
            row_index = arr_sink.rowval[ptr_sink]
            #println(row_index, " ", col_index, " ", arr_sink.nzval[ptr_sink], " ", arr_source.nzval[ptr_source])
            arr_sink.nzval[ptr_sink] -= arr_source.nzval[ptr_source]
            ptr_sink += 1
            ptr_source += 1
        end
    end
    return arr_sink
end

function broadcasted(-, arr_sink::AbstractSparseMatrix{Tv,Ti},
                        arr_source::AbstractSparseMatrix{Tv,Ti}) where {Tv,Ti}
    broadcast(-, arr_sink, arr_source)
end

broadcasted (generic function with 94 methods)

In [8]:
update_!(opt, model, grads::Nothing) = model

# function update_!(opt::Optimise.Descent, arr::AbstractSparseMatrix, d_arr::AbstractSparseMatrix)
#     for (col_index, (ptr_sink, end_sink)) in enumerate(zip(arr.colptr[1:end-1], d_arr.colptr[2:end]))
#         ptr_source = d_arr.colptr[col_index]
#         end_source = d_arr.colptr[col_index + 1]
#         while ptr_sink < end_sink && ptr_source < end_source
#             if arr.rowval[ptr_sink] < d_arr.rowval[ptr_source]
#                 #println("sink up")
#                 ptr_sink += 1
#                 continue
#             end
#             if arr.rowval[ptr_sink] > d_arr.rowval[ptr_source]
#                 #println("source up")
#                 ptr_source += 1
#                 continue
#             end
#             arr.nzval[ptr_sink] -= d_arr.nzval[ptr_source] * opt.eta
#             ptr_sink += 1
#             ptr_source += 1
#         end
#     end
# end

function update_!(opt, arr::AbstractArray, d_arr::AbstractArray)
    #print("abs_arr_update ")
    Optimise.apply!(opt, arr, d_arr)
    arr .-= d_arr
    return arr
end

function update_!(opt, model::Chain, grads) 
    for (layer, d_layer) in zip(model.layers, grads.layers)
        d_layer = getfield(d_layer, 1)
        @assert nfields(layer) == nfields(d_layer)
        for field_index in 1:nfields(layer)
            field = getfield(layer, field_index)
            d_field = getfield(d_layer, field_index)
            update_!(opt, field, d_field)
        end
    end
#     #print("down ")
#     @assert nfields(model) == nfields(grads)
#             ["nfields(model) $(nfields(model)) ≠ nfields(grads) $(nfields(grads))"]
#     for field_idx in 1:nfields(model)
#         field = getfield(model, field_idx)
#         d_field = getfield(grads, field_idx)
#         update_!(opt, field, d_field)
#     end
    return model
end

update_! (generic function with 3 methods)

In [9]:
augment!(x, y) = nothing

function augment!(layer::SprAffine, frac::AbstractFloat)
    # redraw fraction of the layer's Weights
    I, J, V = findnz(layer.W)
    #println(length(V))
    len_augmented = floor(Int, frac * length(V))
    len_preserved = length(V) - len_augmented
    #println(len_augmented, " ", len_preserved)
    elems = Set(zip(I, J))
    indices = sortperm(abs.(V), rev=true)[1:len_preserved]
    I = I[indices]
    J = J[indices]
    V = V[indices]
    I_ = similar(I, len_augmented)
    J_ = similar(J, len_augmented)
    V_ = similar(V, len_augmented)
    index = 1
    while index <= len_augmented
        i = sample(1:size(layer.W, 1))
        j = sample(1:size(layer.W, 2))
        if (i,j) in elems
            continue
        end
        push!(elems, (i, j))
        I_[index] = i
        J_[index] = j
        V_[index] = randn()
        index += 1
    end
    append!(I, I_)
    append!(J, J_)
    #println(length(V))
    append!(V, V_)
#     for index in 1:length(layer.W.nzval)
#         layer.W.nzval[index] = 0.
#     end
#     droptol!(layer.W, Inf, trim=true)
    layer.W = sparse(I, J, V, size(layer.W)...)
    #println(length(V))
    #println(length(layer.W.nzval))
    #println(" ---")
    return layer
end

augment! (generic function with 2 methods)

In [10]:
function train_augmenting!(opt, loss, model, (X, y), (X_val, y_val), batch_size, num_epochs,
                           augmentation_period, augmentation_fraction)
    X_val_sp = sparse(X_val)
    for epoch in 1:num_epochs
        println("Epoch: $epoch")
        perm = shuffle(1:size(X, 2))
        X_sp = sparse(X[:, perm])
        y_ = y[:, perm]
        batch_indices = 1:batch_size:(size(X, 2) - batch_size + 1)
        for (index, batch_index) in enumerate(batch_indices)
            #print("Batch $index/$(length(batch_indices)) ")
            loss_local, back = pullback(model) do model
                loss(model,
                     X_sp[:, 1:batch_size],
                     y_[:, 1:batch_size])
            end;
            #print(round(loss_local, digits=4), " ")
            grads = back(1)[1];
            #print("g ")
            update_!(opt, model, grads)
            acc = accuracy(model, X_val_sp, y_val)
            print(round(acc, digits=4), " ")
            #print("u ")
            #println(length(model[1].W.nzval) / length(model[1].W), " ")
            is_very_last_batch = (epoch == num_epochs) && (batch_index > size(X, 2) - 2batch_size)
            if (index % augmentation_period == 0) && !is_very_last_batch
                for layer in model.layers
                    augment!(layer, augmentation_fraction)
                end
                print("\n")
                #println("Augmented $augmentation_fraction of weights.")
            end
        end
    end
end

train_augmenting! (generic function with 1 method)

In [41]:
model = get_model(size(X, 1), 100, size(y, 1), 0.1)
train_augmenting!(Optimise.Descent(0.1), loss, model, (X, y), (X_val, y_val),
                  64, 1, 6, 0.1)

Epoch: 1
0.1004 0.1088 0.1181 0.1255 0.1313 0.138 
0.1403 0.1461 0.153 0.1597 0.1664 0.1722 
0.1687 0.1731 0.1772 0.1807 0.1852 0.1898 
0.1732 0.1793 0.1849 0.1897 0.194 0.2014 
0.1707 0.1758 0.1851 0.1932 0.1973 0.2039 
0.2186 0.2231 0.2284 0.2307 0.2348 0.2361 
0.1933 0.206 0.2135 0.2192 0.2243 0.2292 
0.2372 0.2485 0.2538 0.2573 0.2608 0.2643 
0.2488 0.2522 0.2544 0.2559 0.2569 0.2593 
0.2359 0.2499 0.2577 0.2647 0.269 0.2714 
0.2832 0.2868 0.2888 0.2897 0.2917 0.2925 
0.2813 0.2862 0.2862 0.2858 0.2872 0.2876 
0.2887 0.2999 0.3052 0.3098 0.313 0.3142 
0.2486 0.2521 0.2552 0.2584 0.2607 0.2638 
0.2288 0.2368 0.2407 0.2471 0.2517 0.2536 
0.2612 0.2669 0.2699 0.2721 0.2751 0.277 
0.2698 0.2706 0.2745 0.2786 0.2797 0.2821 
0.2885 0.2933 0.2988 0.3029 0.3052 0.3083 
0.2921 0.3018 0.3065 0.3117 0.3164 0.3181 
0.2741 0.279 0.2832 0.2875 0.2921 0.2955 
0.2818 0.2848 0.2903 0.2938 0.2981 0.2997 
0.2893 0.2982 0.305 0.3094 0.3132 0.3168 
0.3303 0.3324 0.3352 0.3358 0.337 0.3402 
0.3137 0.320

InterruptException: InterruptException:

In [176]:
model = get_model(size(X, 1), 100, size(y, 1), 0.1)
train_augmenting!(Optimise.Descent(0.1), loss, model, (X, y), (X_val, y_val),
                  64, 1, 16, 0.5)

Epoch: 1
0.1572 0.1625 0.1689 0.1762 0.18 0.1831 0.1843 0.1878 0.1905 0.1938 0.1971 0.2008 0.2052 0.2088 0.2135 0.2158 
0.1507 0.1572 0.164 0.1698 0.1743 0.1794 0.1835 0.1883 0.1925 0.1964 0.2006 0.2042 0.2078 0.2106 0.2133 0.2178 
0.2382 0.2448 0.2452 0.2502 0.2525 0.2568 0.2615 0.2672 0.2737 0.2792 0.2855 0.292 0.2966 0.3006 0.3056 0.3098 
0.2362 0.2418 0.2439 0.2507 0.2576 0.2638 0.2713 0.2782 0.2844 0.2917 0.2967 0.301 0.3053 0.3098 0.3136 0.3159 
0.2099 0.221 0.2312 0.2418 0.2512 0.2583 0.2638 0.2696 0.2751 0.2777 0.281 0.2852 0.2889 0.2902 0.2921 0.2949 
0.2578 0.2673 0.2733 0.2805 0.2864 0.2945 0.3026 0.3104 0.3167 0.323 0.3282 0.3312 0.3362 0.3398 0.344 0.3452 
0.2419 0.2656 0.2853 0.2977 0.3051 0.3093 0.3132 0.3177 0.3219 0.3268 0.3312 0.3356 0.3392 0.3442 0.3481 0.3509 
0.2558 0.2774 0.2995 0.3192 0.3264 0.3355 0.3398 0.3445 0.3485 0.3502 0.3529 0.3548 0.3562 0.3581 0.3597 0.3621 
0.1605 0.1744 0.1862 0.1984 0.2089 0.2185 0.2252 0.2314 0.2358 0.2408 0.2456 0.2498 0.2536 0.256

InterruptException: InterruptException:

In [None]:
train_augmenting!(Optimise.Descent(0.2), loss, model, (X, y), (X_val, y_val),
                  64, 1, 5, 0.2)

In [158]:
accuracy(model, sparse(X_test), y_test)

0.625

In [136]:
Array(model(sparse(X)))

10×60000 Array{Float32,2}:
 -15.2978    39.3126     0.471826  …   11.5914     27.8094    16.3412  
 -14.4236   -31.6884   -19.7033       -24.9329    -18.037    -15.8129  
   1.94016  -64.7316     8.30274      -31.5712    -10.1424   -36.3982  
  21.5777    21.9106    12.6699       -16.7958      9.84951    0.533863
 -49.4807   -50.8089   -13.4105       -28.6316    -32.741    -12.1242  
  11.3429     1.226     -2.5209    …    9.23097    29.3412    -6.88847 
 -12.4105   -22.5244    20.2417        -0.273212    5.67532    7.27039 
  -2.42519   28.6859    -4.85765       -6.10574     7.92254   29.8143  
   3.65144   -6.33468    8.74955       15.3897    -11.0637    -2.80236 
 -14.0099     7.8312     0.470142      15.9757     11.3657    12.9205  

In [133]:
onecold(y[:, 1:10])

10-element Array{Int64,1}:
  6
  1
  5
  2
 10
  3
  2
  4
  2
  5

In [44]:
# model = get_model()
# train_augmenting!(Optimise.Descent(0.3), loss, model, (X, y), 128, 10, 10, 0.3)

Batch 1/468 Loss: 12.903942
Batch 2/468 Loss: 7.998606
Batch 3/468 Loss: 5.978869
Batch 4/468 Loss: 4.7153387
Batch 5/468 Loss: 3.8186448
Batch 6/468 Loss: 3.1671693
Batch 7/468 Loss: 2.6950784
Batch 8/468 Loss: 2.348345
Batch 9/468 Loss: 2.077857
Batch 10/468 Loss: 1.8604236
Augmented 0.3 of weights.
Batch 11/468 Loss: 6.780545
Batch 12/468 Loss: 4.6947923
Batch 13/468 Loss: 3.576496
Batch 14/468 Loss: 3.0275183
Batch 15/468 Loss: 2.3009427
Batch 16/468 Loss: 1.9193074
Batch 17/468 Loss: 1.6593554
Batch 18/468 Loss: 1.4660087
Batch 19/468 Loss: 1.3267179
Batch 20/468 Loss: 1.2275671
Augmented 0.3 of weights.
Batch 21/468 Loss: 6.5453057
Batch 22/468 Loss: 4.088305
Batch 23/468 Loss: 2.9599335
Batch 24/468 Loss: 1.8741505
Batch 25/468 Loss: 1.1792691
Batch 26/468 Loss: 0.8042897
Batch 27/468 Loss: 0.651187
Batch 28/468 Loss: 0.50910723
Batch 29/468 Loss: 0.408691
Batch 30/468 Loss: 0.34578425
Augmented 0.3 of weights.
Batch 31/468 Loss: 8.25316
Batch 32/468 Loss: 3.3610651
Batch 33/468

InterruptException: InterruptException: