In [1]:
import importlib
import torch
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torchvision.transforms as transforms
import os
from data import BraTSDataset
import torchvision.utils as vutils
from PIL import Image

In [11]:
# Model image mapping

# wgan_18_35_30_11       -> wgan_gp_t1
# psnr_threshold = 17.5
# ssim_threshold = 0.65

# wgan_t1ce_21_21_30_11  -> wgan_gp_t1ce
# psnr_threshold = 17.5
# ssim_threshold = 0.65

# wgan_t1ce_22_14_30_11  -> wgan_t1ce
# psnr_threshold = 14.5
# ssim_threshold = 0.35

# wgan_20_22_29_11       -> wgan_t1
# psnr_threshold = 14.5
# ssim_threshold = 0.35

# dcgan_00_07_01_12      -> dcgan_t1
# psnr_threshold = 16.5
# ssim_threshold = 0.55

# dcgan_t1ce_08_17_02_12 -> dcgan_t1ce
# psnr_threshold = 18.5
# ssim_threshold = 0.55

model_folder = "models/dcgan_t1ce_08_17_02_12"
image_folder = "images/dcgan_t1ce"
psnr_threshold = 18.5
ssim_threshold = 0.55

In [3]:
def compute_psnr(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    psnr_val = 0.0
    # fake_batch will only have one batch
    for i in range(b_size):
        psnr_val += psnr(
            real_batch[i, :, :, :].transpose(1, 2, 0),
            fake_batch[0, :, :, :].transpose(1, 2, 0),
            data_range=1.0,
        )
    return psnr_val / b_size

In [4]:
def compute_ssim(real_batch: np.ndarray, fake_batch: np.ndarray) -> float:
    b_size = real_batch.shape[0]
    ssim_val = 0.0
    # fake_batch will only have one batch
    for i in range(b_size):
        ssim_val += ssim(
            real_batch[i, :, :, :],
            fake_batch[0, :, :, :],
            channel_axis=0,
            data_range=1.0,
        )
    return ssim_val / b_size

In [5]:
model_library = ".".join(model_folder.split("/"))
model = importlib.import_module(f"{model_library}.model")

In [6]:
model_g = model.Generator().to(model.device)
model_g.load_state_dict(torch.load(f"{model_folder}/generator.pth"))
model_g.eval()
print(model_g)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(128, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


In [7]:
dataset_root = "dataset"
t1_train_data = "data/MICCAI_BraTS2020/train/t1"
image_paths = [os.path.join(t1_train_data, impath) for impath in os.listdir(t1_train_data)]
batch_size = len(image_paths)
tf = transforms.Compose([
    transforms.Resize(model.image_size),
    transforms.CenterCrop(model.image_size),
    transforms.ToTensor(),
    ])
dataset = BraTSDataset(image_paths, tf)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=model.num_workers)

In [8]:
real = next(iter(dataloader)).numpy()

In [9]:
fake = np.zeros((batch_size, 1, model.image_size, model.image_size))
# Don't stop until required number of images have been saved
total_num_images = batch_size
num_images = 0
while(True):
    noise = torch.randn(1, model.latent_size, 1, 1, device=model.device)
    fake_image = model_g(noise).detach().cpu().numpy()
    psnr_val = compute_psnr(real, fake_image)
    ssim_val = compute_ssim(real, fake_image)
    if psnr_val > psnr_threshold and ssim_val > ssim_threshold:
        fake[num_images] = fake_image
        num_images += 1
        print(f"Saved image with psnr: {psnr_val}, ssim: {ssim_val}, total: {num_images}")
        # Save fake_image
        # Stop if we have generated equal number of fake images
        # as compared to real images
        if num_images >= total_num_images:
            break



Saved image with psnr: 18.60107085393825, ssim: 0.6477680348445406, total: 1
Saved image with psnr: 18.818984575974753, ssim: 0.6236392321786907, total: 2
Saved image with psnr: 19.01415245854079, ssim: 0.6362212327760732, total: 3
Saved image with psnr: 18.64989313915073, ssim: 0.6005953881475661, total: 4
Saved image with psnr: 19.149990691306897, ssim: 0.6383747230700361, total: 5
Saved image with psnr: 18.77033563485183, ssim: 0.6166030783155746, total: 6
Saved image with psnr: 18.99934208481603, ssim: 0.6324125589717048, total: 7
Saved image with psnr: 18.701346498947927, ssim: 0.6318194910601226, total: 8
Saved image with psnr: 18.70241098731378, ssim: 0.6054036696428852, total: 9
Saved image with psnr: 19.007480657151287, ssim: 0.66401046433746, total: 10
Saved image with psnr: 19.526741464303992, ssim: 0.6414333496313431, total: 11
Saved image with psnr: 18.81886513543014, ssim: 0.6171247880794815, total: 12
Saved image with psnr: 19.09065759917187, ssim: 0.6472683671690261, to

In [12]:
real = torch.Tensor(real)
fake = torch.Tensor(fake)
num_images = real.shape[0]
print(f"Saving {num_images} images")
# Save images
for i in range(num_images):
    # Save real
    vutils.save_image(real[i], f"{image_folder}/real/{i+1}.png")

    # Save fake
    vutils.save_image(fake[i], f"{image_folder}/fake/{i+1}.png")

Saving 369 images
