In [None]:
import torchvision
from tqdm.notebook import tqdm
from prdc import compute_prdc
import numpy as np
import torch
from pathlib import Path
import sys
current_file = Path(__file__)
root_dir = current_file.parent.parent.parent
sys.path.append(str(root_dir / 'InnerEye-Generative'))
sys.path.append(str(root_dir))
from locations import DATASETPATH
DATASETPATH = Path(DATASETPATH)
from loaders.prostate_loader import Prostate2DSimpleDataset
# Users should import calculate_frechet_distance from:
# https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
# from metrics.FID import calculate_frechet_distance

In [None]:
# load pretrained model from torchvision
model = torchvision.models.vgg11(pretrained=True, progress=True)
model.eval()
model = model.to('cuda')

In [None]:
# metrics will collect the metrics for different types of image corruption
# these metrics will also be printed directly after their calculation
#
# For each type of corruption, the metrics are calculated through k_iterations - cross-validation, 
# where k_iterations is preset to be 5  
Metrics = {}
k_iterations = 5
k_nearest_neigh = 5
path = DATASETPATH / '2D' / 'dataset.csv'

In [None]:
metrics = {'precision':[], 'recall':[],'density':[], 'coverage':[], 'FID':[]}
for _ in range(k_iterations):

    ds1 = Prostate2DSimpleDataset(path, None, input_channels=3)
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=3)
    idx = np.arange(len(ds1))
    np.random.shuffle(idx)
    ds1.df = ds1.df.loc[ds1.df.index[idx[:int(len(ds1)/2)]]]
    ds2.df = ds2.df.loc[ds2.df.index[idx[int(len(ds2)/2):]]]

    dls = [torch.utils.data.DataLoader(ds,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=8) for ds in [ds1, ds2]]
    preds = []
    for dl in dls:
        _preds = []
        for item in dl:
            with torch.no_grad():
                x = model(item.to('cuda'))
            _preds.append(x.cpu())
        _preds = torch.cat(_preds, 0)
        preds.append(_preds.numpy())
    _metrics = compute_prdc(preds[0], preds[1], k_nearest_neigh)

    mu = [np.mean(pred, axis=0) for pred in preds]
    sigma = [np.cov(pred, rowvar=False) for pred in preds]
    FID = calculate_frechet_distance(mu[0], sigma[0], mu[1], sigma[1])
    _metrics['FID'] = FID
    for key in _metrics:
        metrics[key].append(_metrics[key])
Metrics['data_vs_data'] = {}
for key in metrics:
    print(key, np.mean(metrics[key]), np.std(metrics[key]))
    Metrics['data_vs_data'][key] = [np.mean(metrics[key]), np.std(metrics[key])]


In [None]:
def noise(img):
    if len(img.shape) == 2:
        return np.random.rand(img.shape[0], img.shape[1])
    elif len(img.shape) == 3:
        return np.random.rand(img.shape[0], img.shape[1], img.shape[2])
    else:
        raise ValueError
Metrics['data_vs_noise'] = {}


k_iterations = 5
k_nearest_neigh=5
path = 'dataset.csv'
metrics = {'precision':[], 'recall':[],'density':[], 'coverage':[], 'FID':[]}
for _ in range(k_iterations):

    ds1 = Prostate2DSimpleDataset(path, None, input_channels=3)
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=3, transforms=noise)
    idx = np.arange(len(ds1))
    np.random.shuffle(idx)
    ds1.df = ds1.df.loc[ds1.df.index[idx[:int(len(ds1)/2)]]]
    ds2.df = ds2.df.loc[ds2.df.index[idx[int(len(ds2)/2):]]]
    dls = [torch.utils.data.DataLoader(ds,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=8) for ds in [ds1, ds2]]
    preds = []
    for dl in dls:
        _preds = []
        for item in dl:
            with torch.no_grad():
                x = model(item.to('cuda'))
            _preds.append(x.cpu())
        _preds = torch.cat(_preds, 0)
        preds.append(_preds.numpy())
    _metrics = compute_prdc(preds[0], preds[1], k_nearest_neigh)

    mu = [np.mean(pred, axis=0) for pred in preds]
    sigma = [np.cov(pred, rowvar=False) for pred in preds]
    FID = calculate_frechet_distance(mu[0], sigma[0], mu[1], sigma[1])
    _metrics['FID'] = FID
    for key in _metrics:
        metrics[key].append(_metrics[key])
