In [1]:
using BSON: @load
using Flux
using Flux: chunk
using Flux.Data: DataLoader
using ImageFiltering
using Images
using ImageIO
using MLDatasets: FashionMNIST
using LinearAlgebra
using MLDatasets
using Plots
using Zygote
using FFTW
using Distributions
using SparseArrays

Consider the denoising problem of recovering $x\in\mathbb{R}^n$ from noisy measurements of the form

$$y = A(x) + \epsilon, $$

where $\epsilon\in\mathbb{R}^n$ is noise. We assume the unknown signal $x$ lives in the range of known generative model $G:\mathbb{R}^k \rightarrow \mathbb{R}^n$, i.e. $x = G(z)$ for some $z \in \mathbb{R}^k$. We assume the generative model $G$ is  fully-connected feedforward network of the form 

$$ G(x) = \sigma_d(A_d\sigma_{d-1}(A_{d-1} \cdots \sigma_1(A_1 z)\cdots)),$$

where $A_i \in \mathbb{R}^{n_i \times n_{i-1}}$ is the weight matrix and $\sigma_i$ is the activation function correpsonding to the $i\text{th}$ layer of $G$. Thus, the task of recovering $x$ can be reduced to recovering the corresponding $z$ such that $G(z) = x$. 


We solve this problem using the following iterative algorithm called Partially Linearized Updates for Generative Inversion (PLUGIn):

$$x^{k+1} = x^k -\eta A_1^{\top}\cdots A_d^{\top}A^\top\left(AG(x^k) - y \right) .$$

Here, $\eta$ is the stepsize that depends on the weight matrices and the activation functions. This algorithm is implemented below:



In [6]:
# output the plugin Iterate
function PLUGIN_denoise(G, W, A, y, z, stepsize)
    d = W'* A'*(A*G(z) - y )
    return z - stepsize * d
end

# output the plugin Iterate
function PLUGIN_denoise_regularized(G, W, A, y, z, stepsize, λ)
    d = W'* A'*(A*G(z) - y ) + 2*λ*z  
    return z - stepsize * d
    println("1")
end

# output the plugin Iterate
function PLUGIN_denoise_normalized(G, W, A, y, z, stepsize, scale)
    d = W'* A'*(A*G(z) - y)
    return (z - stepsize * d) * scale / norm(z - stepsize * d,2)
end

# output the plugin Iterate
function proj(z, scale)
    if norm(z) > scale
        z = scale * z/norm(z)
    end
    return z
end


function relative_error(x_true, x_estimate)
    return norm(x_true - x_estimate)/norm(x_true)
end

relative_error (generic function with 1 method)

One advantage of PLUGIn algorithm in general is that it allows us to pre-multiply the weight matrices. The following function is used to compute $A_d\cdot A_{d-1}\cdots A_1$ given a generative model $G$. It also normalizes each weight matrix w.r.t. its top singular value.

In [7]:
function normalized_weight_product(G)
    (_, z_dim) = size(Flux.params(G[1])[1]);
    W = I(z_dim)
    for i in 1:length(G)
        _, s, _ = svd(Flux.params(G[i])[1])
        W = Flux.params(G[i])[1] * W /(s[1]^2)
    end
    return W
end  

normalized_weight_product (generic function with 1 method)

In [8]:
function create_network(net_param)
    n_0 = net_param[1]
    n_1 = net_param[2]
    L = Chain(Dense(n_0, n_1, relu; initW =(out,in) ->  randn(n_1, n_0)/sqrt(n_1)))

    for i in 2:length(net_param)-1
        n_0 = net_param[i]
        n_1 = net_param[i+1]
        L = Chain(L, Dense(n_0, n_1, relu; initW =(out,in) ->  randn(n_1, n_0)/sqrt(n_1)))
    end
    return L
end

create_network (generic function with 1 method)

In [9]:
#########  SNR  #################
max_iter = 1000
stepsize_PLUGIn = .5
stepsize_GD = .5
noise_level_list = 0:.1:1
trials = 10

m = 500

recov_error_matrix_PLUGIn = zeros(length(noise_level_list))
recon_error_matrix_PLUGIn = zeros(length(noise_level_list))

recov_error_matrix_GD = zeros(length(noise_level_list))
recon_error_matrix_GD = zeros(length(noise_level_list))

recov_error_matrix_PLUGIn_proj = zeros(length(noise_level_list))
recon_error_matrix_PLUGIn_proj = zeros(length(noise_level_list))

recov_error_matrix_GD_proj = zeros(length(noise_level_list))
recon_error_matrix_GD_proj = zeros(length(noise_level_list))

recov_error_matrix_PLUGIn_reg = zeros(length(noise_level_list))
recon_error_matrix_PLUGIn_reg = zeros(length(noise_level_list))

recov_error_matrix_GD_reg = zeros(length(noise_level_list))
recon_error_matrix_GD_reg = zeros(length(noise_level_list))

