In [1]:
import importlib
import torch
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torchvision.transforms as transforms
import os
from data import BraTSDataset
import torchvision.utils as vutils
from PIL import Image

In [2]:
from gan.UNetGenerator import UNetGenerator
import config.unetConfig as cfg

In [3]:
image_size = 64

### Store T1/T1ce real images

In [4]:
dataset_root = "dataset"
t1_train_data = "data/MICCAI_BraTS2020/train/t1"
store_path = "images/t1/real"
image_paths = [
    os.path.join(t1_train_data, impath) for impath in os.listdir(t1_train_data)
]
batch_size = len(image_paths)
tf = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    ])
dataset = BraTSDataset(image_paths, tf)

dataset.save(store_path)

In [5]:
# Model image mapping

# wgan_18_35_30_11       -> wgan_gp_t1
# psnr_threshold = 17.5
# ssim_threshold = 0.65

# wgan_t1ce_21_21_30_11  -> wgan_gp_t1ce
# psnr_threshold = 17.5
# ssim_threshold = 0.65

# wgan_t1ce_22_14_30_11  -> wgan_t1ce
# psnr_threshold = 14.5
# ssim_threshold = 0.35

# wgan_20_22_29_11       -> wgan_t1
# psnr_threshold = 14.5
# ssim_threshold = 0.35

# dcgan_00_07_01_12      -> dcgan_t1
# psnr_threshold = 16.5
# ssim_threshold = 0.55

# dcgan_t1ce_08_17_02_12 -> dcgan_t1ce
# psnr_threshold = 18.5
# ssim_threshold = 0.55

# unetgan/norm_winit/unetgan_T1_Gstate__1000epochs_128batch_G_128z_32feat_0.0002lr_D_3lvl_32feat_0.0002lr_0.2lRelu -> images/t1/unetgan_norm_winit

# unetgan/norm_winit/unetgan_T1CE_Gstate__1000epochs_128batch_G_128z_32feat_0.0002lr_D_3lvl_32feat_0.0002lr_0.2lRelu -> images/t1ce/unetgan_norm_winit

# unetgan/ortho_winit/unetgan_T1_Gstate__1000epochs_128batch_G_128z_32feat_0.0002lr_D_3lvl_32feat_0.0002lr_0.2lRelu -> images/t1/unetgan_ortho_winit

# unetgan/ortho_winit/unetgan_T1CE_Gstate__1000epochs_128batch_G_128z_32feat_0.0002lr_D_3lvl_32feat_0.0002lr_0.2lRelu -> images/t1ce/unetgan_ortho_winit

# CHANGE THESE BEFORE RUNNING

model_folder = "models/unetgan/ortho_winit/unetgan_T1CE_Gstate__1000epochs_128batch_G_128z_32feat_0.0002lr_D_3lvl_32feat_0.0002lr_0.2lRelu"
image_folder = "images/t1ce/unetgan/ortho"
# psnr_threshold = 18.5
# ssim_threshold = 0.55

