# Compressive sensing example


In [2]:
from __future__ import print_function
import matplotlib.pyplot as plt
#%matplotlib notebook

import os
import sigpy.mri as mr

import sigpy as sp
import sigpy.mri as mr
from os import listdir
from os.path import isfile, join

import warnings
warnings.filterwarnings('ignore')

from include import *
from PIL import Image
import PIL

import numpy as np
import torch
import torch.optim
from torch.autograd import Variable
#from models import *
#from utils.denoising_utils import *

GPU = True
if GPU == True:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print("num GPUs",torch.cuda.device_count())
else:
    dtype = torch.FloatTensor

num GPUs 0


## Load example image

In [None]:
def crop_center(img,cropx,cropy):
    #y,x = img.shape
    y = img.shape[-2]
    x = img.shape[-1]
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)
    if len(img.shape) == 2:
        return img[starty:starty+cropy,startx:startx+cropx]
    if len(img.shape) == 3:
        return img[0,starty:starty+cropy,startx:startx+cropx]

path = './test_data/'
img_name = "poster"
#img_name = "F16_GT"
#img_name = "sf4_rgb"
#img_name  = 'library'
img_path = path + img_name + ".png"

img_pil = Image.open(img_path)
img_np = pil_to_np(img_pil)

img_np_small = np.array([crop_center(img_np[0],128,128)])
img_var = np_to_var(img_np_small).type(dtype)
output_depth = img_np.shape[0]

## Define forward model

In [None]:
X = img_var.view(-1, np.prod(img_var.shape) )
n = X.shape[1]
m = int(n/3)
A = torch.empty(n,m).uniform_(-1, 1).type(dtype)
A *= 1/np.sqrt(m)

def forwardm(img_var):
    X = img_var.view(-1 , np.prod(img_var.shape) ) 
    return torch.mm(X,A)

measurement = forwardm(img_var)

## DD reconstruction and helper functions

In [None]:
def get_net_input(num_channels,w=128,h=128):
    totalupsample = 2**len(num_channels)
    width = int(128/totalupsample)
    height = int(128/totalupsample)
    shape = [1,num_channels[0], width, height]
    net_input = Variable(torch.zeros(shape)).type(dtype)
    net_input.data.uniform_()
    net_input.data *= 1./10
    return net_input

def get_random_img(num_channels,ni=None):
    if ni is None:
        ni = get_net_input(num_channels)
    net = decodernw(1,num_channels_up=num_channels,need_sigmoid=True).type(dtype)
    print("generated random image with", num_channels, " network has ", num_param(net) )
    return net(ni)

def myimgshow(plt,img):
    if(img.shape[0] == 1):
        plt.imshow(np.clip(img[0],0,1),cmap='gray')
    else:
        plt.imshow(np.clip(img.transpose(1, 2, 0),0,1))
    plt.axis('off')    
    
def plot_img(img_ref): 
    fig = plt.figure(figsize = (15,15)) # create a 5 x 5 figure   
    ax1 = fig.add_subplot(231)
    ax1.imshow(img_ref,cmap='gray')
    #ax1.set_title('Original image')
    ax1.axis('off')
    
def init_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            #m.weight.data.uniform_()
            #torch.nn.init.xavier_uniform(m.weight)
            #nn.init.uniform_(m.weight)
            torch.nn.init.normal_(m.weight)

def snr(x_hat,x_true):
    x_hat = x_hat.flatten()
    x_true = x_true.flatten()
    mse= np.sum( np.square(x_hat-x_true) )
    #snr_ = 10.*np.log(maxv**2/mse)/np.log(10.)
    snr_ = mse / np.sum( np.square(x_true) )
    return snr_

In [None]:
def dd_recovery(measurement,img_var,num_channnels,num_iter=6000,apply_f=forwardm,ni=None):
    net = decodernw(1,num_channels_up=num_channels,need_sigmoid=True).type(dtype)
    #net.apply(init_weights)
    mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                                net_input=ni,
                        reg_noise_std=0.0,num_iter=num_iter,LR = 0.005,
                        img_noisy_var=measurement.type(dtype),
                        net=net,apply_f = apply_f,img_clean_var=img_var.type(dtype),
                        upsample_mode='bilinear',
                        )
    print(num_param(net))
    out_img_var = net( ni.type(dtype) )
    return out_img_var

## Example reconstruction

This demonstrates that reconstruction with a deep decoder works well, but a deconvolutional decoder does not enable good reconstructions.

In [None]:
k=22
num_channels = [k]*4
measurement = forwardm(img_var).type(dtype)
out_img_var = dd_recovery(measurement,img_var,num_channels)

In [None]:
def dconv_recovery(img_var):
    measurement = forwardm(img_var).type(dtype)
    num_channels = [6]*6
    net = deconv_decoder(1,num_channels_up=num_channels,filter_size=4,stride=2,padding=1).type(dtype)
    mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                        reg_noise_std=0.0,num_iter=5000,LR = 0.0025,
                        img_noisy_var=measurement,
                        net=net,apply_f = forwardm,img_clean_var=img_var.type(dtype),
                        upsample_mode='deconv' )
    print(num_param(net))
    out_img_var = net( ni.type(dtype) )
    return out_img_var

