In [None]:
import torch
import scipy.io as sio
import numpy as np
import os
import skimage.io
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from scipy.io import loadmat
#import flatcam
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.misc


**Download model**

In [None]:
!wget -q -O flatnet_separable_pointGrey_transposeInit https://www.dropbox.com/s/zh8ucresezrfb5o/flatnet_separable_pointGrey_transposeInit?dl=0
!wget -q -O flatnet_separable_pointGrey_randomInit https://www.dropbox.com/s/kve8ki2wll9lytg/flatnet_separable_pointGrey_randomInit?dl=0

**Set the path to the trained model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# specify the path to the pretrained model.
# modelRoot = r'flatnet_separable_pointGrey_transposeInit' ##Use this for Proposed-T
modelRoot = r'flatnet_separable_pointGrey_randomInit' ##Use this for Proposed-R

In [None]:
from skimage import transform
tform = transform.SimilarityTransform(rotation=0.00174) #to account for small rotation 

In [None]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch,momentum=0.99),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch,momentum=0.99),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

    
    
class double_conv2(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv2, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3,stride=2, padding=1),
            nn.BatchNorm2d(out_ch,momentum=0.99),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch,momentum=0.99),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x    

    
    

class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            double_conv2(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=False):
        super(up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3,padding=1)

    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class FlatNet(nn.Module):
    def __init__(self, n_channels=4):
        super(FlatNet, self).__init__()
        self.inc = inconv(n_channels, 128)
        self.down1 = down(128, 256)
        self.down2 = down(256, 512)
        self.down3 = down(512, 1024)
        self.down4 = down(1024, 1024)
        self.up1 = up(2048, 512)
        self.up2 = up(1024, 256)
        self.up3 = up(512, 128)
        self.up4 = up(256, 128)
        self.outc = outconv(128, 3)
        self.PhiL =nn.Parameter(torch.randn(500,256,1)) 
        self.PhiR = nn.Parameter(torch.randn(620,256,1)) 
        self.bn=nn.BatchNorm2d(4,momentum=0.99)
    def forward(self, Xinp):
        
        X0 = F.leaky_relu(torch.matmul(torch.matmul(Xinp[:,0,:,:],self.PhiR[:,:,0]).permute(0,2,1),self.PhiL[:,:,0]).permute(0,2,1).unsqueeze(3))
        X11 = F.leaky_relu(torch.matmul(torch.matmul(Xinp[:,1,:,:],self.PhiR[:,:,0]).permute(0,2,1),self.PhiL[:,:,0]).permute(0,2,1).unsqueeze(3))
        X12 = F.leaky_relu(torch.matmul(torch.matmul(Xinp[:,2,:,:],self.PhiR[:,:,0]).permute(0,2,1),self.PhiL[:,:,0]).permute(0,2,1).unsqueeze(3))
        X2 = F.leaky_relu(torch.matmul(torch.matmul(Xinp[:,3,:,:],self.PhiR[:,:,0]).permute(0,2,1),self.PhiL[:,:,0]).permute(0,2,1).unsqueeze(3))
        Xout = torch.cat((X2,X12,X11,X0),3)
        x = Xout.permute(0,3,1,2)
        x = self.bn(x)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)

            
        return torch.sigmoid(x)



In [None]:
print(a.keys())

In [None]:
a = torch.load('/content/drive/MyDrive/flatnet-flatnet-sep/latest.tar',map_location=torch.device('cpu'))

In [None]:
print(a['gen_state_dict'].keys())

**Load the model and update it state dictionary with the trained state dictionary**

In [None]:
flatnet = FlatNet(4)
#flatnet.load_state_dict(torch.load(modelRoot,map_location=torch.device('cpu')))
flatnet.load_state_dict(flatnet_casia)

flatnet.eval()

**Function to pre-process the measurement and evaluate it**

