In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image
import numpy as np
from model import Generator, ResUnetGenerator, ResUnet
from dataset import CustomDataset
import scipy
from skimage.metrics import structural_similarity as ssim
from math import log10
import pytorch_ssim.pytorch_ssim as pytorch_ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

In [2]:
class Metrics:
    """
        Implement metrics for evaluating the results
        - PSNR (Peak Signal-to-Noise Ratio)
        - NMAE
        - SSIM
    """
    def __init__(self):
        pass

    def calculate_ssim(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        ssim_value = pytorch_ssim.ssim(image1, image2)
        return ssim_value

    def calculate_psnr(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        mse = np.mean(np.mean(np.array(image1) - np.array(image2)) ** 2)
        if(mse == 0):  # MSE is zero means no noise is present in the signal. Therefore PSNR have no importance.
            return 100
        max_pixel = 1
        psnr = 20 * log10(max_pixel / np.sqrt(mse))
        return psnr
    
    def calculate_nmae(self, image1, image2):
        image1, image2 = denorm(image1), denorm(image2)
        # Flatten the 3D images to 1D arrays
        flat_image1 = np.array(image1).flatten()
        flat_image2 = np.array(image2).flatten()
        
        # Calculate the mean absolute error
        abs_error = np.abs(flat_image1 - flat_image2)
        mean_abs_error = np.mean(abs_error)
        
        # Calculate the range of the pixel values
        pixel_range = np.max(flat_image1) - np.min(flat_image1)
        
        # Calculate the normalized mean absolute error
        nmae = mean_abs_error / pixel_range
        
        return nmae
    
    def calculate_lpips(self, image1, image2):
        lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
        return lpips(image1, image2)

In [43]:
def label2onehot(labels, dim):
    """Convert label indices to one-hot vectors."""
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def create_labels(c_org, c_dim=4):
    """Generate target domain labels for debugging and testing."""
    c_trg_list = []
    for i in range(c_dim):
        c_trg = label2onehot(torch.ones(c_org.size(0))*i, c_dim)
        c_trg_list.append(c_trg.to(CFG.device))
    return c_trg_list

def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)


In [44]:
metrics = Metrics()

In [45]:
class CFG:
    brats_data_dir = '/home/han/MRI_DATA/BraTS2020 StarGANs/image_2D/test'
    ixi_data_dir = '/home/han/MRI_DATA/IXI StarGANs/image_2D/test'
    # source_contrast = 't2' # pd, mra, t1, t2  # flair, t1ce, t1, t2
    ixi_contrast_list = ['mra', 'pd', 't1', 't2']
    brats_contrast_list = ['flair', 't1', 't1ce', 't2']
    transform = []
    transform.append(T.ToTensor())
    transform.append(T.Resize((256, 256)))
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)
    generator_dir = 'resunet_new_loss_both/models/200000-G.ckpt'
    g_conv_dim = 64
    c_dim = 4
    repeat_num = 6
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cuda')
    batch_size = 4
    num_workers = 2
    # device = torch.device('cpu')

In [46]:
generator = ResUnetGenerator(CFG.g_conv_dim, CFG.c_dim * 2 + 2, CFG.repeat_num)
generator.load_state_dict(torch.load(CFG.generator_dir, map_location=lambda storage, loc: storage))

<All keys matched successfully>

