In [1]:
import glob
import os
import time

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 
os.environ['CUDA_VISIBLE_DEVICES']='1'

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import h5py
import imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.autograd import Variable
import torch.autograd as autograd
import torchvision.utils as vutils

from torch.utils.data import DataLoader, TensorDataset

In [2]:
def load_data():
    x_train = (h5py.File('camelyonpatch_level_2_split_train_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
    y_train = h5py.File('camelyonpatch_level_2_split_train_y.h5', 'r')['y'][:].reshape(-1,1)
    x_test = (h5py.File('camelyonpatch_level_2_split_test_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
    y_test = h5py.File('camelyonpatch_level_2_split_test_y.h5', 'r')['y'][:].reshape(-1,1)
    x_valid = (h5py.File('camelyonpatch_level_2_split_valid_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
    y_valid = h5py.File('camelyonpatch_level_2_split_valid_y.h5', 'r')['y'][:].reshape(-1,1)
              
    return x_train, y_train, x_test, y_test, x_valid, y_valid

def plot_samples(samples, folder=None, epoch=None, i=None):
    rt = int(np.sqrt(samples.shape[0]))
    r, c = rt, rt
    # r, c = 6, 12

    generatedImage = 0.5 * samples + 0.5

    fig = plt.figure(figsize=(10,10))

    axs = [fig.add_subplot(r,c,i+1) for i in range(r*c)]
    cnt = 0
    for ax in axs:
        ax.imshow(generatedImage[cnt],interpolation='nearest')
        ax.axis('off')
        ax.set_aspect('equal')
        cnt+=1
    fig.subplots_adjust(wspace=.004, hspace=.02)

    if folder:
        path = 'results/'+folder+'/samples'
        if not os.path.exists('results'):
            os.mkdir('results')
        if not os.path.exists('results/'+folder):
            os.mkdir('results/'+folder)
        if not os.path.exists(path):
            os.mkdir(path)
        step = ""
        if i:
            step = '_'+str(i)
        fig.savefig(path+'/epoch_%d%s.png' % (epoch, step))
        plt.close()


In [3]:
%cd "~/pathology_gan"

x_train, y_train, x_test, y_test, x_valid, y_valid = load_data()

X = torch.from_numpy(np.moveaxis(np.concatenate([x_train, x_test, x_valid]).astype(np.float32), -1, 1))
# y = torch.from_numpy(np.concatenate([y_train, y_test, y_valid]))

trainloader = DataLoader(TensorDataset(X), batch_size=128, shuffle=True, num_workers=0, pin_memory=True)

# percent = int(x_train.shape[0]*.01)
# np.random.seed(17)
# idx_small = np.random.choice(range(x_train.shape[0]), percent, replace=False)

# x_train_small = x_train[idx_small]
# y_train_small = y_train[idx_small]
# plt.hist(y_train_small, bins=2)
# print(percent)

/home/aray/pathology_gan


In [4]:
class Generator(nn.Module):
    def __init__(self, w, h, c, latent_dim):
        super(Generator, self).__init__()
        
        self.w = w
        self.h = h
        self.c = c
        self.latent_dim = latent_dim
        
        self.n_filters = 128
        
        self.input = nn.Sequential(
            nn.Linear(latent_dim, self.n_filters * w//4 * h//4),
            nn.ReLU()
        )
        
        self.deconv = nn.Sequential(
            nn.Upsample(size=[w//2, h//2], mode='nearest'),
            nn.Conv2d(self.n_filters, self.n_filters//2, 3, stride=1, padding=1),
            nn.ReLU(True),
            
            nn.Upsample(size=[w, h], mode='nearest'),
            nn.Conv2d(self.n_filters//2, self.n_filters//4, 4, stride=2, padding=1),
            nn.ReLU(True),
            
            nn.Upsample(size=[w, h], mode='nearest'),
            nn.Conv2d(self.n_filters//4, self.n_filters//8, 3, stride=1, padding=1),
            nn.ReLU(True),

            nn.Upsample(size=[w, h], mode='nearest'),
            nn.Conv2d(self.n_filters//8, self.n_filters//16, 4, stride=2, padding=1),
            nn.ReLU(True),

            nn.Upsample(size=[w, h], mode='nearest'),
            nn.Conv2d(self.n_filters//16, self.n_filters//32, 3, stride=1, padding=1),
            nn.ReLU(True),

            nn.Conv2d(self.n_filters//32, c, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.input(x)
        # print(output.shape)
        output = output.view(-1, self.n_filters, self.w//4, self.h//4)
        # print(output.shape)
        return self.deconv(output) #.view(-1, self.w, self.h, self.c)

class Discriminator(nn.Module):
    def __init__(self, h, w, c):
        super(Discriminator, self).__init__()
        
        self.w = w
        self.h = h
        self.c = c
        n_filters = 32

        self.main = nn.Sequential(
            nn.Conv2d(c, n_filters, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(p=0.2),
            nn.Conv2d(n_filters, 2*n_filters, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(p=0.2),
            nn.Conv2d(2*n_filters, 4*n_filters, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(p=0.5),
            nn.Flatten()
        )
        self.output = nn.Linear(4*n_filters*int(w/2**3)*int(h/2**3), 1,bias=True)

    def forward(self, x):
        out = self.main(x)
#         print(out.shape)
        out2 = self.output(out)
#         print(out2.shape)
        return out2, out

In [5]:
class WGAN_CT:
    def __init__(self, w, h, c, model_name, latent_dim=100):
        self.model_name = model_name
        self.latent_dim = latent_dim 
        self.lambda_gp = 10
        self.lambda_cp = 2
        self.d_iterations = 5
        self.print_every = 10
        # CT multiplier
        self.M = .1
        self.use_cuda = True
        self.D = Discriminator(w, h, c)
        self.G = Generator(w, h, c, latent_dim)

        lr = 1e-4
        betas = (.9, .99)

        self.D_opt = optim.Adam(self.D.parameters(), lr=lr, betas=betas)
        self.G_opt = optim.Adam(self.G.parameters(), lr=lr, betas=betas)


        if self.use_cuda:
            self.D = self.D.cuda()
            self.G = self.G.cuda()

        if not os.path.exists('results/'+model_name):
            os.mkdir('results/'+model_name)

    def train(self, data_loader, epochs, n_samples, save_training_gif=True):
        print(n_samples, n_samples//128)
        if save_training_gif:
            # Fix latents to see how image generation improves during training
            fixed_latents = torch.randn((12*6, self.latent_dim))
            if self.use_cuda:
                fixed_latents = fixed_latents.cuda()
            training_progress_images = []

        self.stats = {
            'clf_loss': [],
            'clf_acc': [],
            'clf_loss_val': [],
            'clf_acc_val': [],
            'g_loss': [],
            'd_loss': [],
            'd_loss_real': [],
            'd_loss_fake': [],
            'gp': [],
            'cp': [],
            
        }

        for epoch in range(epochs):
#             clf_loss = []
#             clf_acc = []
            g_loss = []
            d_loss = []
            d_loss_fake = []
            d_loss_real = []
            gradient_penalty = []
#             ct = []

            # for i, data in tqdm(enumerate(data_loader), desc="epoch "+str(epoch)):
            for i, data in enumerate(data_loader):
                if i%((n_samples//128)//100)==0:
                    print(".", end="", flush=True)
                X = data[0]
                if self.use_cuda:
                    X = X.cuda()

                r, g, gp, loss = self._train_D(X)
                d_loss_real.append(r)
                d_loss_fake.append(g)
                gradient_penalty.append(gp)
                d_loss.append(loss)
                
                # Only update generator every |d_iterations| iterations
                if i % self.d_iterations == 0:
                    g_loss.append(self._train_G(X.shape[0]))
            
            g_loss_m = sum(g_loss)/len(g_loss)
            d_loss_m = sum(d_loss)/len(d_loss)
            real_m = sum(d_loss_real)/len(d_loss_real)
            fake_m = sum(d_loss_fake)/len(d_loss_fake)
            gp_m = sum(gradient_penalty)/len(gradient_penalty)
            self.stats['g_loss'].append(g_loss_m)
            self.stats['d_loss'].append(d_loss_m)
            self.stats['d_loss_real'].append(real_m)
            self.stats['d_loss_fake'].append(fake_m)
            self.stats['gp'].append(gp_m)
            print()
            print(epoch)
            print("D loss: %f; real: %f; fake: %f; gp: %f"%(d_loss_m, real_m, fake_m, gp_m))
            print("G loss: %f"%(g_loss_m))
            
            with torch.no_grad():
                self.G.eval()
                imgs = self.G(torch.randn((16, self.latent_dim)).cuda()).cpu().numpy()
                plot_samples(np.moveaxis(imgs, 1,-1), self.model_name, epoch)
                self.G.train()

            if save_training_gif:
                with torch.no_grad():
                    self.G.eval()
                    img_grid = vutils.make_grid(self.G(fixed_latents).cpu(), nrow=12).numpy()
                    # (width, height, channels)
                    img_grid = np.transpose(img_grid, (1, 2, 0))
                    training_progress_images.append(img_grid)
                    self.G.train()

            # if i % self.print_every == 0:
            #     print("Iteration {}".format(i + 1))
            #     print("D: {}".format(self.losses['D'][-1]))
            #     print("GP: {}".format(self.losses['GP'][-1]))
            #     print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
            #     if self.num_steps > self.critic_iterations:
            #         print("G: {}".format(self.losses['G'][-1]))

        if save_training_gif:
            imageio.mimsave('results/'+self.model_name+'/training_{}_epochs.gif'.format(epochs), .5+.5*training_progress_images)

    def _train_D(self, data):
        for p in self.D.parameters():
            p.requires_grad = True
            
        self.D_opt.zero_grad()

        batch_size = data.shape[0]
        generated_data = self.sample_generator(batch_size)

        d_real, _ = self.D(data)
        d_generated, _ = self.D(generated_data)

        gradient_penalty = self._gradient_penalty(data, generated_data)
#         consistency_term = self._consistency_term(data)

        gen_mean = d_generated.mean()
        real_mean = d_real.mean()
        d_loss = gen_mean - real_mean + self.lambda_gp*gradient_penalty #+ self.lambda_cp*consistency_term
        d_loss.backward()

        self.D_opt.step()


        return -real_mean.detach().cpu(), gen_mean.detach().cpu(), gradient_penalty.detach().cpu(), d_loss.detach().cpu()

    def _train_G(self, batch_size):
        for p in self.D.parameters():
            p.requires_grad = False
            
        self.G_opt.zero_grad()

        # Get generated data
        generated_data = self.sample_generator(batch_size)

        # Calculate loss and optimize
        d_generated, _ = self.D(generated_data)
        g_loss = -d_generated.mean()
        g_loss.backward()
        self.G_opt.step()

        return g_loss.detach().cpu()

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.shape[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data + (1 - alpha) * generated_data
        # interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        dis_interpolated, _ = self.D(interpolated)
        grad_outputs = torch.ones(dis_interpolated.shape)
        if self.use_cuda:
            grad_outputs=grad_outputs.cuda()

        # Calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(outputs=dis_interpolated, inputs=interpolated,
                               grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
#         self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0])

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = ((torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) - 1) ** 2).mean()
        # Return gradient penalty
        return gradients_norm #((gradients.norm(2, dim=1) - 1) ** 2).mean()

    def _consistency_term(self, real_data):
        d1, d_1 = self.D(real_data)
        d2, d_2 = self.D(real_data)

        ct = (d1 - d2).norm(2, dim=1) + 0.1 * (d_1 - d_2).norm(2, dim=1) - self.M
        return ct.mean()
        
    def sample_generator(self, num_samples):
        latent_samples = torch.randn((num_samples, self.latent_dim), requires_grad=True)
        if self.use_cuda:
            latent_samples = latent_samples.cuda()
        generated_data = self.G(latent_samples)
        return generated_data

    def sample(self, num_samples):
        generated_data = self.sample_generator(num_samples)
        # Remove color channel
        return generated_data.data.cpu().numpy()[:, 0, :, :]

In [None]:
wgan_ct = WGAN_CT(64, 64, 3, 'wgan_gp_torch')
wgan_ct.train(trainloader, 200, X.shape[0])

327680 2560
.......................................................................................................
0
D loss: -67.520340; real: -117.683998; fake: 15.645548; gp: 3.451820
G loss: -14.996018
.......................................................................................................
1
D loss: -48.885990; real: -100.707268; fake: 29.062805; gp: 2.275841
G loss: -28.434992
.......................................................................................................
2
D loss: -40.586342; real: -70.805710; fake: 13.051620; gp: 1.716778
G loss: -12.749321
.......................................................................................................
3
D loss: -30.321894; real: -44.799644; fake: 2.953605; gp: 1.152424
G loss: -2.371369
.......................................................................................................
4
D loss: -23.309628; real: -35.322124; fake: 4.178138; gp: 0.783428
G loss: -3.904270
.........................