In [1]:
!pip install opencv-python torchsummary scikit-learn torchviz utils

Looking in indexes: http://mirrors.aliyun.com/pypi/simple


In [2]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import utils
from torch.nn.functional import one_hot
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.utils.data import SubsetRandomSampler

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.get_arch_list(), device)

['sm_37', 'sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86'] cuda


In [4]:
from __future__ import print_function
import argparse
import datetime
import matplotlib.pyplot as plt
from IPython.display import HTML
from IPython.display import clear_output
from tqdm import tqdm
import os
from tqdm import notebook
from sklearn.metrics import classification_report, confusion_matrix
import gc
import matplotlib.colors as mat_color
import numpy as np
from PIL import Image
import random
import cv2
from torchvision.datasets import ImageNet, ImageFolder
from torch.utils.data import DataLoader
from torchsummary import summary
from torchviz import make_dot

In [5]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

class Generator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim + self.class_num, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        x = self.fc(x)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return x

    
class Discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
        )
        self.dc = nn.Sequential(
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        self.cl = nn.Sequential(
            nn.Linear(1024, self.class_num),
        )
        initialize_weights(self)

    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        x = self.fc1(x)
        d = self.dc(x)
        c = self.cl(x)

        return d, c


In [6]:
base_path = './data'
base_folder = "Covid-19 Image Dataset"
classic_folder = 'Coivd-19_Classic'
synthetic_folder = 'Coivd-19_Synthetic'
data_dir = os.path.join(base_path, classic_folder)

In [7]:
img_size = 512
batch_size = 16
train_path = os.path.join(data_dir, "train")
test_path = os.path.join(data_dir, "test")

In [8]:
labels = os.listdir(train_path)
print(labels)
no_norm = mat_color.Normalize(vmin=0, vmax=255, clip=False)

['Covid', 'Normal', 'Viral Pneumonia']


In [9]:
# Number of training epochs
num_epochs = 100

#All images will be resized to this size using a transformer.
#image_size = 64
imageSize = 512

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Size of feature maps in generator
ngf = imageSize

# Size of feature maps in discriminator
ndf = imageSize

# No of labels
nb_label = len(labels)

# Learning rate for optimizers
lr = 0.002
lr_d = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Beta2 hyperparam for Adam optimizers
beta2 = 0.999

real_label = 1.
fake_label = 0.
# Input to generator
# fixed_noise = torch.randn(64, nz, 1, 1, device=device) #batch of 64
# Define Loss function
s_criterion = nn.BCELoss().to(device) #For synthesizing
c_criterion = nn.CrossEntropyLoss().to(device) #For classification

# input = torch.FloatTensor(batch_size, nc, imageSize, imageSize).to(device)
# noise = torch.FloatTensor(batch_size, nz, 1, 1).to(device)
# fixed_noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).to(device)
s_label = torch.FloatTensor(batch_size).to(device)
c_label = torch.FloatTensor(batch_size, 3).to(device)
# s_label = torch.FloatTensor(batch_size).to(device)
# c_label = torch.LongTensor(batch_size).to(device)

# input = Variable(input)
# s_label = Variable(s_label)
# c_label = Variable(c_label)
print(s_label.shape)
print(c_label.shape)
print(s_label)
print(c_label)
# noise = Variable(noise)
# fixed_noise = Variable(fixed_noise)
# fixed_noise_ = np.random.normal(0, 1, (batch_size, nz))
# random_label = np.random.randint(0, nb_label, batch_size)
# #print('fixed label:{}'.format(random_label))
# random_onehot = np.zeros((batch_size, nb_label))
# random_onehot[np.arange(batch_size), random_label] = 1
# fixed_noise_[np.arange(batch_size), :nb_label] = random_onehot[np.arange(batch_size)]


# fixed_noise_ = (torch.from_numpy(fixed_noise_))
# fixed_noise_ = fixed_noise_.resize_(batch_size, nz, 1, 1)
# fixed_noise.data.copy_(fixed_noise_)

torch.Size([16])
torch.Size([16, 3])
tensor([4.2706e-20, 4.5717e-41, 7.7899e-24, 3.0737e-41, 4.4842e-44, 0.0000e+00,
        8.9683e-44, 0.0000e+00, 7.8375e-24, 3.0737e-41, 2.2421e-44, 0.0000e+00,
               nan, 0.0000e+00, 2.6800e+20, 1.7288e+28], device='cuda:0')
tensor([[ 0.0000e+00,  0.0000e+00,  7.4351e-25],
        [ 3.0737e-41, -5.2732e-14,  0.0000e+00],
        [ 3.6013e-43,  0.0000e+00,  1.6120e-23],
        [ 3.0737e-41,  4.2706e-20,  4.5717e-41],
        [ 2.1943e-38,  0.0000e+00,  3.1529e-43],
        [ 0.0000e+00,  1.6120e-23,  3.0737e-41],
        [ 4.2706e-20,  4.5717e-41,  6.4423e-30],
        [ 0.0000e+00,  2.7045e-43,  0.0000e+00],
        [ 1.6120e-23,  3.0737e-41,  4.2706e-20],
        [ 4.5717e-41, -2.6384e+26,  0.0000e+00],
        [ 2.2561e-43,  0.0000e+00,  1.6120e-23],
        [ 3.0737e-41,  4.2706e-20,  4.5717e-41],
        [-1.9493e-13,  0.0000e+00,  1.8077e-43],
        [ 0.0000e+00,  1.6120e-23,  3.0737e-41],
        [ 4.2706e-20,  4.5717e-41, -9.1864e

In [10]:
for func in [
    lambda: os.mkdir(os.path.join('.', 'GANAug')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/model')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/plots')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/model/ACGAN')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/plots/ACGAN')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/output_images')),
    lambda: os.mkdir(os.path.join('.', 'GANAug/output_images/ACGAN'))]:  # create directories
    try:
        func()
    except Exception as error:
        print(error)
        continue

[Errno 17] File exists: './GANAug'
[Errno 17] File exists: './GANAug/model'
[Errno 17] File exists: './GANAug/plots'
[Errno 17] File exists: './GANAug/model/ACGAN'
[Errno 17] File exists: './GANAug/plots/ACGAN'
[Errno 17] File exists: './GANAug/output_images'
[Errno 17] File exists: './GANAug/output_images/ACGAN'


In [11]:
METRIC_FIELDS = [
    'train.D_x',
    'train.D_G_z1',
    'train.D_G_z2',
    'train.G_losses',
    'train.D_losses',
]
metrics = {field: list() for field in METRIC_FIELDS}

In [12]:
def load_dataset(train_dir=train_path, test_dir=test_path):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    train_loader = DataLoader(train_data, batch_size, shuffle=True, num_workers=1)

    test_data = datasets.ImageFolder(test_dir ,transform=transform)
    test_loader = DataLoader(test_data, batch_size, shuffle=True, num_workers=1)

    return train_loader, test_loader, train_data, test_data

In [13]:
train_loader, test_loader, train_data, test_data = load_dataset()

In [14]:
generator = Generator(input_dim=nz, output_dim=nc, input_size=ngf, class_num=nb_label).to(device)
discriminator = Discriminator(input_dim=nc, output_dim=1, input_size=ndf, class_num=nb_label).to(device)

# setup optimizer
optimizerD = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))

In [15]:
g_output = generator(torch.rand((batch_size, nz)).to(device), torch.ones((batch_size, nb_label)).to(device))
g_output.shape

torch.Size([16, 3, 512, 512])

In [16]:
s_output, c_output = discriminator(torch.rand(g_output.shape).to(device))
s_output.shape

torch.Size([16, 1])

In [17]:
c_output.shape

torch.Size([16, 3])

In [18]:
s_output

tensor([[0.5457],
        [0.3780],
        [0.4582],
        [0.5439],
        [0.3552],
        [0.3951],
        [0.4220],
        [0.5855],
        [0.3634],
        [0.5419],
        [0.3884],
        [0.3118],
        [0.3516],
        [0.3942],
        [0.3329],
        [0.4155]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [19]:
c_output

tensor([[-0.2222,  0.2849,  0.0235],
        [-0.3323,  0.3322,  0.7577],
        [-0.1640,  0.9260, -0.1262],
        [ 0.6716, -0.2795,  0.2536],
        [ 0.4336,  0.0124, -0.2183],
        [-0.5017,  0.4373,  0.9848],
        [-0.0200,  0.1337, -0.2580],
        [ 0.3637,  0.0554, -0.1926],
        [-0.3055,  0.2880, -0.2432],
        [-0.1434,  0.4411,  0.1116],
        [ 0.6108, -0.7236,  0.3336],
        [ 0.4692, -0.1671, -0.2834],
        [-0.7962,  0.8066,  0.2995],
        [-0.8408, -0.2449, -0.3654],
        [ 0.1333,  0.0596, -0.6922],
        [ 0.4729,  0.1008,  0.2158]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [20]:
# exit()

In [21]:
!pip install torchviz

Looking in indexes: http://mirrors.aliyun.com/pypi/simple


In [22]:
print(generator)
print(discriminator)

Generator(
  (fc): Sequential(
    (0): Linear(in_features=131, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=2097152, bias=True)
    (4): BatchNorm1d(2097152, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (deconv): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Tanh()
  )
)
Discriminator(
  (conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_run

In [23]:
SHOW = False
if SHOW:
    summary(generator, (nz, ngf, nc), batch_size=1, device=device)

In [24]:
if SHOW:
    summary(discriminator, (nc, ndf, ndf), batch_size=1, device=device)

In [25]:
SHOW_IMG = False
if SHOW_IMG:
    def modeltorchviz(model,input2):
        y = model(input2.cuda())    # 获取网络的预测值
        MyConvNetVis = make_dot(y, params=dict(list(model.named_parameters()) + [('x', input2)]))
        MyConvNetVis.format = "png"
        # 指定文件生成的文件夹
        MyConvNetVis.directory = "images"
        # 生成文件
        MyConvNetVis.view() 

In [26]:
if SHOW_IMG:
    modeltorchviz(generator, torch.randn(1, nz, ngf, nc).requires_grad_(True))

In [27]:
if SHOW_IMG:
    modeltorchviz(discriminator, torch.randn(1, nc, ndf, ndf).requires_grad_(True))

In [28]:
def test(predict, labels):
    correct = 0
    pred = predict.data.max(1)[1]
    correct = pred.eq(labels.data).cpu().sum()
    return correct, len(labels.data)

In [31]:
for epoch in range(2):
    for i, data in enumerate(tqdm(train_loader, 0)):
        ###########################
        # (1) Update D network
        ###########################
        # train with real
        discriminator.zero_grad()
        img, label = data
        batch_size = img.size(0)
        with torch.no_grad():
            img, label = img.to(device), label.to(device)
#             input.resize_(img.size()).copy_(img)
            print("img.shape", img.shape)
            print("label_ori", label)
            print("label_ori.shape", label.shape)
            label = one_hot(label)
            print("label", label)
            print("s_label.shape", s_label.shape)
            print("c_label.shape", c_label.shape)
            print("real_label", real_label)
            s_label.resize_(batch_size).fill_(real_label)
            c_label.resize_(batch_size, 3).copy_(label)
            print("s_label.shape", s_label.shape)
            print("c_label.shape", c_label.shape)
        s_output, c_output = discriminator(img)
        print("s_output.shape", s_output.shape)
        print("s_label.shape", s_label.shape)
        print("="*80)
        print("s_output", s_output)
        print("s_label", s_label)
        print("="*80)
        print("c_label", c_label)
        print("c_output", c_output)
        print("="*80)
        s_errD_real = s_criterion(s_output[:,0], s_label.resize_(batch_size))
#         c_output = torch.argmax(c_output, dim=1)
#         c_label = torch.argmax(c_label, dim=1)
        print("c_label.shape", c_label.shape)
        print("c_output.shape", c_output.shape)
        print("="*80)
        print("c_label", c_label)
        print("c_output", c_output)
        print("="*80)
        print(c_output.float())
        print(c_label.float())
        c_errD_real = c_criterion(c_output.float(), c_label.float())
        errD_real = s_errD_real + c_errD_real
        errD_real.backward()
        D_x = s_output.data.mean()
        
        correct, length = test(c_output, c_label)

        # train with fake
        with torch.no_grad():
            noise.resize_(batch_size, nz, 1, 1)
            noise.normal_(0, 1)

        label = np.random.randint(0, nb_label, batch_size)
        noise_ = np.random.normal(0, 1, (batch_size, nz))
        label_onehot = np.zeros((batch_size, nb_label))
        label_onehot[np.arange(batch_size), label] = 1
        noise_[np.arange(batch_size), :nb_label] = label_onehot[np.arange(batch_size)]
        
        noise_ = (torch.from_numpy(noise_))
        noise_ = noise_.resize_(batch_size, nz, 1, 1)
        noise.data.copy_(noise_)

        c_label.data.resize_(batch_size).copy_(torch.from_numpy(label))

        fake = generator(noise)
        s_label.data.fill_(fake_label)
        s_output,c_output = discriminator(fake.detach())
        s_errD_fake = s_criterion(s_output, s_label)
        c_errD_fake = c_criterion(c_output, c_label)
        errD_fake = s_errD_fake + c_errD_fake

        errD_fake.backward()
        D_G_z1 = s_output.data.mean()
        errD = s_errD_real + s_errD_fake
        optimizerD.step()

        ###########################
        # (2) Update G network
        ###########################
        generator.zero_grad()
        s_label.data.fill_(real_label)  # fake labels are real for generator cost
        s_output,c_output = discriminator(fake)
        s_errG = s_criterion(s_output, s_label)
        c_errG = c_criterion(c_output, c_label)
        
        errG = s_errG + c_errG
        errG.backward()
        D_G_z2 = s_output.data.mean()
        optimizerG.step()
        metrics['train.G_losses'].append(errG.item())
        metrics['train.D_losses'].append(errD.item())

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f, Accuracy: %.4f / %.4f = %.4f'
              % (epoch, num_epochs, i, len(train_loader),
                 errD.data, errG.data, D_x, D_G_z1, D_G_z2,
                 correct, length, 100.* correct / length))
        if i % 100 == 0:
            vutils.save_image(img,
                    '%s/real_samples.jpg' % './augGAN/output_images/ACGAN', normalize=True)
            #fake = netG(fixed_cat)
            fake = generator(fixed_noise)
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03d.jpg' % ('./augGAN/output_images/ACGAN', epoch), normalize=True)

    # do checkpointing
    #torch.save(generator.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join('.', 'augGAN/model/ACGAN'), epoch))
    #torch.save(discriminator.state_dict(), '%s/netD_epoch_%d.pth' % (os.path.join('.', 'augGAN/model/ACGAN'), epoch))

  0%|                                                                                                | 0/189 [00:00<?, ?it/s]

img.shape torch.Size([16, 3, 512, 512])
label_ori tensor([0, 0, 0, 2, 1, 1, 1, 1, 2, 0, 0, 0, 1, 2, 2, 0], device='cuda:0')
label_ori.shape torch.Size([16])
label tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [0, 0, 1],
        [1, 0, 0]], device='cuda:0')
s_label.shape torch.Size([16])
c_label.shape torch.Size([16, 3])
real_label 1.0
s_label.shape torch.Size([16])
c_label.shape torch.Size([16, 3])
s_output.shape torch.Size([16, 1])
s_label.shape torch.Size([16])
s_output tensor([[0.3702],
        [0.4009],
        [0.5535],
        [0.4673],
        [0.5046],
        [0.2711],
        [0.4713],
        [0.5688],
        [0.3739],
        [0.1712],
        [0.3621],
        [0.4306],
        [0.5800],
        [0.4749],
        [0.4907],
        [0.3888]], device

  0%|                                                                                                | 0/189 [00:02<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 8.00 GiB (GPU 0; 23.69 GiB total capacity; 18.82 GiB already allocated; 2.69 GiB free; 18.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def test2(generator, discriminator, num_epochs, metrics, loader):
    print('Testing Block.........')
    now = datetime.datetime.now()
    #g_losses = metrics['train.G_losses'][-1]
    #d_losses = metrics['train.D_losses'][-1]
    path='augGAN/output_images/ACGAN'
    try:
      os.mkdir(os.path.join('.', path))
    except Exception as error:
      print(error)

    real_batch = next(iter(loader))
    
    test_img_list = []
    test_noise = torch.randn(batch_size, nz, 1, 1, device=device)
    test_fake = generator(test_noise).detach().cpu()
    test_img_list.append(vutils.make_grid(test_fake, padding=2, normalize=True))

    fig = plt.figure(figsize=(15,15))
    ax1 = plt.subplot(1,2,1)
    ax1 = plt.axis("off")
    ax1 = plt.title("Real Images")
    ax1 = plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

    ax2 = plt.subplot(1,2,2)
    ax2 = plt.axis("off")
    ax2 = plt.title("Fake Images")
    ax2 = plt.imshow(np.transpose(test_img_list[-1],(1,2,0)))
    #ax2 = plt.show()
    #fig.savefig('%s/image_%.3f_%.3f_%d_%s.png' %
    #                (path, g_losses, d_losses, num_epochs, now.strftime("%Y-%m-%d_%H:%M:%S")))

In [None]:
def plot(name, train_epoch, values, path, save):
    clear_output(wait=True)
    plt.close('all')
    fig = plt.figure()
    fig = plt.ion()
    fig = plt.subplot(1, 1, 1)
    fig = plt.title('epoch: %s -> %s: %s' % (train_epoch, name, values[-1]))
    fig = plt.ylabel(name)
    fig = plt.xlabel('train_set')
    fig = plt.plot(values)
    fig = plt.grid()
    get_fig = plt.gcf()
    fig = plt.draw()  # draw the plot
    fig = plt.pause(1)  # show it for 1 second
    if save:
        now = datetime.datetime.now()
        get_fig.savefig('%s/%s_%.3f_%d_%s.png' %
                        (path, name, train_epoch, values[-1], now.strftime("%Y-%m-%d_%H:%M:%S")))

In [None]:
def save_model(generator, discriminator, gen_optimizer, dis_optimizer, metrics, num_epochs):
    now = datetime.datetime.now()
    g_losses = metrics['train.G_losses'][-1]
    d_losses = metrics['train.D_losses'][-1]
    name = "%+.3f_%+.3f_%d_%s.dat" % (g_losses, d_losses, num_epochs, now.strftime("%Y-%m-%d_%H:%M:%S"))
    # fname = os.path.join('.', 'augGAN/model', name)
    # states = {
    #         'state_dict_generator': generator.state_dict(),
    #         'state_dict_discriminator': discriminator.state_dict(),
    #         'gen_optimizer': gen_optimizer.state_dict(),
    #         'dis_optimizer': dis_optimizer.state_dict(),
    #         'metrics': metrics,
    #         'train_epoch': num_epochs,
    #         'date': now.strftime("%Y-%m-%d_%H:%M:%S"),
    # }
    # torch.save(states, fname)
    path='augGAN/plots/ACGAN/train_%+.3f_%+.3f_%s'% (g_losses, d_losses, now.strftime("%Y-%m-%d_%H:%M:%S"))
    try:
        os.mkdir(os.path.join('.', path))
    except Exception as error:
        print(error)

    plot('G_losses', num_epochs, metrics['train.G_losses'], path, True)
    plot('D_losses', num_epochs, metrics['train.D_losses'], path, True)
    plot('D_x', num_epochs, metrics['train.D_x'], path, True)
    plot('D_G_z1', num_epochs, metrics['train.D_G_z1'], path, True)
    plot('D_G_z2', num_epochs, metrics['train.D_G_z2'], path, True)

In [None]:
test2(generator, discriminator, num_epochs, metrics, train_loader)

In [None]:
save_model(generator, discriminator, optimizerG, optimizerD, metrics, num_epochs)

In [None]:
test_batch = 16
test_fake = 1

if test_fake:
    #check for fake image
    test_img_list = []
    test_noise = torch.randn(test_batch, nz, 1, 1, device=device)
    test_img = generator(test_noise)#.detach().cpu()

else:
    #check for real image
    test_loader = torch.utils.data.DataLoader(train_set, batch_size=test_batch,
                                            shuffle=True)
    data = next(iter(test_loader))
    test_noise, test_class_lable = data
    test_img.data.resize_(test_noise.size()).copy_(test_noise)
    #print(data[0].size())
    print('class label for real', test_class_lable)

s_output,c_label_op = discriminator(test_img.detach().to(device))
print('Discriminator s o/p', s_output)
print('Discriminator c o/p', c_label_op)

# label = np.random.randint(0, nb_label, batch_size)
# c_label.data.resize_(batch_size).copy_(torch.from_numpy(label))
# print(c_label)

test_img = test_img.detach().cpu()
test_img_list.append(vutils.make_grid(test_img, padding=2, normalize=True))
plt.imshow(np.transpose(test_img_list[-1],(1,2,0)))