In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/davevanveen/compsensing_dip.git
!cd compsensing_dip

Cloning into 'compsensing_dip'...
remote: Enumerating objects: 9970, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 9970 (delta 0), reused 3 (delta 0), pack-reused 9967[K
Receiving objects: 100% (9970/9970), 839.96 MiB | 30.04 MiB/s, done.
Resolving deltas: 100% (1336/1336), done.
Checking out files: 100% (5223/5223), done.


In [None]:
!pip install -r 'compsensing_dip/requirements.txt'

[31mERROR: Could not find a version that satisfies the requirement python==2.7 (from versions: none)[0m
[31mERROR: No matching distribution found for python==2.7[0m


In [None]:
#utils
import numpy as np
import os
import errno
import parser

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets,transforms

BATCH_SIZE = 1

class DCGAN_XRAY(nn.Module):
    def __init__(self, nz, ngf=64, output_size=256, nc=3, num_measurements=1000):
        super(DCGAN_XRAY, self).__init__()
        self.nc = nc
        self.output_size = output_size

        self.conv1 = nn.ConvTranspose2d(nz, ngf, 4, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf)
        self.conv2 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf)
        self.conv3 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(ngf)
        self.conv4 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn4 = nn.BatchNorm2d(ngf)
        self.conv5 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn5 = nn.BatchNorm2d(ngf)
        self.conv6 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn6 = nn.BatchNorm2d(ngf)
        self.conv7 = nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False) #output is image
    
    def forward(self, z):
        input_size = z.size()
        x = F.relu(self.bn1(self.conv1(z)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = torch.tanh(self.conv7(x,output_size=(-1,self.nc,self.output_size,self.output_size)))
       
        return x

class DCGAN_MNIST(nn.Module):
    def __init__(self, nz, ngf=64, output_size=28, nc=1, num_measurements=10):
        super(DCGAN_MNIST, self).__init__()
        self.nc = nc
        self.output_size = output_size

        self.conv1 = nn.ConvTranspose2d(nz, ngf*8, 2, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf*8)
        self.conv2 = nn.ConvTranspose2d(ngf*8, ngf*4, 4, 1, 0, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf*4)
        self.conv3 = nn.ConvTranspose2d(ngf*4, ngf*2, 3, 1, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(ngf*2)
        self.conv4 = nn.ConvTranspose2d(ngf*2, ngf, 3, 1, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(ngf)
        self.conv5 = nn.ConvTranspose2d(ngf, nc, 3, 1, 1, bias=False) 
        
    
    def forward(self, x):
        input_size = x.size()

        # DCGAN_MNIST with old PyTorch version
        # x = F.upsample(F.relu(self.bn1(self.conv1(x))),scale_factor=2)
        # x = F.relu(self.bn2(self.conv2(x)))
        # x = F.upsample(F.relu(self.bn3(self.conv3(x))),scale_factor=2)
        # x = F.upsample(F.relu(self.bn4(self.conv4(x))),scale_factor=2)
        # x = torch.tanh(self.conv5(x,output_size=(-1,self.nc,self.output_size,self.output_size)))

        x = F.interpolate(F.relu(self.bn1(self.conv1(x))),scale_factor=2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.interpolate(F.relu(self.bn3(self.conv3(x))),scale_factor=2)
        x = F.interpolate(F.relu(self.bn4(self.conv4(x))),scale_factor=2)
        x = torch.tanh(self.conv5(x,output_size=(-1,self.nc,self.output_size,self.output_size)))
       
        return x

class DCGAN_RETINO(nn.Module):
    def __init__(self, nz, ngf=64, output_size=256, nc=3, num_measurements=1000):
        super(DCGAN_RETINO, self).__init__()
        self.nc = nc
        self.output_size = output_size

        self.conv1 = nn.ConvTranspose2d(nz, ngf, 4, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf)
        self.conv2 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf)
        self.conv3 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(ngf)
        self.conv4 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn4 = nn.BatchNorm2d(ngf)
        self.conv5 = nn.ConvTranspose2d(ngf, ngf, 6, 2, 2, bias=False)
        self.bn5 = nn.BatchNorm2d(ngf)
        self.conv6 = nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False)
        #self.fc = nn.Linear((output_size)*(output_size)*nc,num_measurements, bias=False) #fc layer - old version
   
    def forward(self, x):
        input_size = x.size()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = torch.tanh(self.conv6(x,output_size=(-1,self.nc,self.output_size,self.output_size)))
       
        return x

NGF = 64
def init_dcgan(args):

    if args.DATASET == 'xray':
        net = DCGAN_XRAY(args.Z_DIM, NGF, args.IMG_SIZE,\
            args.NUM_CHANNELS, args.NUM_MEASUREMENTS)
    elif args.DATASET == 'mnist':
        net = DCGAN_MNIST(args.Z_DIM, NGF, args.IMG_SIZE,\
            args.NUM_CHANNELS, args.NUM_MEASUREMENTS)
    elif args.DATASET == 'retino':
        net = DCGAN_RETINO(args.Z_DIM, NGF, args.IMG_SIZE,\
            args.NUM_CHANNELS, args.NUM_MEASUREMENTS)
    return net

def init_output_arrays(args):
    loss_re = np.zeros((args.NUM_RESTARTS, BATCH_SIZE))
    recons_re = np.zeros((args.NUM_RESTARTS, BATCH_SIZE, args.NUM_CHANNELS, \
                    args.IMG_SIZE, args.IMG_SIZE))
    return loss_re, recons_re

lambdas_tv = {'mnist': 1e-2, 'xray': 5e-2, 'retino': 2e-2}
lambdas_lr = {'mnist': 0, 'xray': 100, 'retino': 1000}
def get_constants(args, dtype):
    MU_FN = 'mu_{0}.npy'.format(args.NUM_MEASUREMENTS)
    MU_PATH = os.path.join(args.LR_FOLDER,MU_FN)
    SIG_FN = "sig_{0}.npy".format(args.NUM_MEASUREMENTS)
    SIG_PATH = os.path.join(args.LR_FOLDER,SIG_FN)
    mu_ = np.load(MU_PATH)
    sig_ = np.load(SIG_PATH)

    mu = torch.FloatTensor(mu_).type(dtype)
    sig_inv = torch.FloatTensor(np.linalg.inv(sig_)).type(dtype)
    try:
        tvc = lambdas_tv[args.DATASET]
    except AttributeError:
        tvc = 1e-2
    try:
        lrc = lambdas_lr[args.DATASET]
    except AttributeError:
        lrc = 0
    return mu, sig_inv, tvc, lrc

def renorm(x):
    return 0.5*x + 0.5

def plot(x,renormalize=True):
    if renormalize:
        plt.imshow(renorm(x).data[0].cpu().numpy(), cmap='gray')
    else:
        plt.imshow(x.data[0].cpu().numpy(), cmap='gray')


exit_window = 50 # number of consecutive MSE values upon which we compare
thresh_ratio = 45 # number of MSE values that must be larger for us to exit
def exit_check(window, i): # if converged, then exit current experiment
    mse_base = window[0] # get first mse value in window
    
    if len(np.where(window > mse_base)[0]) >= thresh_ratio: # if 20/25 values in window are higher than mse_base
        return True, mse_base
    else:
        mse_last = window[exit_window-1] #get the last value of MSE in window
        return False, mse_last


def define_compose(NC, IMG_SIZE): # define compose based on NUM_CHANNELS, IMG_SIZE
    if NC == 1: #grayscale
        compose = transforms.Compose([
            transforms.Resize((IMG_SIZE,IMG_SIZE)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((.5,.5,.5),(.5,.5,.5))
        ])
    elif NC == 3: #rgb
        compose = transforms.Compose([
            transforms.Resize((IMG_SIZE,IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize((.5,.5,.5),(.5,.5,.5))
        ])
    return compose

def set_dtype(CUDA):
    if CUDA: # if cuda is available
        return torch.cuda.FloatTensor
    else:
        return torch.FloatTensor

def get_path_out(args, path_in):
    fn = path_leaf(path_in[0]) # format filename from path

    if args.ALG == 'bm3d' or args.ALG == 'tval3':
        file_ext = 'mat' # if algorithm is implemented in matlab
    else:
        file_ext = 'npy' # if algorithm is implemented in python

    path_out = 'reconstructions/{0}/{1}/meas{2}/im{3}.{4}'.format( \
            args.DATASET, args.ALG, args.NUM_MEASUREMENTS, fn, file_ext)

    full_path = os.getcwd()  + '/' + path_out
    return full_path


def recons_exists(args, path_in):
    path_out = get_path_out(args, path_in)
    print(path_out)
    if os.path.isfile(path_out):
        return True
    else:
        return False

def save_reconstruction(x_hat, args, path_in):
    path_out = get_path_out(args, path_in)

    if not os.path.exists(os.path.dirname(path_out)):
        try:
            os.makedirs(os.path.dirname(path_out))
        except OSError as exc: # guard against race condition
            if exc.errno != errno.EEXIST:
                raise

    np.save(path_out, x_hat)

def check_args(args): # check args for correctness
    IM_DIMN = args.IMG_SIZE * args.IMG_SIZE * args.NUM_CHANNELS

    if isinstance(args.NUM_MEASUREMENTS, int):
        if args.NUM_MEASUREMENTS > IM_DIMN:
            raise ValueError('NUM_MEASUREMENTS must be less than image dimension ' \
                + str(IM_DIMN))
    else:
        for num_measurements in args.NUM_MEASUREMENTS:
            if num_measurements > IM_DIMN:
                raise ValueError('NUM_MEASUREMENTS must be less than image dimension ' \
                    + str(IM_DIMN))
    if not args.DEMO == 'False':
        if not args.DEMO == 'True':
            raise ValueError('DEMO must be either True or False.')

def convert_to_list(args): # returns list for NUM_MEAS, BATCH
    if not isinstance(args.NUM_MEASUREMENTS, list):
        NUM_MEASUREMENTS_LIST = [args.NUM_MEASUREMENTS]
    else:
        NUM_MEASUREMENTS_LIST = args.NUM_MEASUREMENTS
    if not isinstance(args.ALG, list):
        ALG_LIST = [args.ALG]
    else:
        ALG_LIST = args.ALG
    return NUM_MEASUREMENTS_LIST, ALG_LIST

def path_leaf(path):
    # if '/' in path and if '\\' in path:
    #     raise ValueError('Path to image cannot contain both forward and backward slashes')

    if '.' in path: # remove file extension
        path_no_extn = os.path.splitext(path)[0]
    else:
        raise ValueError('Filename does not contain extension')
    
    head, tail = os.path.split(path_no_extn)
    return tail or os.path.basename(head)

def get_data(args):
    compose = define_compose(args.NUM_CHANNELS, args.IMG_SIZE)

    if args.DEMO == 'True':
        image_direc = 'data/{0}_demo/'.format(args.DATASET)
    else:
        image_direc = 'data/{0}/'.format(args.DATASET)

    dataset = ImageFolderWithPaths(image_direc, transform = compose)
    dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE)

    return dataloader

class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """
    # override the __getitem__ method. this is the method dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))

        return tuple_with_path      


In [None]:
#baselines
from sklearn.linear_model import Lasso
import scipy.fftpack as fftpack
import pywt
import copy
import numpy as np

LMBD = 1e-5

def solve_lasso(A_val, y_val, lmbd=1e-1):
    num_measurements = y_val.shape[0]
    lasso_est = Lasso(alpha=lmbd)#,tol=1e-4,selection='random')
    lasso_est.fit(A_val.T, y_val.reshape(num_measurements))
    x_hat = lasso_est.coef_
    x_hat = np.reshape(x_hat, [-1])
    return x_hat

def dct2(image_channel):
    return fftpack.dct(fftpack.dct(image_channel.T, norm='ortho').T, norm='ortho')

def idct2(image_channel):
    return fftpack.idct(fftpack.idct(image_channel.T, norm='ortho').T, norm='ortho')

def db4(image_channel):
    coeffs = pywt.wavedec2(image_channel,'db4')
    arr, coeff_slices = pywt.coeffs_to_array(coeffs)
    return arr, coeff_slices

def idb4(image_channel, coeff_slices):
    coeffs_from_arr = pywt.array_to_coeffs(image_channel, coeff_slices, output_format='wavedec2')
    return pywt.waverec2(coeffs_from_arr,'db4')

def vec(channels,num_channels):
    shape = channels[0].shape
    image = np.zeros((num_channels, shape[0], shape[1]))
    for i, channel in enumerate(channels):
        image[i, :, :] = channel
    return image.reshape([-1])

def devec(vector,num_channels):
    size = int(np.sqrt(vector.shape[0]/num_channels))
    image = np.reshape(vector, [num_channels, size, size])
    channels = [image[i, :, :] for i in range(num_channels)]
    return channels

def lasso_dct_estimator(args):  #pylint: disable = W0613
    """LASSO with DCT"""
    def estimator(A_val, y_batch_val, args):
        # One can prove that taking 2D DCT of each row of A,
        # then solving usual LASSO, and finally taking 2D ICT gives the correct answer.
        A_new = copy.deepcopy(A_val)
        for i in range(A_val.shape[1]):
            A_new[:, i] = vec([dct2(channel) for channel in devec(A_new[:, i],args.NUM_CHANNELS)],args.NUM_CHANNELS)
        y_val = y_batch_val[0]
        z_hat = solve_lasso(A_new, y_val, LMBD)
        x_hat = vec([idct2(channel) for channel in devec(z_hat,args.NUM_CHANNELS)],args.NUM_CHANNELS).T
        x_hat = np.maximum(np.minimum(x_hat, 1), -1)
        return x_hat
    return estimator

def lasso_wavelet_estimator(args):  #pylint: disable = W0613
    """LASSO with DWT"""
    def estimator(A_val, y_batch_val, args):
        # One can prove that taking 2D DWT of each row of A,
        # then solving usual LASSO, and finally taking 2D IWT gives the correct answer.
        A_new = copy.deepcopy(A_val)
        arr, coeff_slices = db4(devec(A_new[:,0],args.NUM_CHANNELS)[0])
        A_wav = np.zeros((args.NUM_CHANNELS*arr.shape[0]*arr.shape[1],A_val.shape[1]))
        for i in range(A_val.shape[1]):
            A_wav[:, i] = vec([db4(channel)[0] for channel in devec(A_new[:, i],args.NUM_CHANNELS)],args.NUM_CHANNELS)
        y_val = y_batch_val[0]
        z_hat = solve_lasso(A_wav, y_val, LMBD)
        x_hat = vec([idb4(channel,coeff_slices) for channel in devec(z_hat,args.NUM_CHANNELS)],args.NUM_CHANNELS).T
        x_hat = np.maximum(np.minimum(x_hat, 1), -1)
        return x_hat
    return estimator

def get_A(dimension,num_measurements):
    return np.sqrt(1.0/num_measurements)*np.random.randn(dimension,num_measurements)

In [None]:
#cs_dip
import numpy as np
import parser
import torch
from torch.autograd import Variable
#import baselines

#import utils
import time

#args = parser.parse_args('configs.json') 

CUDA = torch.cuda.is_available()
dtype = set_dtype(CUDA)
se = torch.nn.MSELoss(reduction='none').type(dtype)

BATCH_SIZE = 1
EXIT_WINDOW = 51
loss_re, recons_re = init_output_arrays(args)

def dip_estimator(args):
    def estimator(A_val, y_batch_val, args):

        y = torch.FloatTensor(y_batch_val).type(dtype) # init measurements y
        A = torch.FloatTensor(A_val).type(dtype)       # init measurement matrix A

        mu, sig_inv, tvc, lrc = utils.get_constants(args, dtype)

        for j in range(args.NUM_RESTARTS):
            
            net = utils.init_dcgan(args)

            z = torch.zeros(BATCH_SIZE*args.Z_DIM).type(dtype).view(BATCH_SIZE,args.Z_DIM,1,1)
            z.data.normal_().type(dtype) #init random input seed
            if CUDA:
                net.cuda() # cast network to GPU if available
            
            optim = torch.optim.RMSprop(net.parameters(),lr=0.001, momentum=0.9, weight_decay=0)
            loss_iter = []
            recons_iter = [] 

            for i in range(args.NUM_ITER):

                optim.zero_grad()

                # calculate measurement loss || y - A*G(z) ||
                G = net(z)
                AG = torch.matmul(G.view(BATCH_SIZE,-1),A) # A*G(z)
                y_loss = torch.mean(torch.sum(se(AG,y),dim=1))

                # calculate total variation loss 
                tv_loss = (torch.sum(torch.abs(G[:,:,:,:-1] - G[:,:,:,1:]))\
                            + torch.sum(torch.abs(G[:,:,:-1,:] - G[:,:,1:,:]))) 

                # calculate learned regularization loss
                layers = net.parameters()
                layer_means = torch.cat([layer.mean().view(1) for layer in layers])
                lr_loss = torch.matmul(layer_means-mu,torch.matmul(sig_inv,layer_means-mu))
                
                total_loss = y_loss + lrc*lr_loss + tvc*tv_loss # total loss for iteration i
                 
                # stopping condition to account for optimizer convergence
                if i >= args.NUM_ITER - EXIT_WINDOW: 
                    recons_iter.append(G.data.cpu().numpy())
                    loss_iter.append(total_loss.data.cpu().numpy())
                    if i == args.NUM_ITER - 1:
                        idx_iter = np.argmin(loss_iter)

                total_loss.backward() # backprop
                optim.step()

            recons_re[j] = recons_iter[idx_iter]       
            loss_re[j] = y_loss.data.cpu().numpy()

        idx_re = np.argmin(loss_re,axis=0)
        x_hat = recons_re[idx_re]

        return x_hat

    return estimator


NameError: ignored

In [None]:
import sys
sys.path.insert(0,'/content/compressing_dip/')
print(sys.path)
import numpy as np
import pickle as pkl
import os
import parser
import numpy as np

import torch
from torchvision import datasets

#import utils
import cs_dip
import baselines as baselines 
import time

NEW_RECONS = False

args = parser.parse_args('configs.json')
print(args)

NUM_MEASUREMENTS_LIST, ALG_LIST = utils.convert_to_list(args)

dataloader = utils.get_data(args) # get dataset of images

for num_meas in NUM_MEASUREMENTS_LIST:
    args.NUM_MEASUREMENTS = num_meas 
    
    # init measurement matrix
    A = baselines.get_A(args.IMG_SIZE*args.IMG_SIZE*args.NUM_CHANNELS, args.NUM_MEASUREMENTS)
    
    for _, (batch, _, im_path) in enumerate(dataloader):

        
        eta_sig = 0 # set value to induce noise 
        eta = np.random.normal(0, eta_sig * (1.0 / args.NUM_MEASUREMENTS) ,args.NUM_MEASUREMENTS)
        

        x = batch.view(1,-1).cpu().numpy() # define image
        y = np.dot(x,A) + eta

        for alg in ALG_LIST:
            args.ALG = alg

            if utils.recons_exists(args, im_path): # to avoid redundant reconstructions
                continue
            NEW_RECONS = True

            if alg == 'csdip':
                estimator = cs_dip.dip_estimator(args)
            elif alg == 'dct':
                estimator = baselines.lasso_dct_estimator(args)
            elif alg == 'wavelet':
                estimator = baselines.lasso_wavelet_estimator(args)
            elif alg == 'bm3d' or alg == 'tval3':
                raise NotImplementedError('BM3D-AMP and TVAL3 are implemented in Matlab. \
                                            Please see GitHub repository for details.')
            else:
                raise NotImplementedError

            x_hat = estimator(A, y, args)

            utils.save_reconstruction(x_hat, args, im_path)

if NEW_RECONS == False:
    print('Duplicate experiment configurations. No new data generated.')
else:
    print('Reconstructions generated!')


['/content/compressing_dip/', '/content/compressing_dip/', '/content/compressing_dip/', '/content/compressing_dip', '', '/content', '/env/python', '/usr/lib/python37.zip', '/usr/lib/python3.7', '/usr/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages/IPython/extensions', '/root/.ipython']


ModuleNotFoundError: ignored