In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image
import numpy as np
from model import Generator, ResUnetGenerator
from dataset import CustomDataset
import scipy
from skimage.metrics import structural_similarity as ssim
from math import log10
import pytorch_ssim.pytorch_ssim as pytorch_ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

In [2]:
class Metrics:
    """
        Implement metrics for evaluating the results
        - PSNR (Peak Signal-to-Noise Ratio)
        - NMAE
        - SSIM
    """
    def __init__(self):
        pass

    def calculate_ssim(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        ssim_value = pytorch_ssim.ssim(image1, image2)
        return ssim_value

    def calculate_psnr(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        mse = np.mean(np.mean(np.array(image1) - np.array(image2)) ** 2)
        if(mse == 0):  # MSE is zero means no noise is present in the signal. Therefore PSNR have no importance.
            return 100
        max_pixel = 1
        psnr = 20 * log10(max_pixel / np.sqrt(mse))
        return psnr
    
    def calculate_nmae(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        # Flatten the 3D images to 1D arrays
        flat_image1 = np.array(image1).flatten()
        flat_image2 = np.array(image2).flatten()
        
        # Calculate the mean absolute error
        abs_error = np.abs(flat_image1 - flat_image2)
        mean_abs_error = np.mean(abs_error)
        
        # Calculate the range of the pixel values
        pixel_range = np.max(flat_image1) - np.min(flat_image1)
        
        # Calculate the normalized mean absolute error
        nmae = mean_abs_error / pixel_range
        
        return nmae
    
    def calculate_lpips(self, image1, image2):
        lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
        return lpips(image1, image2)

In [28]:
class CFG:
    data_dir = '/home/han/MRI_DATA/BraTS2020 StarGANs/image_2D/test'
    source_contrast = 't2' # pd, mra, t1, t2  # flair, t1ce, t1, t2
    # contrast_list = ['mra', 'pd', 't1', 't2']
    contrast_list = ['flair', 't1ce', 't1', 't2']
    transform = []
    transform.append(T.ToTensor())
    transform.append(T.Resize(256))
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)
    generator_dir = 'stargan_both/models/200000-G.ckpt'
    g_conv_dim = 64
    c_dim = 4
    repeat_num = 6
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')

In [29]:
def label2onehot(labels, dim):
    """Convert label indices to one-hot vectors."""
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def create_labels(c_org, c_dim=4):
    """Generate target domain labels for debugging and testing."""
    c_trg_list = []
    for i in range(c_dim):
        c_trg = label2onehot(torch.ones(c_org.size(0))*i, c_dim)
        c_trg_list.append(c_trg.to(CFG.device))
    return c_trg_list

def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)


In [30]:
dataset = CustomDataset(CFG.data_dir, CFG.source_contrast, CFG.contrast_list, CFG.transform)
data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=1)
metrics = Metrics()

In [31]:
generator = Generator(CFG.g_conv_dim, CFG.c_dim * 2 + 2, CFG.repeat_num)
generator.load_state_dict(torch.load(CFG.generator_dir, map_location=lambda storage, loc: storage))

<All keys matched successfully>

