In [1]:
import numpy as np # to handle matrix and data operation
import pandas as pd # to read csv and handle dataframe
from scipy import signal
from scipy import misc
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import Compose
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import math
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple

Import dataset (CIFAR)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()])

train_transform = Compose([
    transforms.ToTensor(),
])

test_transform = Compose([
    transforms.ToTensor(),
])

batch_size = 32
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=train_transform, target_transform=None, download=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,drop_last=True,
                                          shuffle=True, pin_memory=False, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=test_transform, target_transform=None, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,drop_last=True,
                                         shuffle=False, pin_memory=False, num_workers=8)
                                         


Define the network structure

In [None]:
#Check and print the GPU information
if not torch.cuda.is_available():
    print("GPU not detected! Please enable GPU for faster training!")
device = torch.device("cuda:0")
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [15]:
#Function to evaluate the inference accuracy
def test():
  net.eval()
  correct = 0
  total = 0
  with torch.no_grad():
      for data in testloader:
          images, labels = data
          images, labels = images.to(device), labels.to(device)
          outputs = net(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

  print('Accuracy of the network on the 10000 test images: %6.2f %%' % (
      100 * correct / total))
  return (100 * correct / total)

In [None]:
class _ConvNd(Module):

    __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size']   

    def __init__(self, in_channels, out_channels, kernel_size, batch_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups, bias, padding_mode):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.batch_size = batch_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(_ConvNd, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

class FTconvlayer(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, batch_size = 16, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        batch_size = 16
        kernel_size = _pair(kernel_size) 
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(FTconvlayer, self).__init__(
            in_channels, out_channels, kernel_size, batch_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

    def quantization_n(self, input, n = 1, max = 1):
      intv = max/(2**n-1)
      qunt = torch.ceil(torch.mul(input,(1/intv)))  
      out = torch.mul(qunt,intv)
      out = torch.clamp(out, min=0, max=max) #make sure the quantized version lies in the interval 0-1, if it's bigger than one just clamp it at one
      return(out)  

    def input_quant(self, input, level = 5):
        max = torch.max(input)
        intv = max/level
        qunt = torch.floor(torch.mul(input,(1/intv)))
        out = torch.mul(qunt, intv)
        out = torch.clamp(out, min=0, max=max)
        return(out)
    
    def weightclamp(self, input):
      return input.clamp_(0)

    def make_complex(self, x):  #converts a real tensor into complex form by adding one extra dimension to it
        x_i = torch.cuda.FloatTensor(x.shape).fill_(0)
        y = torch.stack((x,x_i),-1)
        return torch.view_as_complex(y)

    def neg_complex_exp(self, x):    #since pytorch does not support complex exponential, implemented using euler formula exp(-jx)=cos(-x)+jsin(-x)
        x_cos = torch.cos(-x)
        x_sin = torch.sin(-x)
        x_euler = torch.stack((x_cos, x_sin), -1)
        return torch.view_as_complex(x_euler)

    def complex_mul(self, x,y):  #this implementation should support broadcasting 
        result = x*y
        return result

    def conj_transpose(self, x):  #should support broadcasting
        x = torch.view_as_real(x)
        size = len(x.size())
        x_r = x[...,0]
        x_i = x[...,1]
        x_i_c =-x_i
        x_conj = torch.stack((x_r,x_i_c),-1)
        x_conj_t = torch.transpose(x_conj, size-3, size-2)#size-1 is the dimension for complex representation
        return torch.view_as_complex(x_conj_t)

    def roll_n(self, X, axis, n):
        f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
        b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
        front = X[f_idx]
        back = X[b_idx]
        return torch.cat([back, front], axis)   

    def batch_fftshift2d(self, x):
        real, imag = torch.unbind(x, -1)
        for dim in range(len(real.size())-2, len(real.size())):
            n_shift = real.size(dim)//2
            if real.size(dim) % 2 != 0:
                n_shift += 1  # for odd-sized images
            real = self.roll_n(real, axis=dim, n=n_shift)
            imag = self.roll_n(imag, axis=dim, n=n_shift)
        return torch.stack((real, imag), -1)  # last dim=2 (real&imag)


    def batch_ifftshift2d(self, x):
        real, imag = torch.unbind(x, -1)
        for dim in range(len(real.size()) - 1, len(real.size())-3, -1):
            real = self.roll_n(real, axis=dim, n=real.size(dim)//2)
            imag = self.roll_n(imag, axis=dim, n=imag.size(dim)//2)
        return torch.stack((real, imag), -1)  # last dim=2 (real&imag)

    def propTF(self, u1,L,lambdaa,z):
        batch,M,N = u1.shape
        dx = L/M
        fx = torch.arange(-1/(2*dx),1/(2*dx), 1/L).cuda()
        FX,FY = torch.meshgrid(fx,fx)
        H = self.neg_complex_exp(math.pi*lambdaa*z*(FX**2+FY**2))
        H = torch.fft.fftshift(H)
        U1 = torch.fft.fft2(torch.fft.fftshift(u1)) 
        U2 = self.complex_mul(H,U1)
        u2 = torch.fft.ifftshift(torch.fft.ifft2(U2)) #
        return u2


    def seidel_5(self, u0, v0, X, Y, wd, w040, w131, w222, w220, w311):
        beta = math.atan2(u0,v0)
        u0r=math.sqrt(u0**2+v0**2)
        Xr=X*math.cos(beta)+Y*math.sin(beta)
        Yr=-X*math.sin(beta)+Y*math.cos(beta)
        rho2=Xr**2+Yr**2
        w=wd*rho2+w040*rho2**2+w131*u0r*rho2*Xr+ w222*u0r**2*Xr**2+w220*u0r**2*rho2+w311*math.pow(u0r,3)*Xr
        return w

    def circ(self, r):
        out = torch.abs(r)<=1
        return out

    ''' for block mean pytorch does not support reshape using 'F' ordering, so use normal reshape and then permute'''
    def blockmean_batch(self, X, V, W):
        S=X.shape
        B1 = S[0]
        B2 = S[1]
        M = int(S[2] - S[2]%V)
        N = int(S[3] - S[3]%W)
        if(M*N == 0):
            Y = X
            return Y
        MV = int(M/V)  
        NW = int(N/W)
        XM = X[:,:,0:M, 0:N].permute(0,1,3,2).reshape([B1,B2,NW, W, MV, V]).permute(0,1,5,4,3,2)
        Y = torch.sum(torch.sum(XM,2),3) * (1/(V*W))
        return Y

    def extract_result(self,input,img_size):
        size = input.shape[-1]
        start = int((size-4*img_size)/2)
        end = start + 4*img_size
        output = input[:,:,start:end,start:end]
        return output
    
    def input_pad(self,input,padsize):
      input_size = input.shape[2]
      pad_size_x = int((padsize-input_size)/2)
      pad_size_y = int((padsize-input_size)/2)
      p2d = (pad_size_x, pad_size_y, pad_size_x, pad_size_y)
      input_pad = F.pad(input, p2d, "constant", 0)
      return input_pad 

    def input_adjust(self, input):
      input = torch.mul(input,5)
      input = torch.floor(input)
      output = torch.clamp(input,min=0,max=1)
      return(output)

    def evenkernel(self, input):
      uptri = torch.triu(input,diagonal = 1)
      downtri = torch.flip(torch.triu(input,diagonal = 1),[1,2])
      result = uptri+downtri
      return result

    def extract_result(self,input,img_size):
      size = input.shape[-1]
      start = int((size-img_size)/2)
      end = start + img_size
      output = input[:,:,start:end,start:end]
      return output
    
    def norm(self,input):
      size = input.shape
      output = torch.cuda.FloatTensor(size).fill_(0)
      for i in range(size[0]):
        for j in range(size[1]):
          orig = input[i,j,:,:]
          maxi = torch.max(orig)
          mini = torch.min(orig)
          output[i,j,:,:] = (orig-mini)/(maxi-mini)
      return output

    def accurate_model_forward(self, input, weight):
        err = 1e-8
        with torch.no_grad():
          input = self.input_quant(input)
        xx = 208
        yy = xx 
        w = 32    
        n_filter_actual = int(self.out_channels/2)
        output_full = torch.cuda.FloatTensor(16,self.out_channels,w,w).fill_(0)
        output_sub = torch.cuda.FloatTensor(16, int(self.out_channels/2), w, w).fill_(0)
        idledmd = torch.cuda.FloatTensor(208, 208).fill_(1)
        M,N = idledmd.shape
        L1=1.90e-2*xx/208
        L2=1.09e-2*yy/208
        du=L1/M
        dv=L2/N
        lambdaa = 0.633e-6
        k=2*math.pi/lambdaa
        
        '''Lens Diffraction (Aperture) and Aberration'''
        fu = torch.arange(-1/(2*du),1/(2*du),1/L1)
        fv = torch.arange(-1/(2*dv),1/(2*dv),1/L2)
        Dxp = 5e-2
        wxp = Dxp/2
        zxp = 200e-3
        lz = lambdaa*zxp
        u0 = 0
        v0 = 0
        f0 = wxp/(lambdaa*zxp)
        '''Lens parameter for aberration (Seidel coefficients), wavefront alteration from spherical waves'''
        wd=0*lambdaa
        w040=4.963*lambdaa
        w131=2.637*lambdaa
        w222=9.025*lambdaa
        w220=7.536*2*lambdaa
        w311=0.157*12*lambdaa
        
        Fu,Fv = torch.meshgrid(fu,fv)
        Fu = torch.transpose(Fu,0,1)
        Fv = torch.transpose(Fv,0,1)
        W = self.seidel_5(u0,v0,-lz*Fu/wxp,-lz*Fv/wxp,wd,w040,w131,w222,w220,w311).cuda() #same as the matlab calculation
        #H = circ(torch.sqrt(Fu**2 + Fv**2)/f0)*torch.exp(-1j*k*W)#same as matlab calculation
        H = self.complex_mul(self.make_complex(self.circ(torch.sqrt(Fu**2 + Fv**2)/f0).float().cuda()),self.neg_complex_exp(k*W))
#-----------------------from here is the actual training, before the loop is basically constants/parameters genreation, which does not needs to be backproped     
        for c_in in range(input.shape[1]): # iters for number of input channels
            signal = input[:,c_in,:,:] #the dimension of signal is 3, with one batch dimension
            weight_raw = weight[:,c_in,:,:] #the dimension of weights are now 3
            weight_raw.data = self.quantization_n(weight_raw.data, 1, 1)
            dmd_1 = self.make_complex(self.input_pad(signal,208)) #dimension of dmd1 is now 4, first dimension is now batch dimension, so propTF needs to be changed accordingly
            
            #now need to implement propTF function
            u2 = self.propTF(dmd_1,1.9e-2, lambdaa, 1.9e-2)
            
            '''Fourier Transform after first lens'''
            Gg = torch.fft.fftshift(torch.fft.fft2(u2))
            Gi = self.complex_mul(Gg,self.conj_transpose(H))
            Gi = Gi.unsqueeze(1)
            
            '''dot product in the fourier plane'''
            Gii = self.complex_mul(Gi,self.make_complex(weight_raw))
            '''Then get the result in real space'''
            Grs = torch.fft.ifft2(torch.fft.ifftshift(Gii))
            Grs = torch.view_as_real(Grs)
            Ii = torch.sqrt(Grs[...,0]**2+Grs[...,1]**2+err)
            op_abs = self.extract_result(Ii,32)
            output_full += op_abs
        return output_full

    def forward(self, input):
        return self.accurate_model_forward(input, self.weight)


class FFTconv(nn.Module):
    def __init__(self):
        super(FFTconv, self).__init__()
        self.conv1 = FTconvlayer(3, 16, 208) #outdimension should be [32,16,28,28] (1/3,:,5/28/32) depends on whether training in fourier domain or not
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.fc01 = nn.Linear(16 * 16 * 16, 256)
        self.fc02 = nn.Linear(256, 10)
        self.drop_layer = nn.Dropout(p=0.3)
    
    def forward(self, x):
        x = self.bn1(self.pool1(self.conv1(x)))
        x = x.view(-1, 16 * 16 * 16)
        x = F.relu(self.fc01(x))
        x = self.fc02(x)
        return x

net = FFTconv()    


if device:
    net.to(device)
    print("put net onto GPU")
print(net)


In [None]:
import torch.optim as optim

n_epoch = 15
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001) #default learning rate for adam is 0.001
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)

def adjust_learning_rate(optimizer, epoch, init_lr, freq):
    """Sets the learning rate to the initial LR decayed by 2 every n epochs"""
    lr = init_lr * (0.5 ** (epoch // freq))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

best_acc = 0
for epoch in range(n_epoch):  # loop over the dataset multiple times
    print('Training epoch ...')
    start_time = time.time()
    running_loss = 0.0
    adjust_learning_rate(optimizer, epoch, 0.001, 8)
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
    scheduler.step()
    inf_acc = test()
    best_acc = max(best_acc, inf_acc)

print('Training finished, best accuracy is ', best_acc)
