In [1]:
using BSON: @load
using Flux
using Flux: chunk
using Flux.Data: DataLoader
using ImageFiltering
using Images
using ImageIO
using MLDatasets: FashionMNIST
using LinearAlgebra
using MLDatasets
using Plots
using Zygote

In [2]:
function PLUGIn_CS(G, y, A, max_iter, stepsize, tolerance, out_toggle)
    
    (_, z_dim) = size(Flux.params(G[1])[1]);
    W = I(z_dim)
  
    #normalize the weights of the network
    for i in 1:length(G)
        _, s, _ = svd(Flux.params(G[i])[1])
        W = Flux.params(G[i])[1] * W /s[1]
    end
  
    z = randn(z_dim)
    iter = 1
    succ_error = 1
  
    while iter <= max_iter && succ_error > tolerance
      
      # d gives the PLUGIn direction
      d = W'*A'*(A * G(z) - y)
      z -= stepsize * d
      succ_error = norm(stepsize * d)
      if iter % out_toggle == 0  
          println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
      end
      iter += 1
    end
    println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
  
    return z
  end

PLUGIn_CS (generic function with 1 method)

In [3]:
function GD_CS(G, y, A, max_iter, stepsize, tolerance, out_toggle)

    z = randn(20)
    iter = 1
    succ_error = 1
    while iter <= max_iter && succ_error > tolerance
        # d gives the PLUGIn direction
        d = gradient(z -> norm(y - A*G(z)), z)[1]
        z -= stepsize * d
        succ_error = norm(stepsize * d)
        if iter % out_toggle == 0  
            println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
        end
        iter += 1
    end
    println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")

    return z
end

GD_CS (generic function with 1 method)

In [4]:
#setup a synthetic problem
G = Chain(
    Dense(20, 500, relu, bias = false; initW =(out,in) ->  randn(500, 20)/sqrt(500)),
    Dense(500, 500, relu, bias = false; initW =(out,in) -> randn(500, 500)/sqrt(500)),
    Dense(500, 784, relu, bias = false; initW =(out,in) -> randn(784, 500)/sqrt(784))
)


z = randn(20)
m = 300; A = randn(m, 784)/sqrt(m)
y = A*G(z) + 1e-14 * randn(m)

stepsize = 2
tolerance = 1e-14
max_iter = 10000
out_toggle = 1000
z_rec = PLUGIn_CS(G,y,A, max_iter, stepsize, tolerance, out_toggle)
recov_error = norm(z - z_rec)
recon_error = norm(G(z) - G(z_rec))
println("recovery error: $recov_error, reconstruction error: $recon_error")

====> In quasi-gradient: Iteration: 1000 Successive error: 1.1941512463987172e-6
====> In quasi-gradient: Iteration: 2000 Successive error: 3.1535913288548917e-10
====> In quasi-gradient: Iteration: 3000 Successive error: 7.572969627436639e-14
====> In quasi-gradient: Iteration: 3270 Successive error: 9.963993939425636e-15
recovery error: 7.194882961768809e-13, reconstruction error: 2.0801248898565373e-13


In [5]:
# experiments with MNIST dataset
function load_model(load_dir::String, epoch::Int)
    print("Loading model...")
    @load joinpath(load_dir, "model-$epoch.bson") encoder_μ encoder_logvar decoder
    println("Done")
    return encoder_μ, encoder_logvar, decoder
end

load_model (generic function with 1 method)

In [6]:
function get_train_loader(batch_size, shuffle::Bool)
    # The MNIST training set is made up of 60k 28 by 28 greyscale images
    train_x, train_y = MNIST.traindata(Float32)
    train_x = 1 .- reshape(train_x, (784, :))
    return DataLoader((train_x, train_y), batchsize=batch_size, shuffle=shuffle, partial=false)
end

get_train_loader (generic function with 1 method)

In [7]:
epoch_to_load = 20
# Load the model and test set loader
#encoder_mu, encoder_logvar, decoder = load_model("result", epoch_to_load)
encoder_mu, encoder_logvar, decoder = load_model("result/MNIST", epoch_to_load)