In [47]:
%matplotlib inline
generator.to(CFG.device)
metrics_scores = {}
for contrast_list, data_dir in zip([CFG.ixi_contrast_list, CFG.brats_contrast_list], [CFG.ixi_data_dir, CFG.brats_data_dir]):
    metrics_scores[data_dir.split('/')[4]] = {}
    for source_contrast in contrast_list:
        metrics_scores[data_dir.split('/')[4]][source_contrast] = {}
        dataset = CustomDataset(data_dir, source_contrast, contrast_list, CFG.transform)
        data_loader = DataLoader(dataset=dataset,
                                  batch_size=CFG.batch_size,
                                  shuffle=False,
                                  num_workers=CFG.num_workers)
        with torch.no_grad():
            ssim = []
            psnr = []
            nmae = []
            for i, data in enumerate(data_loader):
                (x_real, c_org, path) = data['source']
                x_real = x_real.to(CFG.device)
                c_org = c_org.to(CFG.device)

                # IXI
                c_ixi_list = create_labels(c_org, CFG.c_dim)
                zero_brats2020 = torch.zeros(x_real.size(0), CFG.c_dim).to(CFG.device)  
                mask_ixi = label2onehot(torch.ones(x_real.size(0)), 2).to(CFG.device)

                # BraTS
                c_brats2020_list = create_labels(c_org, CFG.c_dim)
                zero_ixi = torch.zeros(x_real.size(0), CFG.c_dim).to(CFG.device)             # Zero vector for XIX.
                mask_brats2020 = label2onehot(torch.zeros(x_real.size(0)), 2).to(CFG.device)  # Mask vector: [1, 0].
                
                x_fake_list = []
                target_list = []

                if contrast_list == CFG.ixi_contrast_list:
                    for j, c_fixed in enumerate(c_ixi_list):
                        # c_trg = torch.cat([c_fixed, zero_ixi, mask_brats2020], dim=1)
                        c_trg = torch.cat([zero_brats2020, c_fixed, mask_ixi], dim=1)
                        x_fake = generator(x_real, c_trg)
                        x_fake_list.append(x_fake)
                    for j in contrast_list:
                        target = data['target'][j][0]
                        target_list.append(target)
                    for j in range(len(contrast_list)):
                        ssim_item = metrics.calculate_ssim(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        psnr_item = metrics.calculate_psnr(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        nmae_item = metrics.calculate_nmae(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        if contrast_list[j] not in metrics_scores[data_dir.split('/')[4]][source_contrast].keys():
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]] = {}
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'] = [ssim_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'] = [psnr_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'] = [nmae_item]
                        else:
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'].append(ssim_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'].append(psnr_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'].append(nmae_item)
                    if i%50 == 0:
                        print('IXI: ', contrast_list)
                        x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
                        x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
                        x_concat = (x_concat + 1) / 2
                        x_concat = x_concat.clamp_(0, 1)
                        save_image(x_concat, f"{data_dir.split('/')[4]}/fake{i}.png")

                else:
                    for j, c_fixed in enumerate(c_brats2020_list):
                        c_trg = torch.cat([c_fixed, zero_ixi, mask_brats2020], dim=1)
                        # c_trg = torch.cat([zero_brats2020, c_fixed, mask_ixi], dim=1)
                        x_fake = generator(x_real, c_trg)
                        x_fake_list.append(x_fake)
                    for j in contrast_list:
                        target = data['target'][j][0]
                        target_list.append(target)
                    for j in range(len(contrast_list)):
                        ssim_item = metrics.calculate_ssim(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        psnr_item = metrics.calculate_psnr(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        nmae_item = metrics.calculate_nmae(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        if contrast_list[j] not in metrics_scores[data_dir.split('/')[4]][source_contrast].keys():
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]] = {}
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'] = [ssim_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'] = [psnr_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'] = [nmae_item]
                        else:
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'].append(ssim_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'].append(psnr_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'].append(nmae_item)
                    if i%100 == 0:
                        print('BraTS2020: ', contrast_list)
                        x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
                        x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
                        x_concat = (x_concat + 1) / 2
                        x_concat = x_concat.clamp_(0, 1)
                        save_image(x_concat, f"{data_dir.split('/')[4]}/fake{i}.png")



IXI:  ['mra', 'pd', 't1', 't2']
IXI:  ['mra', 'pd', 't1', 't2']




IXI:  ['mra', 'pd', 't1', 't2']
IXI:  ['mra', 'pd', 't1', 't2']




IXI:  ['mra', 'pd', 't1', 't2']
IXI:  ['mra', 'pd', 't1', 't2']




IXI:  ['mra', 'pd', 't1', 't2']
IXI:  ['mra', 'pd', 't1', 't2']




BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']




BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']




BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']




BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']
BraTS2020:  ['flair', 't1', 't1ce', 't2']


In [54]:
metrics = 'nmae'
print(np.mean(metrics_scores['IXI StarGANs']['t2']['t1'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['t2']['pd'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['t2']['mra'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['t2'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['pd'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['mra'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['t1'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['t2'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['mra'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['t1'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['t2'][metrics]))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['pd'][metrics]))

0.069913305
0.03236283
0.029810015
0.051985413
0.059531458
0.028398063
0.061080262
0.028997628
0.031060133
0.04598265
0.04607626
0.05112156


In [60]:
generator = Generator(CFG.g_conv_dim, CFG.c_dim, CFG.repeat_num)
generator.load_state_dict(torch.load(CFG.generator_dir, map_location=lambda storage, loc: storage))

<All keys matched successfully>

In [61]:
%matplotlib inline
generator.to(CFG.device)
metrics_scores = {}
for contrast_list, data_dir in zip([CFG.ixi_contrast_list, CFG.brats_contrast_list], [CFG.ixi_data_dir, CFG.brats_data_dir]):
    if data_dir.split('/')[4] == 'BraTS2020 StarGANs':
        metrics_scores[data_dir.split('/')[4]] = {}
        for source_contrast in contrast_list:
            metrics_scores[data_dir.split('/')[4]][source_contrast] = {}
            dataset = CustomDataset(data_dir, source_contrast, contrast_list, CFG.transform)
            data_loader = DataLoader(dataset=dataset,
                                    batch_size=CFG.batch_size,
                                    shuffle=False,
                                    num_workers=CFG.num_workers)
            with torch.no_grad():
                ssim = []
                psnr = []
                nmae = []
                for i, data in enumerate(data_loader):
                    (x_real, c_org, path) = data['source']
                    x_real = x_real.to(CFG.device)
                    c_trg_list = create_labels(c_org, CFG.c_dim)
                    
                    x_fake_list = []
                    target_list = []

                    for c_trg in c_trg_list:
                        x_fake_list.append(generator(x_real, c_trg))
                    for j in contrast_list:
                        target = data['target'][j][0].data.cpu()
                        target_list.append(target)
                    for j in range(len(contrast_list)):
                        ssim_item = metrics.calculate_ssim(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        psnr_item = metrics.calculate_psnr(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        nmae_item = metrics.calculate_nmae(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                        if contrast_list[j] not in metrics_scores[data_dir.split('/')[4]][source_contrast].keys():
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]] = {}
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'] = [ssim_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'] = [psnr_item]
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'] = [nmae_item]
                        else:
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'].append(ssim_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'].append(psnr_item)
                            metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'].append(nmae_item)
                    # if i%50 == 0:
                    #     x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
                    #     x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
                    #     x_concat = (x_concat + 1) / 2
                    #     x_concat = x_concat.clamp_(0, 1)
                    #     save_image(denorm(x_concat.data.cpu()), f"{data_dir.split('/')[4]}/single_fake{i}.png")



In [65]:
metrics = 'nmae'
print(np.mean(metrics_scores['IXI StarGANs']['t2']['t1']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['t2']['pd']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['t2']['mra']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['t2']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['pd']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['t1']['mra']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['t1']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['t2']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['pd']['mra']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['t1']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['t2']['ssim']))
print(np.mean(metrics_scores['IXI StarGANs']['mra']['pd']['ssim']))

0.03876209
0.041596036
0.034268912
0.043283213
0.035996106
0.04559305
0.042630948
0.04008158
0.056440245
0.045357995
0.038583476
0.048108928


In [None]:
CFG.generator_dir = 'stargan_ixi/models/200000-G.ckpt'
generator = Generator(CFG.g_conv_dim, CFG.c_dim, CFG.repeat_num)
generator = Generator(CFG.g_conv_dim, CFG.c_dim*2+2, CFG.repeat_num)
generator.load_state_dict(torch.load(CFG.generator_dir, map_location=lambda storage, loc: storage))
data_dir = CFG.ixi_data_dir
contrast_list = CFG.ixi_contrast_list

In [None]:
for source_contrast in contrast_list:
    metrics_scores[data_dir.split('/')[4]][source_contrast] = {}
    dataset = CustomDataset(data_dir, source_contrast, contrast_list, CFG.transform)
    data_loader = DataLoader(dataset=dataset,
                            batch_size=CFG.batch_size,
                            shuffle=False,
                            num_workers=CFG.num_workers)
    with torch.no_grad():
        ssim = []
        psnr = []
        nmae = []
        for i, data in enumerate(data_loader):
            (x_real, c_org, path) = data['source']
            x_real = x_real.to(CFG.device)
            c_trg_list = create_labels(c_org, CFG.c_dim)
            
            x_fake_list = [x_real]
            target_list = [x_real]

            for c_trg in c_trg_list:
                x_fake_list.append(generator(x_real, c_trg))
            for j in contrast_list:
                target = data['target'][j][0].data.cpu()
                target_list.append(target)
            for j in range(len(contrast_list)):
                ssim_item = metrics.calculate_ssim(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                psnr_item = metrics.calculate_psnr(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                nmae_item = metrics.calculate_nmae(target_list[j].data.cpu(), x_fake_list[j].data.cpu())
                if contrast_list[j] not in metrics_scores[data_dir.split('/')[4]][source_contrast].keys():
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]] = {}
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'] = [ssim_item]
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'] = [psnr_item]
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'] = [nmae_item]
                else:
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['ssim'].append(ssim_item)
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['psnr'].append(psnr_item)
                    metrics_scores[data_dir.split('/')[4]][source_contrast][contrast_list[j]]['nmae'].append(nmae_item)
            if i%50 == 0:
                x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
                x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
                x_concat = (x_concat + 1) / 2
                x_concat = x_concat.clamp_(0, 1)
                save_image(denorm(x_concat.data.cpu()), f"{data_dir.split('/')[4]}/single_qualitative_fake{i}.png")