In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
import os
import warnings
warnings.filterwarnings('ignore')
from include import *
from PIL import Image
import PIL
import pywt
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import time



GPU = True
if GPU == True:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print("num GPUs",torch.cuda.device_count())
    device = 'cuda'
    if torch.cuda.device_count()==0:
        dtype = torch.FloatTensor
        device = 'cpu'
else:
    dtype = torch.FloatTensor
    device = 'cpu'



# 1. Load a test image from a dataset (now : CelebA 128x128)

In [None]:
#dataset = 'mnist' # 'mnist' or 'celeba'
dataset = 'celeba'
path = './test_data/' + dataset + '/' 
img_name = dataset + '2' # 1-5 (for celeba), 1-6 (for mnist)
img_path = path + img_name + ".jpg"
img_pil = Image.open(img_path)

if dataset == 'celeba':
    cx=89
    cy=121
    img_pil = img_pil.crop((cx-64, cy - 64, cx + 64, cy+64))

img_np = pil_to_np(img_pil)
print('Dimensions of input image:', img_np.shape)
img_np = img_np / np.max(img_np)


img_np_orig = 1*img_np

'''if dataset == 'celeba':
    plt.imshow(img_np.transpose(1,2,0))
else:
    plt.imshow(img_np[0,:,:])
    plt.gray()'''

img_var = np_to_var(img_np).type(dtype)
d = img_np.shape[1]
out_ch = img_np.shape[0]
d_image = img_np.size

# normalize the pixels to [-1,1]
img_var = 2*img_var -1

grid = torchvision.utils.make_grid(img_var.clamp(min=-1, max=1), scale_each=True, normalize=True)
plt.imshow(grid.detach().permute(1, 2, 0).cpu().numpy())    
plt.axis('off')

save_to = './recovery/'
save_path= save_to + 'Original'+'_'+img_name+'.png'
plt.savefig(save_path, bbox_inches='tight', pad_inches = 0)



In [None]:
f = 0.2 #compression rate
print('Compression rate is ', f)
m_image = int(f*d_image)
print('Number of measurements is ',m_image, ' for signal of length ', d_image)

# random Gaussian measurement process

A = torch.randn(m_image, d_image).to(device)/np.sqrt(m_image)
x = img_var.to(device).reshape(d_image)
y = torch.matmul(A,x).to(device)

#latentDim = model.config.noiseVectorDim
print(A.shape, x.shape, y.shape)
mse = torch.nn.MSELoss()

# 2. Compressed sensing using generative models

## 2.1. Load a pretrained generative model on the dataset (now: PGGAN)

In [None]:
use_gpu = True if torch.cuda.is_available() else False

# trained on high-quality celebrity faces "celebA" dataset
# this model outputs 512 x 512 pixel images
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
                       'PGAN', model_name='celeba',
                       pretrained=True, useGPU=use_gpu)
# this model outputs 256 x 256 pixel images
# model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
#                        'PGAN', model_name='celebAHQ-256',
#                        pretrained=True, useGPU=use_gpu)
G = model.netG
#G.eval()
latentDim = model.config.noiseVectorDim

## 2.2. CS using the loaded GAN

In [None]:
t0= time.time()

z0, mse_wrt_loss = CSGM2(G, latentDim, y, A, device, num_iter = 4000)
x0 = G(z0)

grid = torchvision.utils.make_grid(x0.clamp(min=-1, max=1), scale_each=True, normalize=True)
plt.axis('off')
plt.imshow(grid.detach().permute(1, 2, 0).cpu().numpy())

save_path= save_to +'PGGAN'+'_'+img_name+'.png'
plt.savefig(save_path,bbox_inches='tight', pad_inches = 0) 
    
t1= time.time()
print('\nTime elapsed:',t1-t0)

error_wrt_truth = mse(x0, img_var).item()
print('\nl2-recovery error:', error_wrt_truth)

In [None]:
'''plt.xlabel('optimizer iteration')
plt.ylabel('recovery error')
plt.semilogy(mse_wrt_truth)'''

