In [1]:
#it is recommended to run in Google Colab.

In [None]:
!pip install -i https://test.pypi.org/simple/ lightbridge

In [3]:
import os
import csv
from time import time
import random
import pathlib
import argparse
import numpy as np
from tqdm import tqdm
import pandas as pd
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pickle
import lightbridge.utils as utils
import lightbridge.layers as layers

In [4]:
class NetCodesign(torch.nn.Module):
    def __init__(self, phase_func, intensity_func, wavelength=5.32e-7, pixel_size=0.000036, batch_norm=False, sys_size = 200, distance=0.1, num_layers=2, precision=256, amp_factor=6):
        super(NetCodesign, self).__init__()
        self.amp_factor = amp_factor
        self.size = sys_size
        self.distance = distance
        self.phase_func = phase_func.cuda()
        self.intensity_func = intensity_func.cuda()
        self.wavelength = wavelength
        self.pixel_size = pixel_size
        #self.phase_func = phase_func
        #self.intensity_func = intensity_func
        self.diffractive_layers = torch.nn.ModuleList([layers.DiffractiveLayer(self.phase_func, self.intensity_func, size=self.size,
                                                    wavelength = self.wavelength, pixel_size = self.pixel_size,
                                distance=self.distance, amplitude_factor = amp_factor, phase_mod=True) for _ in range(num_layers)])
        self.last_diffraction = layers.DiffractiveLayer(None, None, size=self.size, distance=self.distance, phase_mod=False)
        # 200 by 200 system siz det designe
        self.detector = layers.Detector(start_x = [46,46,46], start_y = [46,46,46], det_size = 20,
                                        gap_x = [19,20], gap_y = [27, 12, 27])
    def forward(self, x):
        for index, layer in enumerate(self.diffractive_layers):
            x = layer(x)
        x = self.last_diffraction(x)
        output = self.detector(x)
        return output


