In [1]:
from __future__ import print_function
import argparse
import datetime
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
from IPython.display import clear_output
from tqdm import tqdm
import cv2

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.utils import save_image
from torch.autograd import Variable
import matplotlib.colors as mat_color
from torch.utils.data import SubsetRandomSampler

In [2]:
for func in [
    lambda: os.mkdir(os.path.join('.', 'GANGEN')),
    lambda: os.mkdir(os.path.join('.', 'GANGEN/Covid')),
    lambda: os.mkdir(os.path.join('.', 'GANGEN/Normal')),
    lambda: os.mkdir(os.path.join('.', 'GANGEN/Viral Pneumonia'))]:  # create directories
    try:
        func()
    except Exception as error:
        print(error)
        continue

[Errno 17] File exists: './GANGEN'
[Errno 17] File exists: './GANGEN/Covid'
[Errno 17] File exists: './GANGEN/Normal'
[Errno 17] File exists: './GANGEN/Viral Pneumonia'


In [3]:
select_epoch = 12
generator_path = "GANAug/model/ACGAN/G_epoch_" + str(select_epoch) + ".pth"
discriminator_path = "GANAug/model/ACGAN/D_epoch_" + str(select_epoch) + ".pth"
no_norm = mat_color.Normalize(vmin=0, vmax=255, clip=False)

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):

    def __init__(self, nz, ngf, nc):

        super(Generator, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.Tanh = nn.Tanh()
        self.conv1 = nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False)
        self.BatchNorm1 = nn.BatchNorm2d(ngf * 8)

        self.conv2 = nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False)
        self.BatchNorm2 = nn.BatchNorm2d(ngf * 8)

        self.conv3 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False)
        self.BatchNorm3 = nn.BatchNorm2d(ngf * 4)

        self.conv4 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)
        self.BatchNorm4 = nn.BatchNorm2d(ngf * 2)
        
        self.conv5 = nn.ConvTranspose2d(ngf * 2, ngf * 1, 4, 2, 1, bias=False)
        self.BatchNorm5 = nn.BatchNorm2d(ngf * 1)

        self.conv6 = nn.ConvTranspose2d(ngf * 1, nc, 4, 2, 1, bias=False)

        self.apply(weights_init)


    def forward(self, input):
        x = self.conv1(input)
        x = self.BatchNorm1(x)
        x = self.ReLU(x)
        
        x = self.conv2(x)
        x = self.BatchNorm2(x)
        x = self.ReLU(x)

        x = self.conv3(x)
        x = self.BatchNorm3(x)
        x = self.ReLU(x)

        x = self.conv4(x)
        x = self.BatchNorm4(x)
        x = self.ReLU(x)

        x = self.conv5(x)
        x = self.BatchNorm5(x)
        x = self.ReLU(x)

        x = self.conv6(x)
        output = self.Tanh(x)
        return output

class Discriminator(nn.Module):

    def __init__(self, ndf, nc, nb_label):

        super(Discriminator, self).__init__()
        self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
        self.DropOut1 = nn.Dropout(p=0.5)
        self.DropOut2 = nn.Dropout(p=0.25)
        self.conv1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)
        self.BatchNorm1 = nn.BatchNorm2d(ndf)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)
        self.BatchNorm2 = nn.BatchNorm2d(ndf * 2)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)
        self.BatchNorm3 = nn.BatchNorm2d(ndf * 4)
        self.conv4_1 = nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False)
        self.BatchNorm4_1 = nn.BatchNorm2d(ndf * 4)
        self.conv4_2 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)
        self.BatchNorm4_2 = nn.BatchNorm2d(ndf * 8)
        self.conv4_3 = nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False)
        self.BatchNorm4_3 = nn.BatchNorm2d(ndf * 8)
        self.conv5 = nn.Conv2d(ndf * 8, ndf * 1, 4, 1, 0, bias=False)
        self.disc_linear = nn.Linear(ndf * 1, 1)
        self.aux_linear = nn.Linear(ndf * 1, nb_label)
        self.softmax = nn.LogSoftmax()
        self.sigmoid = nn.Sigmoid()
        self.ndf = ndf
        self.apply(weights_init)

    def forward(self, input):

        x = self.conv1(input)
        x = self.BatchNorm1(x)
        x = self.LeakyReLU(x)
        x = self.DropOut2(x)

        x = self.conv2(x)
        x = self.BatchNorm2(x)
        x = self.LeakyReLU(x)
        x = self.DropOut2(x)

        x = self.conv3(x)
        x = self.BatchNorm3(x)
        x = self.LeakyReLU(x)
        x = self.DropOut2(x)

        x = self.conv4_2(x)
        x = self.BatchNorm4_2(x)
        x = self.LeakyReLU(x)
        x = self.DropOut2(x)
        x = self.conv4_3(x)
        x = self.BatchNorm4_3(x)
        x = self.LeakyReLU(x)
        x = self.DropOut2(x)

        x = self.conv5(x)
        x = x.view(-1, self.ndf * 1)
        c = self.aux_linear(x)
        c = self.softmax(c)
        s = self.disc_linear(x)
        s = self.sigmoid(s)
        return s, c

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageSize = 128