# 3. Compressed Sensing using Deep decoder

## 3.1. Define the network

In [None]:
#use decoder architecture or DC GAN architecture
decodetype = 'upsample' # transposeconv / upsample

num_channels = [120,40,20,15,10] 

output_depth = img_np.shape[0] # number of output channels
net = autoencodernet(num_output_channels=output_depth,num_channels_up=num_channels,need_sigmoid=True, 
                        decodetype=decodetype
                        ).type(dtype)

print("number of parameters: ", num_param(net))
if decodetype == 'upsample':
    print(net.decoder)
elif decodetype == 'transposeconv':
    print(net.convdecoder)
net_in = copy.deepcopy(net)



## 3.2. CS using untrained network

In [None]:
t0 = time.time()

net, net_input, loss = CS_DD(net, num_channels, d_image, y=y, A=A, device= device, 
                             num_iter = 5000, lr_decay_epoch=2000)#12000-3000
x_DD = 2*net( net_input.type(dtype) )-1 #.data.cpu().numpy()[0]

t1 = time.time()
grid = torchvision.utils.make_grid(x_DD.clamp(min=-1, max=1), scale_each=True, normalize=True)
plt.imshow(grid.detach().permute(1, 2, 0).cpu().numpy())
plt.axis('off')
#plt.imshow(x_hat.transpose(1,2,0))
#plt.show()

print('\n time elapsed:', t1-t0)

error_wrt_truth = mse(x_DD, img_var).item()
print('\nl2-recovery error:', error_wrt_truth)

save_path= save_to +'DD'+'_'+img_name+'.png'
plt.savefig(save_path,bbox_inches='tight', pad_inches = 0)


# 4. Compressed sensing using hybrid model

## 4.1. Define the untrained network used for hybrid model

In [None]:
#use decoder architecture or DC GAN architecture
decodetype = 'upsample' # transposeconv / upsample

num_channels = [120,40,20,15,10] 

output_depth = img_np.shape[0] # number of output channels
#net = autoencodernet(num_output_channels=output_depth,num_channels_up=num_channels,need_sigmoid=True, 
#                        decodetype=decodetype
#                        ).type(dtype)

print("number of parameters: ", num_param(net))
if decodetype == 'upsample':
    print(net.decoder)
elif decodetype == 'transposeconv':
    print(net.convdecoder)
net_copy = copy.deepcopy(net)

## 4.2. CS using hybrid model

In [None]:
t0 = time.time()

#initialization by csgm
print('Initialization step')
z0, mse_wrt_truth = CSGM2(G=G, latentDim=latentDim, y=y, A=A, device=device, num_iter=500)
x0 = G(z0)

net, net_input, loss = CS_DD(net, num_channels, d_image, y=y, A=A, device= device, 
                             num_iter = 500, lr_decay_epoch=1000)
x_DD = 2*net( net_input.type(dtype) )-1 #.data.cpu().numpy()[0]


# performing optimization
print('Performing optimization')
net, net_input, z, alpha, beta, loss = CS_hybrid2(G, net, net_input, num_channels, d_image, y, A, device = device, z_0 = z0, 
                                            latentDim=latentDim, num_iter = 4000, lr_decay_epoch = 2000)

x_hat = alpha.clamp(0,1)*G(z) + beta.clamp(0,1)*(2*net(net_input.type(dtype)) - 1)

print(alpha,beta)

grid = torchvision.utils.make_grid(x_hat.clamp(min=-1, max=1), scale_each=True, normalize=True)
plt.imshow(grid.detach().permute(1, 2, 0).cpu().numpy())
plt.axis('off')

t1 = time.time()
print('\n time elapsed:', t1-t0)

error_wrt_truth = mse(x_hat, img_var).item()
print('\nl2-recovery error:', error_wrt_truth)


save_path= save_to +'Hybrid'+'_'+img_name+'.png'
plt.savefig(save_path,bbox_inches='tight', pad_inches = 0) 