Loading model...Done


(Chain(Chain(Dense(784, 500, relu), Dense(500, 500, relu)), Dense(500, 20)), Chain(Chain(Dense(784, 500, relu), Dense(500, 500, relu)), Dense(500, 20)), Chain(Dense(20, 500, relu; bias=false), Dense(500, 500, relu; bias=false), Dense(500, 784, σ; bias=false)))

In [9]:
# pick a image in MNIST to denoise
num = 8
batch_size = 64
shuffle_data = true
dataloader = get_train_loader(batch_size, shuffle_data)

stdin>  y


This program has requested access to the data dependency MNIST.
which is not currently installed. It can be installed automatically, and you will not see this message again.

Dataset: THE MNIST DATABASE of handwritten digits
Authors: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
Website: http://yann.lecun.com/exdb/mnist/

[LeCun et al., 1998a]
    Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner.
    "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE, 86(11):2278-2324, November 1998

The files are available for download at the offical
website linked above. Note that using the data
responsibly and respecting copyright remains your
responsibility. The authors of MNIST aren't really
explicit about any terms of use, so please read the
website to make sure you want to download the
dataset.



Do you want to download the dataset from ["https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "https://ossci-datasets.s3.amazonaws.com/mn

DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG}((Float32[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; … ; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 64, 60000, false, 59937, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  59991, 59992, 59993, 59994, 59995, 59996, 59997, 59998, 59999, 60000], true, Random._GLOBAL_RNG())

In [None]:
(x_batch, y_batch) = first(dataloader)
i = 1
while y_batch[i] != num
    i += 1
end

x = x_batch[:,i]
noise_level = .1

m = 300; A = randn(m, 784)/sqrt(m)

y = A*x + noise_level * randn(m)

stepsize = 1
tolerance = 1e-7
max_iter = 10000
out_toggle = 1000
z_rec_PLUGIn = PLUGIn_CS(decoder, y, A, max_iter, stepsize, tolerance, out_toggle)
error = norm(x - decoder(z_rec_PLUGIn))
println("reconstruction error: $error")

z_rec_GD = GD_CS(decoder, y, A, max_iter, stepsize, tolerance, out_toggle)

recovered_image_PLUGIn = colorview(Gray, reshape(decoder(z_rec_PLUGIn), 28,28)' )
recovered_image_GD = colorview(Gray, reshape(decoder(z_rec_GD), 28,28)' )

true_image = colorview(Gray, reshape(x, 28,28)' );
p1 = plot(true_image, framestyle = :none, bg =:black, title = "original image")
p2 = plot(recovered_image_PLUGIn, framestyle = :none, bg =:black, title = "recovered image PLUGIn,\n m = $m")
p3 = plot(recovered_image_GD, framestyle = :none, bg =:black, title = "recovered image GD,\n m = $m")
plot(p1, p2, p3, layout = 3)

====> In quasi-gradient: Iteration: 1000 Successive error: 0.053126221360445224
====> In quasi-gradient: Iteration: 2000 Successive error: 0.05623084468364893
====> In quasi-gradient: Iteration: 3000 Successive error: 0.05459704474526513
====> In quasi-gradient: Iteration: 4000 Successive error: 0.055005210686865535
====> In quasi-gradient: Iteration: 5000 Successive error: 0.05638883476486583
====> In quasi-gradient: Iteration: 6000 Successive error: 0.056799390446632396
====> In quasi-gradient: Iteration: 7000 Successive error: 0.05704838245217213
====> In quasi-gradient: Iteration: 8000 Successive error: 0.0572244139312179
====> In quasi-gradient: Iteration: 9000 Successive error: 0.05735688982888102
====> In quasi-gradient: Iteration: 10000 Successive error: 0.057462034333574497
====> In quasi-gradient: Iteration: 10001 Successive error: 0.057462034333574497
reconstruction error: 6.0879239926856465
====> In quasi-gradient: Iteration: 1000 Successive error: 0.9167083536621017
====> 