In [5]:
def train(model,train_dataloader, val_dataloader,input_padding, lambda1):
    criterion = torch.nn.MSELoss(reduction='sum').cuda()
    print('training starts.')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.7)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20, gamma=0.5)
    for epoch in range(1, 21):
        log = [epoch]
        model.train()
        train_len = 0.0
        train_running_counter = 0.0
        train_running_loss = 0.0
        tk0 = tqdm(train_dataloader, ncols=150, total=int(len(train_dataloader)))
        for train_iter, train_data_batch in enumerate(tk0):
            train_images, train_labels = utils.data_to_cplex(train_data_batch)
            train_outputs = model(train_images)
            train_loss_ = lambda1 * criterion(train_outputs, train_labels)
            train_counter_ = torch.eq(torch.argmax(train_labels, dim=1), torch.argmax(train_outputs, dim=1)).float().sum()

            optimizer.zero_grad()
            train_loss_.backward(retain_graph=True)
            optimizer.step()
            train_len += len(train_labels)
            train_running_loss += train_loss_.item()
            train_running_counter += train_counter_

            train_loss = train_running_loss / train_len
            train_accuracy = train_running_counter / train_len

            tk0.set_description_str('Epoch {}/{} : Training'.format(epoch, 20))
            tk0.set_postfix({'Train_Loss': '{:.2f}'.format(train_loss), 'Train_Accuracy': '{:.5f}'.format(train_accuracy)})
        scheduler.step()
        log.append(train_loss)
        log.append(train_accuracy)

        with open('./result.csv', 'a', newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(log)
        val_loss, val_accuracy = eval(model, val_dataloader, epoch,input_padding)
        log.append(val_loss)
        log.append(val_accuracy)
    return train_loss, train_accuracy, val_loss, val_accuracy, log

In [6]:
def eval(model, val_dataloader, epoch, input_padding):
    criterion = torch.nn.MSELoss(reduction='sum').cuda()
    with torch.no_grad():
        model.eval()
        val_len = 0.0
        val_running_counter = 0.0
        val_running_loss = 0.0

        tk1 = tqdm(val_dataloader, ncols=100, total=int(len(val_dataloader)))
        for val_iter, val_data_batch in enumerate(tk1):
            val_images, val_labels = utils.data_to_cplex(val_data_batch)
            val_outputs = model(val_images)

            val_loss_ = criterion(val_outputs, val_labels)
            val_counter_ = torch.eq(torch.argmax(val_labels, dim=1), torch.argmax(val_outputs, dim=1)).float().sum()

            val_len += len(val_labels)
            val_running_loss += val_loss_.item()
            val_running_counter += val_counter_

            val_loss = val_running_loss / val_len
            val_accuracy = val_running_counter / val_len

            tk1.set_description_str('Epoch {}/{} : Validating'.format(epoch, 20))
            tk1.set_postfix({'Val_Loss': '{:.5f}'.format(val_loss), 'Val_Accuarcy': '{:.5f}'.format(val_accuracy)})
    return val_loss, val_accuracy


In [8]:
#Start training, example with Dataset: MNIST-10; Precision=256; phase and intensity files are user-uploaded; 
#depth of the model is 2; system_size is 200; distance between layers is 0.6604m; 
#amp_factor is 50; learning rate is 0.7; training epoch is 20;

torch.autograd.set_detect_anomaly(True)

transform = transforms.Compose([transforms.Resize((200),interpolation=2),transforms.ToTensor()])
print("training and testing on MNIST10 dataset")
train_dataset = torchvision.datasets.MNIST("./data", train=True, transform=transform, download=True)
val_dataset = torchvision.datasets.MNIST("./data", train=False, transform=transform, download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=600, num_workers=8, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=600, num_workers=8, shuffle=False, pin_memory=True)
input_padding = 0

phase_file = "phase.csv"
phase_function = utils.phase_func(phase_file,  i_k=256)
with open('phase_file.npy', 'wb') as f_phase:
        np.save(f_phase, phase_function.cpu().numpy())
intensity_file = "intensity.csv"
intensity_function = utils.intensity_func(intensity_file,  i_k=256)
with open('intensity_file.npy', 'wb') as f_amp:
        np.save(f_amp, intensity_function.cpu().numpy())

model = NetCodesign(num_layers=2, batch_norm =False, wavelength=5.32e-7, pixel_size=0.000036, sys_size=200, distance=0.6604,phase_func=phase_function, intensity_func=intensity_function, precision=256, amp_factor=50)
model.cuda()
lambda1= 1

start_time = time()
train(model, train_dataloader, val_dataloader, input_padding, lambda1)
print('run time', time()-start_time)
exit()

  "Argument interpolation should be of type InterpolationMode instead of int. "
  cpuset_checked))


training and testing on MNIST10 dataset


  0%|                                                                                                                         | 0/100 [00:00<?, ?it/s]

training starts.


Epoch 1/20 : Training: 100%|███████████████████████████████████████████████| 100/100 [00:33<00:00,  2.94it/s, Train_Loss=0.76, Train_Accuracy=0.58450]
Epoch 1/20 : Validating: 100%|█| 17/17 [00:05<00:00,  3.40it/s, Val_Loss=0.39255, Val_Accuarcy=0.786
Epoch 2/20 : Training: 100%|███████████████████████████████████████████████| 100/100 [00:33<00:00,  3.01it/s, Train_Loss=0.26, Train_Accuracy=0.85677]
Epoch 2/20 : Validating: 100%|█| 17/17 [00:05<00:00,  3.39it/s, Val_Loss=0.15603, Val_Accuarcy=0.913
Epoch 3/20 : Training: 100%|███████████████████████████████████████████████| 100/100 [00:33<00:00,  3.02it/s, Train_Loss=0.16, Train_Accuracy=0.90960]
Epoch 3/20 : Validating: 100%|█| 17/17 [00:05<00:00,  3.33it/s, Val_Loss=0.16081, Val_Accuarcy=0.911
Epoch 4/20 : Training: 100%|███████████████████████████████████████████████| 100/100 [00:33<00:00,  3.01it/s, Train_Loss=0.13, Train_Accuracy=0.92610]
Epoch 4/20 : Validating: 100%|█| 17/17 [00:05<00:00,  3.33it/s, Val_Loss=0.11417, Val_Accuarc

run time 768.3078391551971