In [32]:
%matplotlib inline
generator.to(CFG.device)
with torch.no_grad():
    ssim = []
    psnr = []
    nmae = []
    for i, data in enumerate(data_loader):
        (x_real, c_org, path) = data['source']
        x_real = x_real.to(CFG.device)
        c_org = c_org.to(CFG.device)

        # HERE
        # c_ixi_list = create_labels(c_org, CFG.c_dim)
        # zero_brats2020 = torch.zeros(x_real.size(0), CFG.c_dim).to(CFG.device)  
        # mask_ixi = label2onehot(torch.ones(x_real.size(0)), 2).to(CFG.device)
        c_brats2020_list = create_labels(c_org, CFG.c_dim)
        zero_ixi = torch.zeros(x_real.size(0), CFG.c_dim).to(CFG.device)             # Zero vector for XIX.
        mask_brats2020 = label2onehot(torch.zeros(x_real.size(0)), 2).to(CFG.device)  # Mask vector: [1, 0].

        
        x_fake_list = []
        target_list = []
        for j, c_fixed in enumerate(c_brats2020_list):
            c_trg = torch.cat([c_fixed, zero_ixi, mask_brats2020], dim=1)
            # c_trg = torch.cat([zero_brats2020, c_fixed, mask_ixi], dim=1)
            x_fake = generator(x_real, c_trg)
            x_fake_list.append(x_fake)
        for j in CFG.contrast_list:
            target = data['target'][j][0]
            target_list.append(target)
        for j in range(len(CFG.contrast_list)):
            ssim.append(metrics.calculate_ssim(target_list[j], x_fake_list[j]))
            psnr.append(metrics.calculate_psnr(target_list[j], x_fake_list[j]))
            nmae.append(metrics.calculate_nmae(target_list[j], x_fake_list[j]))

        x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
        x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
        x_concat = (x_concat + 1) / 2
        x_concat = x_concat.clamp_(0, 1)
        save_image(x_concat, f'eval_ixi_stargan/fake{i}.png')
    print("SSIM: ", np.array(ssim).mean())
    print("PSNR: ", np.array(psnr).mean())
    print("NMAE: ", np.array(nmae).mean())

# StarGAN
# IXI source MRA
# SSIM:  0.75208205
# PSNR:  38.86268557462074
# NMAE:  0.037938464

# IXI source = PD
# SSIM:  0.80015516
# PSNR:  39.29677431441381
# NMAE:  0.03180834

# IXI source = T1
# SSIM:  0.7346829
# PSNR:  39.814971066299115
# NMAE:  0.038296565

# IXI source = T2
# SSIM:  0.799478
# PSNR:  40.696188551363484
# NMAE:  0.031956453

# BraTS2020 source = FLAIR
# SSIM:  0.84171426
# PSNR:  33.61483340884243
# NMAE:  0.058331743

# BraTS2020 source = T1CE
# SSIM:  0.8246284
# PSNR:  28.547625551237427
# NMAE:  0.06697038

# BraTS2020 source = T1
# SSIM:  0.8344398
# PSNR:  26.675615710097293
# NMAE:  0.0704227

# BraTS2020 source = T2
# SSIM:  0.8269338
# PSNR:  30.4808867340839
# NMAE:  0.06497672

# ResUnet
# IXI source = MRA
# SSIM:  0.73379254
# PSNR:  40.97990127919558
# NMAE:  0.034802567

# IXI source = PD
# SSIM:  0.77352405
# PSNR:  41.696309792660564
# NMAE:  0.030027147

# IXI source = T1
# SSIM:  0.7168928
# PSNR:  41.60366481777639
# NMAE:  0.034858994

# IXI source = T2
# SSIM:  0.7844148
# PSNR:  39.04168398485751
# NMAE:  0.031735655

# BraTS2020 source = FLAIR
# SSIM:  0.79582345
# PSNR:  32.78408054369021
# NMAE:  0.057049632

# BraTS2020 source = T1CE
# SSIM:  0.7823191
# PSNR:  29.32913845714312
# NMAE:  0.06274834

# BraTS2020 source = T1
# SSIM:  0.7845728
# PSNR:  28.776334321491472
# NMAE:  0.06536168

# BraTS2020 source = T2
# SSIM:  0.77713275
# PSNR:  30.846138165699482
# NMAE:  0.064423576



SSIM:  0.8269338
PSNR:  30.4808867340839
NMAE:  0.06497672


In [None]:
import torch
_ = torch.manual_seed(123)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
# LPIPS needs the images to be in the [-1, 1] range.
img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
loss = lpips(img1, img2)



In [None]:
x = loss.detach().numpy()
x

array(0.3493258, dtype=float32)

In [21]:
!tensorboard dev upload --logdir resunet_both/logs \
    --name "(optional) My latest experiment" \
    --description "(optional) Simple comparison of several hyperparameters"

2023-05-25 00:05:42.508979: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-25 00:05:42.541377: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-25 00:05:43.455689: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-05-25 00:05:43.475630: I tensorflow/comp