out_img_dc_var = dconv_recovery(img_var)

In [None]:
plot_img(img_var.data.cpu().numpy()[0,0])
#plot_img(out_img_var.data.cpu().numpy()[0,0])
#plot_img(out_img_dc_var.data.cpu().numpy()[0,0])

def savefig(filename,img):
    plt.imshow(img,cmap='gray')
    plt.axis('off')
    plt.savefig(filename,bbox_inches='tight')
    
savefig(img_name + '_orig.png',img_var.data.cpu().numpy()[0,0])

# Compressive sensing on random images

Our main result shows that taking random linear measurements on the order of the number of parameters of the deep decoder is suffient for recovery is possible. In order to see whether that is also necessary and thus the number of parameters captures the complexity of the range space of the deep decoder, we conduct the following experiment to recover an image in the range of the deep decoder.

In order to generate an image, we can in principle simply choose the coefficients of the deep decoder at random. However, for a deep decoder with a fixed number of parameters, this tends to generate simple images, in that often a deep decoder with much fewer coefficients can represent it well. To ensure that we generate a sufficiently complex image, we generate an image in the range of the generator by finding the best representation of noise with the deep decoder. 

In [None]:
numpoints = 8
ms = [ int(100*np.exp(5.5/numpoints*i)) for i in range(numpoints) ] #[100,200,,17000]
print(ms)
ks = [10,20,30,50,150,250]
err = np.zeros((len(ms), len(ks)))

numit = 10

for q in range(numit):
    for j,m in enumerate(ms):
        for ell,k in enumerate(ks):
            # generate input
            num_channels = [k]*4
            ni = get_net_input(num_channels)
        
            # get random noise, and find approximation to it in the range of the generator
            img_var.data.uniform_()
            img_approx = Variable(dd_recovery(img_var,img_var,num_channels,ni=ni,apply_f=None,num_iter=3000))

            print("number useful variables / number observations", (k**2*4 + k) /m)
            print("number observations / number of variables", m/n)
            print("m,n,nump",m,n,k**2*4 + k)
            
            # generate random matrix
            A = 10*torch.empty(n,m).normal_(0, 1/np.sqrt(m)).type(dtype)
            
            def forwardm(img):
                X = img.view(-1 , np.prod(img.shape) )
                return torch.mm(X,A)

            measurement = forwardm(img_approx).type(dtype)
            out_img_var = dd_recovery(measurement,img_approx,num_channels,ni=ni,apply_f=forwardm,num_iter=10000)
    
            #plot_img(img_approx.data.cpu().numpy()[0,0])
            #plot_img(out_img_var.data.cpu().numpy()[0,0])
    
            error = snr(out_img_var.data.cpu().numpy()[0] , img_approx.data.cpu().numpy()[0])
            print("error: ", error, "\n")
            err[j,ell] += error/numit

In [None]:
# plot and save
plt.xscale('log')
for i,c in enumerate(['b','r','g','y','b']):
    plt.plot(ms,err[:,i],c)
plt.show()

np.savetxt("csrandimg_"+img_name+".csv", np.vstack([ np.array(ms) ,np.array(err).T]).T , delimiter="\t")

## Compressive sensing on a natural image for varying number of parameters and number of measurements

In [None]:
# get a small image
img_name = "poster" # "F16_GT"
img_path = path + img_name + ".png"
img_pil = Image.open(img_path)
img_np = pil_to_np(img_pil)
img_np_small = np.array([crop_center(img_np[0],128,128)])
img_var = np_to_var(img_np_small).type(dtype)

numpoints = 8
ms = [ int(100*np.exp(5.5/numpoints*i)) for i in range(numpoints) ] #[100,200,,17000]
ks = [10,20,30,50,150,250]

err2 = np.zeros((len(ms), len(ks)))

numit = 10

for q in range(numit):
    for j,m in enumerate(ms):
        for ell,k in enumerate(ks):
        
            # generate fixed input
            num_channels = [k]*4
            ni = get_net_input(num_channels)
        
            #print("number useful variables / number observations", num_param(net)/m)
            print("number useful variables / number observations", (k**2*4 + k) /m)
            print("number observations / number of variables", m/n)
            print("m,n,nump",m,n,k**2*4 + k)

            A = 10*torch.empty(n,m).normal_(0, 1/np.sqrt(m)).type(dtype)
            
            def forwardm(img):
                X = img.view(-1 , np.prod(img.shape) )
                return torch.mm(X,A)
            
            # take measurement of original image
            measurement = forwardm(img_var).type(dtype)
            out_img_var = dd_recovery(measurement,img_var,num_channels,ni=ni,apply_f=forwardm,num_iter=6000)
        
            error = snr(out_img_var.data.cpu().numpy()[0] , img_var.data.cpu().numpy()[0])
            print("error: ", error, "\n")
            err2[j,ell] += error/numit

In [None]:
# plot and save
plt.xscale('log')
for i,c in enumerate(['b','r','g','y','b','o']):
    plt.plot(ms,err2[:,i],c)
plt.show()

np.savetxt("csf16img_"+img_name+".csv", np.vstack([ np.array(ms) ,np.array(err2).T]).T , delimiter="\t")