In [None]:
import os

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.set_cmap('jet')
%matplotlib inline

import collections
import numpy as np
import random
import sklearn, sklearn.model_selection
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.utils.data as Data
import torchvision


In [None]:
random.seed(1234)
np.random.seed(1234)

torch.manual_seed(1234)
torch.backends.cudnn.deterministic = True

torch.cuda.manual_seed_all(1234)

In [None]:
BATCH_SIZE = 64

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=64,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=64,
                out_channels=32,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            nn.ReLU(),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=16,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            nn.ReLU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=8,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            nn.ReLU(),
        )
        self.out = nn.Linear(440, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x



In [None]:
cnn = CNN()


In [None]:
checkpoint = torch.load('./classifier_model.pth')
cnn.load_state_dict(checkpoint)


In [None]:
cnn.eval()


In [None]:
class SubsetSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):

        return len(self.indices)


In [None]:
mytransform = torchvision.transforms.Compose([
   torchvision.transforms.Resize(100),
   torchvision.transforms.ToTensor()])

In [None]:
if model == 'cycle_gan':
    thispath = './cyclegan/results/brats2013_cyclegan_'
else:
    thispath = './cyclegan/results/brats2013_pix2pix_'

In [None]:
results = []

for percent in [
        '0.0', '0.1', '0.2', '0.3', '0.4', '0.5',
        '0.6', '0.7', '0.8', '0.9', '1.0']:

    path = thispath + percent + '/test_latest/'
    test_data_raw = torchvision.datasets.ImageFolder(
        path, transform=mytransform)

    labels = np.asarray(
        ['True' in img for img in np.asarray(test_data_raw.imgs)[:, 0]])
    fake_b_samples = np.where(
        ['fake_B' in img for img in np.asarray(test_data_raw.imgs)[:, 0]])[0]
    real_b_samples = np.where(
        ['real_B' in img for img in np.asarray(test_data_raw.imgs)[:, 0]])[0]

    test_fake_b_loader = torch.utils.data.DataLoader(
        dataset=test_data_raw, batch_size=len(test_data_raw), shuffle=False,
        sampler=SubsetSampler(fake_b_samples))
    test_real_b_loader = torch.utils.data.DataLoader(
        dataset=test_data_raw, batch_size=len(test_data_raw),
        shuffle=False, sampler=SubsetSampler(real_b_samples))

    test_fake_b_data = list(test_fake_b_loader)
    test_real_b_data = list(test_real_b_loader)
    test_fake_b_x = Variable(test_fake_b_data[0][0])
    test_fake_b_y = test_fake_b_data[0][1]

    cnn.eval()
    r = cnn(test_fake_b_x)

    pred_y = torch.max(r[0], 1)[1].data.squeeze().numpy()

    acc = (pred_y == labels[fake_b_samples]).mean()

    dist_0 = pred_y[labels[fake_b_samples] == 0].mean()
    dist_1 = pred_y[labels[fake_b_samples] == 1].mean()

    diff = np.abs(
        (test_fake_b_data[0][0].numpy() - test_real_b_data[0][0].numpy()))
    diff_per_image = diff.mean(axis=(1, 2, 3))

    diff_0 = diff_per_image[labels[fake_b_samples] == 0].mean()
    diff_1 = diff_per_image[labels[fake_b_samples] == 1].mean()

    results.append([
        percent, pred_y.mean(), acc, dist_0, dist_1,
        collections.Counter(pred_y), diff_0, diff_1])  

    print(
        'Percent:', percent,
        ' Tumors:', pred_y.mean(),
        ' ', collections.Counter(pred_y),
        ' ', diff_0, diff_1)


In [None]:
perc = np.asarray(results)[:, 3].astype(np.float)
fig, ax1 = plt.subplots()
ax1.bar(
    range(perc.shape[0]), 1 - perc, bottom=perc,
    color='forestgreen', label='Predicted without Tumor')
ax1.bar(
    range(perc.shape[0]), perc,
    color='tomato', label='Predicted with Tumor')
ax1.set_ylabel('Percentage of transformed samples')
ax1.set_xlabel('Percentage of tumor images in the target domain during training')
ax1.set_title('Transformed healthy images from the holdout set')
ax1.legend(loc='upper left')

ax2 = ax1.twinx()
t = np.arange(11)
s2 = np.asarray(results)[:, 6].astype(np.float)
ax2.set_ylim(ymax=0.03, ymin=0.005)
ax2.plot(t, s2, 'black', label='Pixel Error (MAE)')

ax2.set_ylabel('Pixel Error (MAE)')
ax2.legend(loc='lower right')
fig.tight_layout()
plt.xticks(range(11), [
    '0%', '10%', '20%', '30%', '40%', '50%',
    '60%', '70%', '80%', '90%', '100%'])
plt.show()

In [None]:
perc = np.asarray(results)[:, 4].astype(np.float)
fig, ax1 = plt.subplots()
ax1.bar(
    range(perc.shape[0]), 1 - perc, bottom=perc,
    color='forestgreen', label='Predicted without Tumor')
ax1.bar(
    range(perc.shape[0]), perc,
    color='tomato', label='Predicted with Tumor')
ax1.set_ylabel('Percentage of transformed samples')
ax1.set_xlabel('Percentage of tumor images in the target domain during training')
ax1.set_title('Transformed tumor images from the holdout set')
ax1.legend(loc='upper left')

ax2 = ax1.twinx()
t = np.arange(11)
s2 = np.asarray(results)[:, 7].astype(np.float)
ax2.set_ylim(ymax=0.03, ymin=0.005)
ax2.plot(t, s2, 'black', label='Pixel Error (MAE)')

ax2.set_ylabel('Pixel Error (MAE)')
ax2.legend(loc='lower right')
fig.tight_layout()
plt.xticks(range(11), [
    '0%', '10%', '20%', '30%', '40%', '50%',
    '60%', '70%', '80%', '90%', '100%'])
plt.show()