In [6]:
def compute_psnr(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    psnr_val = 0.0
    # fake_batch will only have one batch
    for i in range(b_size):
        psnr_val += psnr(
            real_batch[i, :, :, :].transpose(1, 2, 0),
            fake_batch[0, :, :, :].transpose(1, 2, 0),
            data_range=1.0,
        )
    return psnr_val / b_size

In [7]:
def compute_ssim(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    ssim_val = 0.0
    # fake_batch will only have one batch
    for i in range(b_size):
        ssim_val += ssim(
            real_batch[i, :, :, :],
            fake_batch[0, :, :, :],
            channel_axis=0,
            data_range=1.0,
        )
    return ssim_val / b_size

In [8]:
# model_library = ".".join(model_folder.split("/"))
# model = importlib.import_module(f"{model_library}.model")

In [9]:
# model_g = model.Generator().to(model.device)
# model_g.load_state_dict(torch.load(f"{model_folder}/generator.pth"))
# model_g.eval()
# print(model_g)

In [10]:
netG = UNetGenerator(cfg.LATENT_SZ, cfg.NGF, cfg.NGC).to(cfg.DEVICE)
netG.load_state_dict(torch.load(f"{model_folder}.pth"))
netG.eval()
print(netG)

UNetGenerator(
  (main): Sequential(
    (0): ConvTranspose2d(128, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [11]:
tf = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    ])

num_iter_per_image = 100
for i in range(len(image_paths)):
    real = dataset.__getitem__(i).numpy()[None, ...]
    best_psnr_val = float("-inf")
    best_ssim_val = None
    best_image = None
    for j in range(num_iter_per_image):

        noise = torch.randn(1, cfg.LATENT_SZ, 1, 1, device=cfg.DEVICE)
        fake_image = netG(noise).detach().cpu().numpy()
        fake_image = fake_image[0][0]
        fake_image = np.uint8(fake_image * 255)
        fake_image = Image.fromarray(fake_image, mode="L")
        fake_image = tf(fake_image)
        fake_image = fake_image.numpy()[None, ...]

        psnr_val = compute_psnr(real, fake_image)
        if psnr_val > best_psnr_val:
            best_psnr_val = psnr_val
            best_image = fake_image.copy()
            best_ssim_val = compute_ssim(real, fake_image)
    print(f"Best image PSNR: {best_psnr_val}")
    print(f"Best image SSIM: {best_ssim_val}")
    best_image = torch.Tensor(best_image)
    vutils.save_image(best_image, f"{image_folder}/{i}.png")
    print(f"Saved {i+1} image")

Best image PSNR: 20.52035389429495
Best image SSIM: 0.7478031516075134
Saved 1 image
Best image PSNR: 20.336846836620627
Best image SSIM: 0.7290620803833008
Saved 2 image
Best image PSNR: 19.213180944073546
Best image SSIM: 0.6682518720626831
Saved 3 image
Best image PSNR: 20.156682372025408
Best image SSIM: 0.6764939427375793
Saved 4 image
Best image PSNR: 19.65086779707654
Best image SSIM: 0.701614260673523
Saved 5 image
Best image PSNR: 19.4507252338136
Best image SSIM: 0.6992819905281067
Saved 6 image
Best image PSNR: 22.90961629370592
Best image SSIM: 0.708570659160614
Saved 7 image
Best image PSNR: 19.56861252371147
Best image SSIM: 0.684928834438324
Saved 8 image
Best image PSNR: 19.916664372322213
Best image SSIM: 0.682972252368927
Saved 9 image
Best image PSNR: 20.496867480809932
Best image SSIM: 0.7621201276779175
Saved 10 image
Best image PSNR: 19.309794253499867
Best image SSIM: 0.6928215622901917
Saved 11 image
Best image PSNR: 21.17132286403715
Best image SSIM: 0.76337027

In [None]:
# num_iter_per_image = 100
# for i in range(len(image_paths)):
#     best_psnr_val = float("-inf")
#     best_ssim_val = None
#     best_image = None
#     real = dataset.__getitem__(i).numpy()[None, ...]
#     for j in range(num_iter_per_image):
#         noise = torch.randn(1, model.latent_size, 1, 1, device=model.device)
#         fake_image = model_g(noise).detach().cpu().numpy()
#         psnr_val = compute_psnr(real, fake_image)
#         if psnr_val > best_psnr_val:
#             best_psnr_val = psnr_val
#             best_image = fake_image
#             best_ssim_val = compute_ssim(real, fake_image)
#     print(f"Best image PSNR: {best_psnr_val}")
#     print(f"Best image SSIM: {best_ssim_val}")
#     best_image = torch.Tensor(best_image)
#     vutils.save_image(best_image, f"{image_folder}/{i}.png")
#     print(f"Saved {i+1} image")

In [None]:
# fake = np.zeros((batch_size, 1, model.image_size, model.image_size))
# # Don't stop until required number of images have been saved
# total_num_images = batch_size
# num_images = 0
# while(True):
#     noise = torch.randn(1, model.latent_size, 1, 1, device=model.device)
#     fake_image = model_g(noise).detach().cpu().numpy()
#     psnr_val = compute_psnr(real, fake_image)
#     ssim_val = compute_ssim(real, fake_image)
#     if psnr_val > psnr_threshold and ssim_val > ssim_threshold:
#         fake[num_images] = fake_image
#         num_images += 1
#         print(f"Saved image with psnr: {psnr_val}, ssim: {ssim_val}, total: {num_images}")
#         # Save fake_image
#         # Stop if we have generated equal number of fake images
#         # as compared to real images
#         if num_images >= total_num_images:
#             break

In [None]:
# real = torch.Tensor(real)
# fake = torch.Tensor(fake)
# num_images = real.shape[0]
# print(f"Saving {num_images} images")
# # Save images
# for i in range(num_images):
#     # Save real
#     vutils.save_image(real[i], f"{image_folder}/real/{i+1}.png")

#     # Save fake
#     vutils.save_image(fake[i], f"{image_folder}/fake/{i+1}.png")