for key in metrics:
    print(key, np.mean(metrics[key]), np.std(metrics[key]))
    Metrics['data_vs_noise'][key] = [np.mean(metrics[key]), np.std(metrics[key])]


In [None]:
from scipy.ndimage.filters import gaussian_filter   

k_iterations = 5
k_nearest_neigh=5
path = 'dataset.csv'
Gaussian_blur_metrics = {}
for Sigma in [.25,.5,.75,1,2]:
    metrics = {'precision':[], 'recall':[],'density':[], 'coverage':[], 'FID':[]}

    for _ in range(k_iterations):

        ds1 = Prostate2DSimpleDataset(path, None, input_channels=3)
        ds2 = Prostate2DSimpleDataset(path, None, input_channels=3, transforms=gaussian_filter, transforms_args={'sigma':Sigma})
        idx = np.arange(len(ds1))
        np.random.shuffle(idx)
        ds1.df = ds1.df.loc[ds1.df.index[idx[:int(len(ds1)/2)]]]
        ds2.df = ds2.df.loc[ds2.df.index[idx[int(len(ds2)/2):]]]
        dls = [torch.utils.data.DataLoader(ds,
                                        batch_size=16,
                                        shuffle=False,
                                        drop_last=False,
                                        num_workers=8) for ds in [ds1, ds2]]
        preds = []
        for dl in dls:
            _preds = []
            for item in dl:
                with torch.no_grad():
                    x = model(item.to('cuda'))
                _preds.append(x.cpu())
            _preds = torch.cat(_preds, 0)
            preds.append(_preds.numpy())
        _metrics = compute_prdc(preds[0], preds[1], k_nearest_neigh)

        mu = [np.mean(pred, axis=0) for pred in preds]
        sigma = [np.cov(pred, rowvar=False) for pred in preds]
        FID = calculate_frechet_distance(mu[0], sigma[0], mu[1], sigma[1])
        _metrics['FID'] = FID
        for key in _metrics:
            metrics[key].append(_metrics[key])
    Gaussian_blur_metrics[Sigma] = {}
    for key in metrics:
        print(key, np.mean(metrics[key]), np.std(metrics[key]))
        Gaussian_blur_metrics[Sigma][key] = [np.mean(metrics[key]), np.std(metrics[key])]
Metrics['gaussian_blur'] = Gaussian_blur_metrics

In [None]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
tsne = TSNE(perplexity = 5)
emb2d = tsne.fit_transform(np.concatenate(preds, 0))
n = preds[0].shape[0]
plt.scatter(emb2d[:n, 0], emb2d[:n, 1], alpha=.5)
plt.scatter(emb2d[n:, 0], emb2d[n:, 1],alpha=.5)
plt.show()

In [None]:
plt.scatter(emb2d[n:, 0], emb2d[n:, 1],alpha=.5, c='C1')

plt.scatter(emb2d[:n, 0], emb2d[:n, 1], alpha=.5, c='C0')
plt.show()

In [None]:
def multiplicative_noise(img, sigma=1):
    return np.clip(img * (1 + np.random.randn(img.shape[0], img.shape[1], img.shape[2]) * sigma), 0, 1)
ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=None, transforms_args={'sigma':.5})
plt.imshow(ds2[0].squeeze(), cmap='gray')
plt.show()
for sigma in [.1,.25,.5,1]:
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=multiplicative_noise, transforms_args={'sigma':sigma})
    plt.imshow(ds2[0].squeeze(), cmap='gray')
    plt.show()

