In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/Acad/ADS/Project/

In [None]:
'''
Notebook to generate samples using trained diffusion model
'''

import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from modules_conditional import UNet, Diffusion

device = "cuda"
model = UNet().to(device)
model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
diffusion = Diffusion(img_size=128, device=device)

x_test = np.load('./Data/test_images.npy').astype(np.float32)[:20].reshape(-1,5,3,128,128)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(5,3,128,128))
    data = data.to(device)
    recon = diffusion.sample(model, n=data.shape[0], lat=data)
    out.append(recon.cpu().detach().numpy().reshape(5,1,128,128))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/test_labels.npy').astype(np.float32)[:20].reshape(-1,5,1,128,128)
x_out = dataSR.astype(np.float32)

x_test = x_test.reshape(-1,3,128,128) # LR
x = x.reshape(-1,1,128,128) # HR
x_out = x_out.reshape(-1,1,128,128) # SR

print("Metrics:")
criteria = nn.MSELoss()
losses = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))

In [None]:
# Visualize samples
import cv2
dataLR = x_test
dataHR = x

def rescale_img(img):
    min_vals = np.min(img, axis=(0,1), keepdims=True)
    max_vals = np.max(img, axis=(0,1), keepdims=True)
    return (img - min_vals) / (max_vals - min_vals)

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=3,figsize=(16,3))
  plt.sca(axarr[0]);
  plt.imshow(rescale_img(x_test[i].transpose(1, 2, 0))); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]);
  plt.imshow(x_out[i][0]); plt.title('Model Output Labels')
  plt.sca(axarr[2]);
  plt.imshow(x[i][0]); plt.title('Ground Truth Labels')
  plt.savefig('./Results/Samples/Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()