for trial in 1:trials

    net_param = [20, 200, 750]
    G = create_network(net_param)
    W = normalized_weight_product(G)

    depth = length(Flux.params(G))/2

    (_, z_dim) = size(Flux.params(G)[1]);
    (x_dim, _ ) = size(Flux.params(G)[Int(depth*2 - 1)])
    W = normalized_weight_product(G)

    A = randn(m, 750)/sqrt(m)


    recov_error_PLUGIn = []
    recon_error_PLUGIn = []

    recov_error_GD = []
    recon_error_GD = []

    recov_error_PLUGIn_proj = []
    recon_error_PLUGIn_proj = []

    recov_error_GD_proj = []
    recon_error_GD_proj = []

    recov_error_PLUGIn_reg = []
    recon_error_PLUGIn_reg = []

    recov_error_GD_reg = []
    recon_error_GD_reg = []

    z = randn(z_dim)
    z = z/norm(z) #unit sphere
    
    for noise_level in noise_level_list
        y = A*G(z) + noise_level * randn(x_dim)
        descent(z) = gradient(z->norm(A*G(z) - y,2)^2, z)[1]   
        descent_regularized(z, λ) = gradient(z->norm(A*G(z) - y,2)^2 +λ*norm(z,2)^2, z)[1]        

        z_est_PLUGIn = randn(z_dim)
        z_est_GD = randn(z_dim)

        z_est_PLUGIn_proj = randn(z_dim)
        z_est_GD_proj = randn(z_dim)

        z_est_PLUGIn_reg = randn(z_dim)
        z_est_GD_reg = randn(z_dim)

        iter = 1
        for iter in 1:max_iter
            z_est_PLUGIn = PLUGIN_denoise(G, W, A, y, z_est_PLUGIn, stepsize_PLUGIn)
            d = descent(z_est_GD); z_est_GD -= stepsize_GD *d

            z_est_PLUGIn_proj = PLUGIN_denoise(G, W, A, y, z_est_PLUGIn_proj, stepsize_PLUGIn)
            z_est_PLUGIn_proj = proj(z_est_PLUGIn_proj, 2 * sqrt(20))

            d = descent(z_est_GD_proj); z_est_GD_proj -= stepsize_GD *d; z_est_GD_proj = proj(z_est_GD_proj, 2 * sqrt(20))

            z_est_PLUGIn_reg = PLUGIN_denoise_regularized(G, W, A, y, z_est_PLUGIn_reg, stepsize_PLUGIn, 0.05)

            d = descent_regularized(z_est_GD_reg, 0.05); z_est_GD_reg -= stepsize_GD *d

        end

        push!(recov_error_PLUGIn, relative_error(z, z_est_PLUGIn))
        push!(recon_error_PLUGIn,  relative_error(G(z), G(z_est_PLUGIn)) )

        push!(recov_error_GD, relative_error(z, z_est_GD))
        push!(recon_error_GD, relative_error(G(z), G(z_est_GD)) )  

        push!(recov_error_PLUGIn_proj, relative_error(z, z_est_PLUGIn_proj))
        push!(recon_error_PLUGIn_proj,  relative_error(G(z), G(z_est_PLUGIn_proj)) )

        push!(recov_error_GD_proj, relative_error(z, z_est_GD_proj))
        push!(recon_error_GD_proj, relative_error(G(z), G(z_est_GD_proj)) )  

        push!(recov_error_PLUGIn_reg, relative_error(z, z_est_PLUGIn_reg))
        push!(recon_error_PLUGIn_reg,  relative_error(G(z), G(z_est_PLUGIn_reg)) )

        push!(recov_error_GD_reg, relative_error(z, z_est_GD_reg))
        push!(recon_error_GD_reg, relative_error(G(z), G(z_est_GD_reg)) )  
    end

    recov_error_matrix_PLUGIn = hcat(recov_error_matrix_PLUGIn, recov_error_PLUGIn)
    recon_error_matrix_PLUGIn =  hcat(recon_error_matrix_PLUGIn, recon_error_PLUGIn)

    recov_error_matrix_GD = hcat(recov_error_matrix_GD, recov_error_GD)
    recon_error_matrix_GD = hcat(recon_error_matrix_GD, recon_error_GD)

    recov_error_matrix_PLUGIn_proj = hcat(recov_error_matrix_PLUGIn_proj, recov_error_PLUGIn_proj)
    recon_error_matrix_PLUGIn_proj =  hcat(recon_error_matrix_PLUGIn_proj, recon_error_PLUGIn_proj)

    recov_error_matrix_GD_proj = hcat(recov_error_matrix_GD_proj, recov_error_GD_proj)
    recon_error_matrix_GD_proj = hcat(recon_error_matrix_GD_proj, recon_error_GD_proj)

    recov_error_matrix_PLUGIn_reg = hcat(recov_error_matrix_PLUGIn_reg, recov_error_PLUGIn_reg)
    recon_error_matrix_PLUGIn_reg =  hcat(recon_error_matrix_PLUGIn_reg, recon_error_PLUGIn_reg)

    recov_error_matrix_GD_reg = hcat(recov_error_matrix_GD_reg, recov_error_GD_reg)
    recon_error_matrix_GD_reg = hcat(recon_error_matrix_GD_reg, recon_error_GD_reg)

end


save("result/talk/CS_error_denoise_recov.jld", "error_GD", recov_error_matrix_GD, "error_PLUGIn", recov_error_matrix_PLUGIn, "error_GD_proj", recov_error_matrix_GD_proj, "error_PLUGIn_proj", recov_error_matrix_PLUGIn_proj, "error_GD_reg", recov_error_matrix_GD_reg, "error_PLUGIn_reg", recov_error_matrix_PLUGIn_reg)
save("result/talk/CS_error_denoise_recon.jld", "error_GD", recon_error_matrix_GD, "error_PLUGIn", recon_error_matrix_PLUGIn, "error_GD_proj", recon_error_matrix_GD_proj, "error_PLUGIn_proj", recon_error_matrix_PLUGIn_proj, "error_GD_reg", recon_error_matrix_GD_reg, "error_PLUGIn_reg", recon_error_matrix_PLUGIn_reg)


DimensionMismatch: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(500),), b has dims (Base.OneTo(750),), mismatch at 1")