In [None]:
k_iterations = 5
k_nearest_neigh=5
path = 'dataset.csv'
Mult_noise_metrics = {}
for Sigma in [.1,.25,.5,.75,1]:
    metrics = {'precision':[], 'recall':[],'density':[], 'coverage':[], 'FID':[]}

    for _ in range(k_iterations):

        ds1 = Prostate2DSimpleDataset(path, None, input_channels=3)
        ds2 = Prostate2DSimpleDataset(path, None, input_channels=3, transforms=multiplicative_noise, transforms_args={'sigma':Sigma})
        idx = np.arange(len(ds1))
        np.random.shuffle(idx)
        ds1.df = ds1.df.loc[ds1.df.index[idx[:int(len(ds1)/2)]]]
        ds2.df = ds2.df.loc[ds2.df.index[idx[int(len(ds2)/2):]]]

        dls = [torch.utils.data.DataLoader(ds,
                                        batch_size=16,
                                        shuffle=False,
                                        drop_last=False,
                                        num_workers=8) for ds in [ds1, ds2]]
        preds = []
        for dl in dls:
            _preds = []
            for item in dl:
                with torch.no_grad():
                    x = model(item.to('cuda'))
                _preds.append(x.cpu())
            _preds = torch.cat(_preds, 0)
            preds.append(_preds.numpy())
        _metrics = compute_prdc(preds[0], preds[1], k_nearest_neigh)

        mu = [np.mean(pred, axis=0) for pred in preds]
        sigma = [np.cov(pred, rowvar=False) for pred in preds]
        FID = calculate_frechet_distance(mu[0], sigma[0], mu[1], sigma[1])
        _metrics['FID'] = FID
        for key in _metrics:
            metrics[key].append(_metrics[key])
    Mult_noise_metrics[Sigma] ={}
    for key in metrics:
        print(key, np.mean(metrics[key]), np.std(metrics[key]))
        Mult_noise_metrics[Sigma][key] = [ np.mean(metrics[key]), np.std(metrics[key])]
Metrics['mult_noise'] = Mult_noise_metrics

In [None]:
from skimage.transform import swirl

def swirl_img(img, rotation=0, strength=1, radius=120):
    return np.moveaxis(swirl(np.moveaxis(img, 0, -1), rotation=rotation, strength=strength, radius=radius), -1, 0)


ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=None, transforms_args={'sigma':.5})
plt.imshow(ds2[0].squeeze(), cmap='gray')
plt.show()
for sigma in [.1,.25,.5,1]:
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=swirl_img, transforms_args={'rotation':0, 'strength': sigma, 'radius':120})
    plt.imshow(ds2[0].squeeze(), cmap='gray')
    plt.show()

In [None]:
k_iterations = 5
k_nearest_neigh=5
path = 'dataset.csv'
swirl_metrics = {}
for Sigma in [.1,.25,.5,.75,1, 2.5, 5,7.5,10]:
    metrics = {'precision':[], 'recall':[],'density':[], 'coverage':[], 'FID':[]}

    for _ in range(k_iterations):

        ds1 = Prostate2DSimpleDataset(path, None, input_channels=3)
        ds2 = Prostate2DSimpleDataset(path, None, input_channels=3, transforms=swirl_img, transforms_args={'strength':Sigma})
        idx = np.arange(len(ds1))
        np.random.shuffle(idx)
        ds1.df = ds1.df.loc[ds1.df.index[idx[:int(len(ds1)/2)]]]
        ds2.df = ds2.df.loc[ds2.df.index[idx[int(len(ds2)/2):]]]

        dls = [torch.utils.data.DataLoader(ds,
                                        batch_size=16,
                                        shuffle=False,
                                        drop_last=False,
                                        num_workers=8) for ds in [ds1, ds2]]
        preds = []
        for dl in dls:
            _preds = []
            for item in dl:
                with torch.no_grad():
                    x = model(item.to('cuda'))
                _preds.append(x.cpu())
            _preds = torch.cat(_preds, 0)
            preds.append(_preds.numpy())
        _metrics = compute_prdc(preds[0], preds[1], k_nearest_neigh)

        mu = [np.mean(pred, axis=0) for pred in preds]
        sigma = [np.cov(pred, rowvar=False) for pred in preds]
        FID = calculate_frechet_distance(mu[0], sigma[0], mu[1], sigma[1])
        _metrics['FID'] = FID
        for key in _metrics:
            metrics[key].append(_metrics[key])
    swirl_metrics[Sigma] ={}
    for key in metrics:
        print(key, np.mean(metrics[key]), np.std(metrics[key]))
        swirl_metrics[Sigma][key] = [ np.mean(metrics[key]), np.std(metrics[key])]
Metrics['swirl'] = swirl_metrics

In [None]:
print(list(Metrics))
for el in Metrics:
    print(list(Metrics[el]))

In [None]:
Metrics['data_vs_data']['precision']

