In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import Encoder_Decoder_Model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

Load Data

In [None]:
# diff_grid = np.load('dataset/diff_grid.npz')['arr_0'] # Non-compressed 7100x7100 diff grid
diff_grid = np.load('dataset/compressed_diff_grid.npz')['arr_0'] # Compressed 1775x1775 diff grid
label = np.load('dataset/norm_diffraction_label.npz')['arr_0']

In [None]:
lossfn =  nn.BCELoss(reduction='mean')

def ModelLoss(preds1, targets1, preds2, targets2):
  loss1 = lossfn(preds1, targets1)
  loss2 = lossfn(preds2, targets2)
  return loss1, loss2

In [None]:
PtychoModel = Encoder_Decoder_Model.Model().to(device)
PtychoModel.load_state_dict(torch.load('models/MSE_10000.pth'))
diff = torch.tensor(diff_grid,device=device).float()
phase = torch.tensor(label[:, 0],device=device).float()
amp = torch.tensor(label[:, 1],device=device).float()
LR = 0.00013
step_size = 8000
optimizer = torch.optim.AdamW(PtychoModel.parameters(), lr=LR, betas=(0.59418, 0.8699))
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=LR/10, max_lr=LR, step_size_up=step_size, cycle_momentum=False, mode='triangular2')

In [None]:
num_epochs = 10000

for epoch in range(num_epochs):

  PtychoModel.train()
  phase_pred, amp_pred = PtychoModel(diff)
  loss1, loss2 = ModelLoss(phase_pred, phase, amp_pred, amp)
  loss = loss1 + loss2

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  scheduler.step()

  if((epoch+1) % 50 == 0):
    print("Epoch: ", (epoch+1), " Training Loss: ", round(loss.item(), 5), " L1: ",round(loss1.item(), 7)," L2: ", round(loss2.item(), 7))

torch.save(PtychoModel.state_dict(), 'models/model_name.pth')

In [None]:
PtychoModel.eval()
phase_pred, amp_pred = PtychoModel(diff)

In [None]:
f, ax = plt.subplots(4,4, figsize=(12, 12), facecolor='white')
ax[0,0].set_ylabel('PtychoNeuralNetwork', fontsize = 12.0)
ax[1,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 12.0)
ax[2,0].set_ylabel('PtychoNeuralNetwork', fontsize = 12.0)
ax[3,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 12.0)

ax[0,0].imshow(phase_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[0,1].imshow(phase_pred[1].cpu().detach().numpy().reshape((650,650)))
ax[0,2].imshow(phase_pred[2].cpu().detach().numpy().reshape((650,650)))
ax[0,3].imshow(phase_pred[3].cpu().detach().numpy().reshape((650,650)))
ax[1,0].imshow(phase[0].cpu().detach().numpy().reshape((650,650)))
ax[1,1].imshow(phase[1].cpu().detach().numpy().reshape((650,650)))
ax[1,2].imshow(phase[2].cpu().detach().numpy().reshape((650,650)))
ax[1,3].imshow(phase[3].cpu().detach().numpy().reshape((650,650)))
ax[2,0].imshow(amp_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[2,1].imshow(amp_pred[1].cpu().detach().numpy().reshape((650,650)))
ax[2,2].imshow(amp_pred[2].cpu().detach().numpy().reshape((650,650)))
ax[2,3].imshow(amp_pred[3].cpu().detach().numpy().reshape((650,650)))
ax[3,0].imshow(amp[0].cpu().detach().numpy().reshape((650,650)))
ax[3,1].imshow(amp[1].cpu().detach().numpy().reshape((650,650)))
ax[3,2].imshow(amp[2].cpu().detach().numpy().reshape((650,650)))
ax[3,3].imshow(amp[3].cpu().detach().numpy().reshape((650,650)))
