# 1 Compressed Sensing using Deep Decoders

## 1.1 Importing libraries

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
import os
import warnings
warnings.filterwarnings('ignore')
from include import *
from PIL import Image
import PIL
import pywt
import numpy as np
import torch
import torch.optim
from torch.autograd import Variable
from sklearn import linear_model

GPU = True
if GPU == True:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print("num GPUs",torch.cuda.device_count())
    device = 'cuda'
    if torch.cuda.device_count()==0:
        dtype = torch.FloatTensor
        device = 'cpu'
else:
    dtype = torch.FloatTensor
    device = 'cpu'
from scipy.fftpack import dct
from scipy.fftpack import idct
from scipy import io as sio 
import time

## 1.2. Loading image and preprocessing

In [None]:
#dataset = 'mnist' # 'mnist' or 'celeba'
dataset = 'mnist'
path = './test_data/' + dataset + '/' 
img_name = dataset + '1' # 1-5 (for celeba), 1-6 (for mnist)
img_path = path + img_name + ".jpg"
img_pil = Image.open(img_path)
if dataset == 'celeba':
    img_pil = img_pil.crop((60,80+20,60+64,80+84)) #crop to 3 x 64 x 64
img_np = pil_to_np(img_pil)
print('Dimensions of input image:', img_np.shape)
img_np = img_np / np.max(img_np)
img_np_orig = 1*img_np

Display image x and convert to pytorch variable

In [None]:
if dataset == 'celeba':
    plt.imshow(img_np.transpose(1,2,0))
else:
    plt.imshow(img_np[0,:,:])
    plt.gray()
plt.axis('off')
img_var = np_to_var(img_np).type(dtype)
d = img_np.shape[1]
out_ch = img_np.shape[0]
d_image = img_np.size

##1.3. Setup model for compressed sensing

In [None]:
f = 0.2 #compression rate
print('Compression rate is ', f)
m_image = int(f*d_image)
print('Number of measurements is ',m_image, ' for signal of length ', d_image)
# random Gaussian measurement matrix : A
Ameas = np.random.randn(m_image,d_image).astype(float)/np.sqrt(m_image)
Ameas_var = torch.from_numpy(Ameas).float().to(device)
# measurements : y = A*x
img_var_meas = torch.matmul(Ameas_var,img_var.to(device).reshape(d_image,1))

##1.4. Initialize the deep decoder

In [None]:
#use decoder architecture or DC GAN architecture
decodetype = 'upsample' # transposeconv / upsample

if dataset == 'mnist':
    num_channels = [25,15,10] 
elif dataset== 'celeba':    
    num_channels = [120,25,15,10] 
else:
    num_channels = [512,256,128]
    
output_depth = img_np.shape[0] # number of output channels
net = autoencodernet(num_output_channels=output_depth,num_channels_up=num_channels,need_sigmoid=True, Ameas=Ameas_var,
                        decodetype=decodetype
                        ).type(dtype)

print("number of parameters: ", num_param(net))
if decodetype == 'upsample':
    print(net.decoder)
elif decodetype == 'transposeconv':
    print(net.convdecoder)
net_in = copy.deepcopy(net)

##1.5 Solving minimization problem to recover the signal

In [None]:
OPTIMIZER='adam'       #optimizer - SGD or adam 
numit = 10000         #number of iterations for SGD
LR = 0.0001              #typically 0.02-0.5 for gd , higher for more complex structures

optimizer2 = None                                    
numit_inner = None
LR_LS = None

lr_decay_epoch = 3000

t0 = time.time()
mse_t, ni, net, ni_mod, in_np_img = fit( 
                            net=net,
                            num_channels=num_channels,
                            num_iter=numit,
                            numit_inner = numit_inner,
                            LR=LR,
                            LR_LS = LR_LS,
                            OPTIMIZER = OPTIMIZER,                          
                            optimizer2 = optimizer2,             
                            lr_decay_epoch = lr_decay_epoch,             
                            img_clean_var=img_var_meas,
                            find_best=True,
                            Ameas = Ameas_var,
                            model = mode,
                            code='uniform',
                            decodetype=decodetype,
                            optim=optim,
                            out_channels=out_ch        
                            )
t1 = time.time()
print('\ntime elapsed:',t1-t0)

##1.6. Plot loss w.r.t. number of measurements