# Number of channels in the training images. For color images this is 3
# nc = 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
# nz = 300 # v1
# nz = 400 # v2
nz = 512 # v3 & v4

# Size of feature maps in generator
ngf = 128

# Size of feature maps in discriminator
ndf = 128

# No of labels
nb_label = 3

generator = Generator(nz, ngf, nc).to(device)
discriminator = Discriminator(ndf, nc, nb_label).to(device)

In [6]:
generator.load_state_dict(torch.load(generator_path))
discriminator.load_state_dict(torch.load(discriminator_path))

<All keys matched successfully>

In [7]:
generator

Generator(
  (ReLU): ReLU(inplace=True)
  (Tanh): Tanh()
  (conv1): ConvTranspose2d(512, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (BatchNorm1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=

In [8]:
discriminator

Discriminator(
  (LeakyReLU): LeakyReLU(negative_slope=0.2, inplace=True)
  (DropOut1): Dropout(p=0.5, inplace=False)
  (DropOut2): Dropout(p=0.25, inplace=False)
  (conv1): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4_1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNorm4_1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4_2): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), b

In [9]:
label_dict = {
    0: "Covid",
    1: "Normal",
    2: "Viral Pneumonia"
}

label_count = {
    0: 0,
    1: 0,
    2: 0
}

In [10]:
def process_abnormal(c):
    cc = np.array(torch.argmax(c, dim=1).detach().cpu())
    size = len(c)
    count = 0
    for i in range(size):
        if max(c[i]) <= np.log(0.99):
            cc[i] = -1
            count += 1
    print(count/size)
    return cc

In [11]:
def generate_and_save(generator, discriminator, batch_size, path):
    # classes (list): List of the class names sorted alphabetically.
    test_noise = torch.randn(batch_size, nz, 1, 1, device=device)
    test_fake = generator(test_noise)
    now = datetime.datetime.now()
    s, c = discriminator(test_fake)
    c = process_abnormal(c)
    for i in range(batch_size):
        if c[i] == -1:
            print("too bad, man")
            continue
        label = label_dict[c[i]]
        label_count[c[i]] += 1
        image = test_fake[i].detach().cpu()
        image_show = np.transpose(test_fake[i].detach().cpu(), (1, 2, 0))
        # plt.imshow(image_show, 'gray')
        print(os.path.join(os.path.join(path, label), "CACGAN_" + str(label_count[c[i]]) + ".jpg"))
        vutils.save_image(image.data, os.path.join(os.path.join(path, label), "CACGAN_" + str(label_count[c[i]]) + "_" + now.strftime("%Y-%m-%d_%H-%M-%S") + ".png"),
                         normalize=True)

In [12]:
total_num = 251 * 5
variety = 50 * 5
batch_size = round(total_num/variety)

In [13]:
for i in range(variety):
    generate_and_save(generator, discriminator, batch_size=batch_size, path=os.path.join('.', 'GANGEN'))

  c = self.softmax(c)


0.4
too bad, man
./GANGEN/Covid/CACGAN_1.jpg
./GANGEN/Covid/CACGAN_2.jpg
./GANGEN/Covid/CACGAN_3.jpg
too bad, man
0.8
too bad, man
too bad, man
too bad, man
too bad, man
./GANGEN/Covid/CACGAN_4.jpg
0.4
./GANGEN/Covid/CACGAN_5.jpg
./GANGEN/Covid/CACGAN_6.jpg
./GANGEN/Covid/CACGAN_7.jpg
too bad, man
too bad, man
0.2
./GANGEN/Covid/CACGAN_8.jpg
too bad, man
./GANGEN/Covid/CACGAN_9.jpg
./GANGEN/Covid/CACGAN_10.jpg
./GANGEN/Covid/CACGAN_11.jpg
0.2
./GANGEN/Covid/CACGAN_12.jpg
./GANGEN/Covid/CACGAN_13.jpg
too bad, man
./GANGEN/Covid/CACGAN_14.jpg
./GANGEN/Covid/CACGAN_15.jpg
0.2
./GANGEN/Covid/CACGAN_16.jpg
./GANGEN/Covid/CACGAN_17.jpg
./GANGEN/Covid/CACGAN_18.jpg
./GANGEN/Covid/CACGAN_19.jpg
too bad, man
0.4
too bad, man
./GANGEN/Covid/CACGAN_20.jpg
too bad, man
./GANGEN/Covid/CACGAN_21.jpg
./GANGEN/Covid/CACGAN_22.jpg
0.4
too bad, man
too bad, man
./GANGEN/Covid/CACGAN_23.jpg
./GANGEN/Covid/CACGAN_24.jpg
./GANGEN/Covid/CACGAN_25.jpg
0.4
too bad, man
./GANGEN/Covid/CACGAN_26.jpg
./GANGEN/Co