In [None]:
def evaluate(X):
    X=X/65535.0
    X_train=np.zeros((1,4,500,620))
    im1=np.zeros((500,620,4))
    im1[:,:,0]=X[0::2, 0::2]#b
    im1[:,:,1]=X[0::2, 1::2]#gb
    im1[:,:,2]=X[1::2, 0::2]#gr
    im1[:,:,3]=X[1::2, 1::2]#r
    im1=transform.warp(im1,tform)
    #im=im1[6:506,10:630,:]      
    rowMeans = im1.mean(axis=1, keepdims=True)
    colMeans = im1.mean(axis=0, keepdims=True)
    allMean = rowMeans.mean()
    im1 = im1 - rowMeans - colMeans + allMean

    X_train[0,:,:,:]=np.swapaxes(np.swapaxes(im1,0,2),1,2)
    X_train=X_train.astype('float32')
    X_val=torch.from_numpy(X_train)
    #print(X.size())
    Xvalout=flatnet(X_val)
    ims=Xvalout.detach().numpy()
    ims=np.swapaxes(np.swapaxes(ims[0,:,:,:],0,2),0,1)
    ims=(ims-np.min(ims))/(np.max(ims)-np.min(ims))
    return ims


**Load the measurement and evaluate it on the trained model**

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF

image = Image.open('/content/drive/MyDrive/RAFDB_Alligned/measurements/0013_meas.png')
x = TF.to_tensor(image)

print(x.size())

In [None]:
# Uncomment the below line if using local jupyter server
# !wget -q -O fc_8.png https://raw.githubusercontent.com/siddiquesalman/flatnet/flatnet-sep/example_data/fc_8.png
# Replace the input path with `fc_8.png`


#X = skimage.io.imread("https://raw.githubusercontent.com/siddiquesalman/flatnet/flatnet-sep/example_data/fc_8.png") ## Specify the path to the measurement
#X = skimage.io.imread("/content/drive/MyDrive/RAFDB_Alligned/measurements/0015_meas.png")
X = skimage.io.imread('/content/drive/MyDrive/RAFDB_Alligned/measurements/0014_meas.png')
print(X.shape)
#print(meas.shape)
#X = torch.from_numpy(X).permute(2,0,1)
#X = X.squeeze(0)
#print(X.size())
#X = Raw2Bayer(X)
print(X.shape)
recn = evaluate(X)
skimage.io.imshow(recn)
    

In [None]:
mat = loadmat('/content/drive/MyDrive/Colab Notebooks/lensless_imaging/flatcam_calibdata.mat')
cSize = np.squeeze(mat['cSize'][:, :]).astype(int)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def Raw2Bayer(x, crop_size = cSize, is_rotate = False):
    r''' Convert FlatCam raw data to Bayer'''
    
    # Step 1. Convert the Image & rotate 
    c, b, h, w = x.size()
    
    y = torch.zeros((c, 4, int(h/2), int(w/2)), device = torch.device('cuda'))

    if is_rotate:                       # ---> THIS MODES DOESNOT WORK YET!!! (2019.07.14)
        scale = torch.ones(1)
        angle = torch.ones(1) * 0.05 * 360              # 0.05 is angle collected from data measurements 
        center = torch.ones(1, 2)
        center[..., 0] = int(h / 4)  # x
        center[..., 1] = int(w / 4)  # y
        M = kr.get_rotation_matrix2d(center, angle, scale).cuda()
        _, _, h, w = y.size()
        
        y[:, 0, :, : ] = kr.warp_affine(x[:, :, 1::2, 1::2], M, dsize = (h, w))
        y[:, 1, :, : ] = kr.warp_affine(x[:, :, 0::2, 1::2], M, dsize = (h, w))
        y[:, 2, :, : ] = kr.warp_affine(x[:, :, 1::2, 0::2], M, dsize = (h, w))
        y[:, 3, :, : ] = kr.warp_affine(x[:, :, 0::2, 0::2], M, dsize = (h, w))

    else:
        y[:, 0, :, : ] = x[:, 0, 1::2, 1::2]
        y[:, 1, :, : ] = x[:, 0, 0::2, 1::2]
        y[:, 2, :, : ] = x[:, 0, 1::2, 0::2]
        y[:, 3, :, : ] = x[:, 0, 0::2, 0::2]

    # Step 3. Crop the image 
    start_row = int((y.size()[2] - crop_size[0]) / 2) 
    end_row = start_row + crop_size[0]
    start_col = int((y.size()[3] - crop_size[1])/2) 
    end_col = start_col + crop_size[1] 
    return y[:,:, start_row:end_row, start_col:end_col]