In [None]:
# PLOT
from pylab import figure, show, legend, ylabel
def plot_metrics(metrics, base_metrics, Title='', xaxis='', figsize=(10,6), savefig=False):
    # create the general figure
    fig1 = figure(figsize=figsize)
    lines = []
    # and the first axes using subplot populated with data 
    ax1 = fig1.add_subplot(111)
    for el in ['precision', 'recall', 'density', 'coverage']:
        m =  np.array([base_metrics[el][0]] + [metrics[key][el][0] for key in metrics])
        std = np.array([base_metrics[el][1]] + [metrics[key][el][1] for key in metrics])
        ax1.fill_between([0] + list(metrics), m-std, m+std, alpha=.5)
    
    # now, the second axes that shares the x-axis with the ax1
    el = 'FID'
    m =  np.array([base_metrics[el][0]] + [metrics[key][el][0] for key in metrics])
    std = np.array([base_metrics[el][1]] + [metrics[key][el][1] for key in metrics])
    ax2 = fig1.add_subplot(111, sharex=ax1, frameon=False)
    ax2.fill_between([0] + list(metrics), m-std, m+std, alpha=.5, color='C4')
    ax2.yaxis.tick_right()
    ax2.yaxis.set_label_position("right")
    
    for el in ['precision', 'recall', 'density', 'coverage']:
        m =  np.array([base_metrics[el][0]] + [metrics[key][el][0] for key in metrics])
        std = np.array([base_metrics[el][1]] + [metrics[key][el][1] for key in metrics])
        lines = lines + ax1.plot([0] + list(metrics), m)

    el = 'FID'
    m =  np.array([base_metrics[el][0]] + [metrics[key][el][0] for key in metrics])
    std = np.array([base_metrics[el][1]] + [metrics[key][el][1] for key in metrics])
    lines = lines + ax2.plot([0] + list(metrics), m, c='C4')

    # for the legend, remember that we used two different axes so, we need 
    # to build the legend manually
    legend(tuple(lines),  ['precision', 'recall', 'density', 'coverage', 'FID (RHS)'])
    plt.xlabel(xaxis)
    plt.title(Title)
    if savefig:
        plt.tight_layout()
        plt.savefig('fig')
    show()

In [None]:

list(Metrics)

In [None]:
plot_metrics(Metrics['gaussian_blur'], Metrics['data_vs_data'], """Randomly initialised VGG metrics
with increasing Gaussian blur""", 'sigma: data = Gaussian_blur(img, mu=0, sigma=sigma)', figsize=(8,5), savefig=True)


In [None]:
plot_metrics(Metrics['mult_noise'], Metrics['data_vs_data'], """Randomly initialised VGG metrics
with increasing multiplicative noise""", 'w: data = img * (1 + w * N(0,1))', figsize=(8,5), savefig=True)

In [None]:
_metrics = {}
for el in  [0.1, 0.25, 0.5, 0.75, 1, 2.5]:
    _metrics[el] = Metrics['swirl'][el]

In [None]:
plot_metrics(_metrics, Metrics['data_vs_data'], """Randomly initialised VGG metrics
with increasing swirl""", 's: data = swirl(img, strength=s, radius=120)', figsize=(8,5), savefig=True)

plot_metrics(Metrics['swirl'], Metrics['data_vs_data'], """Randomly initialised VGG metrics
with increasing swirl""", 's: data = swirl(img, strength=s, radius=120)', figsize=(8,5), savefig=True)

In [None]:
Dict = Metrics['gaussian_blur']
fix, axs = plt.subplots(1, len(Dict), figsize=(12,10))
for ax, sigma in zip(axs, Dict):
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=gaussian_filter, transforms_args={'sigma':sigma})
    ax.imshow(ds2[0].squeeze(), cmap='gray')
    ax.set_title('sigma: {}'.format(sigma))
plt.tight_layout()
plt.savefig('imgs')
plt.show()

In [None]:
Dict = Metrics['mult_noise']
fix, axs = plt.subplots(1, len(Dict), figsize=(12,10))
for ax, sigma in zip(axs, Dict):
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=multiplicative_noise, transforms_args={'sigma':sigma})
    ax.imshow(ds2[0].squeeze(), cmap='gray')
    ax.set_title('w: {}'.format(sigma))
plt.tight_layout()
plt.savefig('imgs')    
plt.show()

In [None]:
Dict = Metrics['swirl']
fix, axs = plt.subplots(1, len(Dict), figsize=(16,6))
for ax, sigma in zip(axs, Dict):
    ds2 = Prostate2DSimpleDataset(path, None, input_channels=1, transforms=swirl_img, transforms_args={'strength':sigma})
    ax.imshow(ds2[0].squeeze(), cmap='gray')
    ax.set_title('s: {}'.format(sigma))
plt.tight_layout()
plt.savefig('imgs.png')
plt.show()