In [None]:
import os
import cv2
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch import optim
#from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import ViT_Model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device 

Load Data

In [None]:
class SegData(Dataset):
  def __init__(self):
    self.diff_grid = torch.tensor(np.load('../dataset/compressed_diff_grid.npz')['arr_0']).float()
    labels = np.load('../dataset/norm_diffraction_label.npz')['arr_0']
    self.phase = torch.tensor(labels[:, 0]).float()
    self.amp = torch.tensor(labels[:, 1]).float()
  def __len__(self):
    return self.diff_grid.shape[0]
  def __getitem__(self, i):
    return (self.diff_grid[i], self.phase[i], self.amp[i])

In [None]:
trn_ds = SegData()
trn_dl = DataLoader(trn_ds, batch_size=2, shuffle=True)

In [None]:
lossfn =  nn.BCELoss(reduction='mean')

def ModelLoss(preds1, targets1):
  loss1 = lossfn(preds1, targets1)
  return loss1

In [None]:
ViTModel = ViT_Model.ViTGenerator().to(device)
LR = 0.00013
step_size = 8000
criterion = lossfn

optimizer = optim.AdamW(ViTModel.parameters(), lr=LR, betas=(0.59418, 0.8699))
scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=LR/10, max_lr=LR, step_size_up=step_size, cycle_momentum=False, mode='triangular2')

In [None]:
summary(ViTModel, (1, 1775, 1775))

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
  for bx, data in enumerate(trn_dl):
    diff, phase, amp = data.to(device)
    ViTModel.train()
    amp_pred = ViTModel(diff)
    
    loss = ModelLoss(amp_pred, amp)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if((epoch+1) % 1 == 0):
      print("Epoch: ", (epoch+1), " Training Loss: ", round(loss.item(), 5))

#torch.save(ViTModel.state_dict(), 'models/model_name.pth')

In [None]:
ViTModel.eval()
amp_pred = ViTModel(diff)

In [None]:
f, ax = plt.subplots(2,4, figsize=(12, 12), facecolor='white')
ax[0,0].set_ylabel('ViT', fontsize = 12.0)
ax[1,0].set_ylabel('E-Pie (300 Iterations)', fontsize = 12.0)

ax[0,0].imshow(amp_pred[0].cpu().detach().numpy().reshape((650,650)))
ax[0,1].imshow(amp_pred[1].cpu().detach().numpy().reshape((650,650)))
ax[0,2].imshow(amp_pred[2].cpu().detach().numpy().reshape((650,650)))
ax[0,3].imshow(amp_pred[3].cpu().detach().numpy().reshape((650,650)))
ax[1,0].imshow(amp[0].cpu().detach().numpy().reshape((650,650)))
ax[1,1].imshow(amp[1].cpu().detach().numpy().reshape((650,650)))
ax[1,2].imshow(amp[2].cpu().detach().numpy().reshape((650,650)))
ax[1,3].imshow(amp[3].cpu().detach().numpy().reshape((650,650)))