Imports

In [7]:
import os
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import wandb

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

HyperParameter Setup

In [2]:
ADAM_LR = 0.000018
BETAS = (0.75, 0.999)
NUM_EPOCHS = 3000

wandb_config = {
    "Learning_Rate": ADAM_LR,
    "Betas": BETAS,
    "Num_Epochs": NUM_EPOCHS
}

Hyper Parameter Sweep Config

In [3]:
sweep_config = {
    'method': 'random'
}

metric = {
    'name' : 'loss',
    'goal' : 'minimize'
}


parameters_dict = {
    'learning_rate': {
        'values': [0.000050, 0.000040, 0.000030, 0.000020, 0.000010, 0.000044, 0.000035, 0.000025, 0.000015]
    },
    'beta_val1': {
        'distribution': 'uniform',
        'min': 0.5,
        'max': 0.99999
    },
    'beta_val2': {
        'distribution': 'uniform',
        'min': 0.5,
        'max': 0.99999
    },
    'epochs': {
        'value': 3000
    }
    
}
sweep_config['metric'] = metric
sweep_config['parameters'] = parameters_dict

Weights and Biases Setup

In [4]:
wandb.login()

# sweep_id = wandb.sweep(sweep_config, project="ePIE - HyperParameter Sweep")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msrsbingresearch[0m ([33mbingsrs[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

Load Data

In [5]:
diff_grid = np.load('dataset/diff_grid.npz')['arr_0']
label = np.load('dataset/diffraction_label.npz')['arr_0']

Encoder

In [6]:
def conv(in_channels, out_channels):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=3, padding=1),
    nn.LeakyReLU(negative_slope=0.01, inplace=True),
    nn.BatchNorm2d(out_channels),
  )
  
def conv_max(in_channels, out_channels):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=3, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.LeakyReLU(negative_slope=0.01, inplace=True),
    nn.BatchNorm2d(out_channels),
  )

In [7]:
class Encoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.block1 = conv(1, 16)
    self.block2 = conv_max(16, 32)
    self.block3 = conv_max(32, 64)
    self.block4 = conv_max(64, 128)

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    return x

Decoder

In [8]:
def convTrans(in_channels, out_channels):
  return nn.Sequential(
    nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=2),
    #nn.ReLU(inplace = True),
    nn.LeakyReLU(negative_slope=0.01, inplace=True),
  )

def up_conv(in_channels, out_channels, padding):
  return nn.Sequential(
    nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=2, padding=padding),
    #nn.ReLU(inplace = True),
    nn.LeakyReLU(negative_slope=0.01, inplace=True),
    nn.BatchNorm2d(out_channels),
    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
  )

In [9]:
class Decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.block1 = up_conv(128, 64, 1)
    self.block2 = convTrans(64, 32)
    self.block3 = up_conv(32, 16, 2)
    self.block4 = convTrans(16, 1)

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    return x

Encoder-Decoder Model

In [10]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.encoder = Encoder()
    self.phase_decoder = Decoder()
    self.amp_decoder = Decoder()
    self.tanh = nn.Tanh()
    self.sigmoid = nn.Sigmoid()

  def forward(self, diffraction):
    latent_z = self.encoder(diffraction)
    phase = self.tanh(self.phase_decoder(latent_z))
    phase = phase*cp.pi
    amp = self.sigmoid(self.amp_decoder(latent_z))
    return phase, amp

Training

In [11]:
lossfn =  nn.MSELoss()
def ModelLoss(preds1, targets1, preds2, targets2):
  loss1 = lossfn(preds1, targets1)
  loss2 = lossfn(preds2, targets2)
  # loss2 = nn.functional.binary_cross_entropy(preds2, targets2, reduction='mean')
  return loss1, loss2

In [12]:
PtychoModel = Model().to(device)
PtychoModel.load_state_dict(torch.load('models/overfit4.pth'))
diff = torch.tensor(diff_grid,device=device).float()
phase = torch.tensor(label[:, 0],device=device).float()
amp = torch.tensor(label[:, 1],device=device).float()

In [13]:
def train():
  # Look into Scheduler: (varies the learning rate of optimizer) scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=, max_lr=, step_size_up=)
  optimizer = torch.optim.Adam(PtychoModel.parameters(), lr=ADAM_LR, betas=BETAS)
  for epoch in range(NUM_EPOCHS):

    PtychoModel.train()
    phase_pred, amp_pred = PtychoModel(diff)
    loss1, loss2 = ModelLoss(phase_pred, phase, amp_pred, amp)
    loss = loss1+loss2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 100 == 0:
      print("Epoch: ", (epoch+1), "Training Loss: ", round(loss.item(), 5), round(loss1.item(), 7), round(loss2.item(), 7))
    wandb.log({
      'loss': round(loss.item(),3),
      'loss1': round(loss1.item(),4),
      'loss2': round(loss2.item(),4)
    })

In [14]:
# wandb.agent(sweep_id, train, count = 10)
wandb.init(config=wandb_config)
train()
wandb.finish()

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  return F.mse_loss(input, target, reduction=self.reduction)


In [None]:
PtychoModel.eval()
phase_pred, amp_pred = PtychoModel(diff)

In [None]:
f, ax = plt.subplots(4,4, figsize=(12, 12), facecolor='white')
ax[0,0].set_ylabel('PtychoNeuralNetwork', fontsize = 12.0)
ax[1,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 12.0)
ax[2,0].set_ylabel('PtychoNeuralNetwork', fontsize = 12.0)
ax[3,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 12.0)

ax[0,0].imshow(phase_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[0,1].imshow(phase_pred[1].cpu().detach().numpy().reshape((650,650)))
ax[0,2].imshow(phase_pred[2].cpu().detach().numpy().reshape((650,650)))
ax[0,3].imshow(phase_pred[3].cpu().detach().numpy().reshape((650,650)))
ax[1,0].imshow(phase[0].cpu().detach().numpy().reshape((650,650)))
ax[1,1].imshow(phase[1].cpu().detach().numpy().reshape((650,650)))
ax[1,2].imshow(phase[2].cpu().detach().numpy().reshape((650,650)))
ax[1,3].imshow(phase[3].cpu().detach().numpy().reshape((650,650)))
ax[2,0].imshow(amp_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[2,1].imshow(amp_pred[1].cpu().detach().numpy().reshape((650,650)))
ax[2,2].imshow(amp_pred[2].cpu().detach().numpy().reshape((650,650)))
ax[2,3].imshow(amp_pred[3].cpu().detach().numpy().reshape((650,650)))
ax[3,0].imshow(amp[0].cpu().detach().numpy().reshape((650,650)))
ax[3,1].imshow(amp[1].cpu().detach().numpy().reshape((650,650)))
ax[3,2].imshow(amp[2].cpu().detach().numpy().reshape((650,650)))
ax[3,3].imshow(amp[3].cpu().detach().numpy().reshape((650,650)))


In [None]:
f, ax = plt.subplots(2,2, figsize=(11, 10), facecolor='white')
ax[0,0].set_ylabel('PtychoNeuralNetwork', fontsize = 20.0)
ax[0,0].set_title('Amplitude', fontsize = 20.0)
ax[0,0].imshow(phase_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[0,1].set_title('Phase', fontsize = 20.0)
ax[0,1].imshow(amp_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[1,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 20.0)
ax[1,0].imshow(phase[0].cpu().detach().numpy().reshape((650,650)))
ax[1,1].imshow(amp[0].cpu().detach().numpy().reshape((650,650)))

In [None]:
torch.save(PtychoModel.state_dict(), 'overfit4.pth')