In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_VISIBLE_DEVICES=0
env: CUDA_LAUNCH_BLOCKING=1


In [2]:
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn import utils
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Resize, RandomHorizontalFlip

from data_utils import CelebaDataset, BatchCollate
from models import Generator, ConcatGenerator, ProjectionDiscriminator

In [3]:
from modules import Autoencoder

### Check Generator

In [4]:
gen = Generator()

In [5]:
x = torch.rand(4, 256)
y = torch.randint(0, 4, (4,))
# y = nn.functional.one_hot(torch.randint(0, 4, (4,)))

In [6]:
gen(x, y)

tensor([[0.7402, 0.5766, 1.4744,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 2.9183, 0.0000,  ..., 0.4084, 0.0000, 2.3889],
        [0.0000, 0.0000, 0.0000,  ..., 2.8478, 0.0000, 0.0000],
        [1.6395, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)

### Check Discriminator

In [7]:
x = torch.rand(4, 256)
y = torch.randint(0, 2, (4,))

In [8]:
ProjectionDiscriminator(num_classes=4)(x, y)

tensor([[ 0.2372],
        [ 0.1720],
        [-0.4609],
        [ 0.0255]], grad_fn=<AddBackward0>)

### Dataset

In [9]:
# transforms = Compose([
#     Resize((64, 64)),
#     RandomHorizontalFlip(),
#     ToTensor()
# ])

# features = ['Black_Hair']
# dataset = CelebaDataset(usecols=features,
#                         attr_file='CelebA/list_attr_celeba.txt', 
#                         dataset_location='CelebA/img_align_celeba/', 
#                         transforms=transforms)

# image, label = dataset[np.random.choice(len(dataset))]
# print('\n'.join([name + f": {val}" for name, val in zip(features, dataset.one2multi[label])]))
# plt.imshow(image.permute(1, 2, 0))
# plt.axis('off');

# image, label = dataset[np.random.choice(len(dataset))]
# print('\n'.join([name + f": {val}" for name, val in zip(features, dataset.one2multi[label])]))
# plt.imshow(image.permute(1, 2, 0))
# plt.axis('off');

### Training loop

In [10]:
def sample_target_labels(y_source, num_classes=2):
    return torch.randint(0, num_classes, y_source.shape)

In [11]:
class DisLoss(nn.Module):
    def __init__(self, loss_type):
        super(DisLoss, self).__init__()
        self.loss_type = loss_type
        
        if loss_type == 'hinge':
            self.critetion = self.hinge_loss
        elif loss_type == 'dcgan':
            self.critetion = self.dcgan_loss
        elif loss_type == 'lsgan':
            self.critetion = self.lsgan_loss
        else:
            raise ValueError('Not supported! Choose from [\'hinge\', \'dcgan\', \'lsgan\'].')
        
    def hinge_loss(self, dis_real, dis_fake):
        return (1. - dis_real).relu().mean() + (1. + dis_fake).relu().mean()
    
    def dcgan_loss(self, dis_real, dis_fake):
        return -torch.log(F.softplus(-dis_real).mean().sigmoid()) - \
                torch.log(1.0 - F.softplus(dis_fake).mean().sigmoid())
    
    def lsgan_loss(self, dis_real, dis_fake):
        return 0.5 * ((dis_real - 1).pow(2).mean() + dis_fake.pow(2).mean())
        
    def forward(self, dis_real, dis_fake):
        """
        PARAMS:
            dis_real (Tensor): discriminator scores for real data
            dis_fake (Tensor): discriminator scores for fake outputs of generator
        """
        return self.critetion(dis_real, dis_fake)
    
    
class GenLoss(nn.Module):
    def __init__(self, loss_type):
        super(GenLoss, self).__init__()
        self.loss_type = loss_type
        
        if loss_type == 'hinge':
            self.critetion = self.hinge_loss
        elif loss_type == 'dcgan':
            self.critetion = self.dcgan_loss
        elif loss_type == 'lsgan':
            self.critetion = self.lsgan_loss
        else:
            raise ValueError('Not supported! Choose from [\'hinge\', \'dcgan\', \'lsgan\'].')
        
    def hinge_loss(self, dis_fake):
        return -dis_fake.mean()
    
    def dcgan_loss(self, dis_fake):
        return -torch.log(F.softplus(-dis_fake).mean().sigmoid())
    
    def lsgan_loss(self, dis_fake):
        return 0.5 * (dis_fake - 1).pow(2).mean()
        
    def forward(self, dis_fake):
        """
        PARAMS: 
            dis_fake (Tensor): discriminator scores for fake outputs of generator
        """
        return self.critetion(dis_fake)

### Default Dataset with images 

In [12]:
BATCH_SIZE = 64

In [13]:
transforms = Compose([
    Resize((64, 64)),
    RandomHorizontalFlip(),
    ToTensor()
])

In [54]:
features = ['Black_Hair']
# features = ['Male']
dataset = CelebaDataset(usecols=features,
                        attr_file='CelebA/list_attr_celeba.txt', 
                        dataset_location='CelebA/img_align_celeba/', 
                        transforms=transforms)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4,
                        collate_fn=BatchCollate(), shuffle=True)

In [17]:
num_classes = 2

G = Generator(hidden_size=128, num_classes=num_classes).cuda()
D = ProjectionDiscriminator(hidden_size=128, num_classes=num_classes).cuda()

args =  {'dataset': 'MNIST',
         'eval_each': 10,
         'epochs': 101,
         'log_dir': 'CelebA64_256_v2/',
         'device': 'cuda:0',
         'weight_decay': 1e-05,
         'depth': 16,
         'gamma': 0.2,
         'lmbda': 0.5,
         'batch_norm': False,
         'batch_size': 64,
         'colors': 3,
         'latent_width': 4, # Bottleneck HW
         'width': 128, # Means 4 downsampling blocks
         'latent': 32, # Bottleneck channels
         'n_classes': 10,
         'advdepth': 16,
         'lr': 0.0001}

scales = int(round(math.log(args['width'] // args['latent_width'], 2)))
ae = Autoencoder(scales=scales,depth=args['depth'],latent=args['latent'],colors=args['colors']).to(args['device']).eval()

ae.load_state_dict(torch.load('acai_64.pt', map_location='cuda:0'))

optimizer_g = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0, 0.9))
optimizer_d = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0, 0.9))
# optimizer_g = torch.optim.Adam(G.parameters(), betas=(0, 0.9))
# optimizer_d = torch.optim.Adam(D.parameters(), betas=(0, 0.9))

loss_type = 'hinge'
criterion_d = DisLoss(loss_type)
criterion_g = GenLoss(loss_type)

In [None]:
iteration = 0
for epoch in range(10):
    for _ in tqdm_notebook(range(len(dataloader))):
        # =================================================================== #
        #                         1. Get new batch                            #
        # =================================================================== #
        
        images, y_real = next(iter(dataloader))
        y_fake = sample_target_labels(y_real, num_classes)
        images, y_real, y_fake = images.cuda(), y_real.cuda(), y_fake.cuda()
        # y_real_hot = F.one_hot(y_real, num_classes)
        # y_fake_hot = F.one_hot(y_fake, num_classes)
        
        bs = images.size(0)

        # get latent codes of images
        with torch.no_grad():
            latent_real = ae.encoder(images).view(bs, -1)
        
        # =================================================================== #
        #                        2. Train Discriminator                       #
        # =================================================================== #
        
        dis_real = D(latent_real, y_real)      
                           
        images, y_real = next(iter(dataloader))
        y_fake = sample_target_labels(y_real, num_classes)
        images, y_real, y_fake = images.cuda(), y_real.cuda(), y_fake.cuda()
        
        bs = images.size(0)

        with torch.no_grad():
            latent_real = ae.encoder(images).view(bs, -1)
        
        latent_fake = G(latent_real, y_fake)
        dis_fake = D(latent_fake, y_fake)

        d_loss = criterion_d(dis_real, dis_fake)

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # =================================================================== #
        #                         3. Train Generator                          #
        # =================================================================== #
        images, y_real = next(iter(dataloader))
        y_fake = sample_target_labels(y_real, num_classes)
        images, y_real, y_fake = images.cuda(), y_real.cuda(), y_fake.cuda()
        
        bs = images.size(0)

        with torch.no_grad():
            latent_real = ae.encoder(images).view(bs, -1)
        
        latent_fake = G(latent_real, y_fake)
        dis_fake = D(latent_fake, y_fake)
        
        latent_cyclic =  G(G(latent_real, y_fake), y_real)
        cyclic_loss = F.l1_loss(latent_cyclic, latent_real)
        
        # latent_id =  G(latent_real, y_real)
        # identity_loss = F.l1_loss(latent_id, latent_real)
        
        g_loss = criterion_g(dis_fake) + cyclic_loss

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
        if iteration % 50 == 0:
            print("Step: {} | D_loss: {:.3f} | G_loss: {:.3f} | Cyclic loss: {:.3f}".format(iteration, 
                                                                        d_loss.item(), 
                                                                        g_loss.item(),
                                                                        cyclic_loss.item()))
            
            
        if iteration % 100 == 0:
            with torch.no_grad():
                latent_real = ae.encoder(images).view(bs, -1)
                latent_fake = G(latent_real, y_fake)
                x_decoded = ae.decoder(latent_fake.view(bs, -1, 2, 2)).detach().cpu().permute(0, 2, 3, 1)
            
            images = images.detach().cpu().permute(0, 2, 3, 1)
            target = np.array([dataset.one2multi[a.item()] for a in y_fake])
            target[target == -1] = 0

            fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(20, 5))
            for i in range(2):
                for j in range(8):
                    if i == 0:
                        name = '\n'.join([name + f": {val}" for name, val in zip(dataset.usecols, target[j])])
                        ax[i][j].set_title(name, fontsize=14)
                        ax[i][j].imshow(images[j], aspect='auto')
                        ax[i][j].axis('off')
                    else:
                        ax[i][j].imshow(x_decoded[j], aspect='auto')
                        ax[i][j].axis('off')
            plt.tight_layout(pad=0)
            plt.show()
        
        iteration += 1

### Embedding Dataset

In [15]:
class EmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, emb):
        self.embedding = emb
        one_hot_labels = np.array([a[-2:].numpy() for a in emb])
        unique_labels = np.unique(one_hot_labels, axis=0)
        self.num_classes = unique_labels.shape[0]
        self.multi2one = {tuple(label): i for i, label in enumerate(unique_labels)}
        self.one2multi = {i: label.tolist() for i, label in enumerate(unique_labels)}
        
        self.labels = np.apply_along_axis(lambda x: self.multi2one[tuple(x)], 
                                          axis=1,
                                          arr=one_hot_labels)
        
        
    def __len__(self):
        return len(self.embedding)
    
    def __getitem__(self, idx):
        obj = self.embedding[idx]
        label = self.labels[idx]
        while label == self.num_classes - 1:
            idx = np.random.choice(self.__len__())
            obj = self.embedding[idx]
            label = self.labels[idx]
        return (obj[:-2], label)


class BatchCollate(object):
    def __call__(self, batch):
        images = torch.stack([img for img, _ in batch])
        labels = torch.from_numpy(np.array([label for _, label in batch]))
        return images, labels

In [16]:
features = ['Black_Hair', 'Blond_Hair']
train_dataset = EmbeddingDataset(torch.load('Embedding_CelebA64.pth'))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                          num_workers=0, shuffle=True,
                          collate_fn=BatchCollate())

In [17]:
DEVICE = 'cuda:0'
num_classes = 2
G = Generator(hidden_size=128, num_classes=num_classes).to(DEVICE)
D = ProjectionDiscriminator(hidden_size=128, num_classes=num_classes).to(DEVICE)

args =  {'dataset': 'MNIST',
         'eval_each': 10,
         'epochs': 101,
         'log_dir': 'CelebA64_256_v2/',
         'device': DEVICE,
         'weight_decay': 1e-05,
         'depth': 16,
         'gamma': 0.2,
         'lmbda': 0.5,
         'batch_norm': False,
         'batch_size': 64,
         'colors': 3,
         'latent_width': 4, # Bottleneck HW
         'width': 128, # Means 4 downsampling blocks
         'latent': 32, # Bottleneck channels
         'n_classes': 10,
         'advdepth': 16,
         'lr': 0.0001}

scales = int(round(math.log(args['width'] // args['latent_width'], 2)))
ae = Autoencoder(scales=scales,depth=args['depth'],latent=args['latent'],colors=args['colors']).to(args['device']).eval()

ae.load_state_dict(torch.load('acai_64.pt', map_location=args['device']))

optimizer_g = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0, 0.9))
optimizer_d = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0, 0.9))

loss_type = 'hinge'

criterion_d = DisLoss(loss_type)
criterion_g = GenLoss(loss_type)

G.train();
D.train();

BCE = nn.BCEWithLogitsLoss()

In [None]:
iteration = 0
dis_fakes = []
dis_reals = []
for epoch in range(10):
    for _ in tqdm_notebook(range(len(train_loader))):
        # =================================================================== #
        #                         1. Get new batch                            #
        # =================================================================== #
        
        latent_real, y_real = next(iter(train_loader))
        y_fake = sample_target_labels(y_real, num_classes)
        latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
        
        bs = latent_real.size(0)
        
        # =================================================================== #
        #                        2. Train Discriminator                       #
        # =================================================================== #
        
        dis_real = D(latent_real, y_real)

        latent_real, y_real = next(iter(train_loader))
        y_fake = sample_target_labels(y_real, num_classes)
        latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
        
        latent_fake = G(latent_real, y_fake)
        dis_fake = D(latent_fake, y_fake)
        
        dis_fakes.append(dis_fake)

        d_loss = criterion_d(dis_real, dis_fake)
        # d_loss = BCE(dis_real, torch.ones_like(dis_real).to(DEVICE)) + \
        #          BCE(dis_fake, torch.zeros_like(dis_fake).to(DEVICE))
        

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # =================================================================== #
        #                         3. Train Generator                          #
        # =================================================================== #
        latent_real, y_real = next(iter(train_loader))
        y_fake = sample_target_labels(y_real, num_classes)
        latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
        
        latent_fake = G(latent_real, y_fake)
        dis_fake = D(latent_fake, y_fake)
        
        latent_cyclic =  G(latent_fake, y_real)
        cyclic_loss = F.l1_loss(latent_cyclic, latent_real)
        
        dis_fakes.append(dis_fake)
        dis_reals.append(dis_real)
        
        # latent_id =  G(latent_real, y_real)
        # identity_loss = F.l1_loss(latent_id, latent_real)
        
        g_loss = criterion_g(dis_fake) + cyclic_loss
        # g_loss = BCE(dis_fake, torch.ones_like(dis_fake).to(DEVICE)) + cyclic_loss

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
        if iteration % 50 == 0:
            print("Step: {} | D_loss: {:.3f} | G_loss: {:.3f} | Cyclic loss: {:.3f}".format(iteration, 
                                                                        d_loss.item(), 
                                                                        g_loss.item(),
                                                                        cyclic_loss.item()))
            
            
        if iteration % 100 == 0:
            G.eval()
            with torch.no_grad():
                latent_fake = G(latent_real, y_fake)
                x_decoded = ae.decoder(latent_fake.view(bs, -1, 2, 2)).detach().cpu().permute(0, 2, 3, 1)
                x_decoded = (x_decoded - x_decoded.min()) / (x_decoded.max() - x_decoded.min())
            
            images = ae.decoder(latent_real.view(bs, -1, 2, 2)).detach().cpu().permute(0, 2, 3, 1)
            images = (images - images.min()) / (images.max() - images.min())
            target = np.array([train_dataset.one2multi[a.item()] for a in y_fake]).astype(int)
            target[target == -1] = 0

            fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(20, 5))
            # print(y_fake[:8])
            for i in range(2):
                for j in range(8):
                    if i == 0:
                        name = '\n'.join([name + f": {val}" for name, val in zip(features, target[j])])
                        ax[i][j].set_title(name, fontsize=14)
                        ax[i][j].imshow(images[j], aspect='auto')
                        ax[i][j].axis('off')
                    else:
                        ax[i][j].imshow(x_decoded[j], aspect='auto')
                        ax[i][j].axis('off')
            plt.tight_layout(pad=0)
            plt.show()
            G.train()
        
        iteration += 1

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_reals], dim=0), 
         bins=1000, histtype='step', color='darkviolet', label='Real');
plt.xticks(fontsize=14)
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_fakes], dim=0), 
         bins=1000, histtype='step', color='g', label='Fake');
plt.legend(fontsize=14)
plt.xticks(fontsize=14);

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_reals], dim=0), 
         bins=1000, histtype='step', color='darkviolet', label='Real');
plt.xticks(fontsize=14)
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_fakes], dim=0), 
         bins=1000, histtype='step', color='g', label='Fake');
plt.legend(fontsize=14)
plt.xticks(fontsize=14);

### One-hot

In [38]:
DEVICE = 'cuda:0'
num_classes = 2
G = ConcatGenerator(hidden_size=128, num_classes=num_classes).to(DEVICE)
D = ProjectionDiscriminator(hidden_size=128, num_classes=num_classes).to(DEVICE)

args =  {'dataset': 'MNIST',
         'eval_each': 10,
         'epochs': 101,
         'log_dir': 'CelebA64_256_v2/',
         'device': DEVICE,
         'weight_decay': 1e-05,
         'depth': 16,
         'gamma': 0.2,
         'lmbda': 0.5,
         'batch_norm': False,
         'batch_size': 64,
         'colors': 3,
         'latent_width': 4, # Bottleneck HW
         'width': 128, # Means 4 downsampling blocks
         'latent': 32, # Bottleneck channels
         'n_classes': 10,
         'advdepth': 16,
         'lr': 0.0001}

scales = int(round(math.log(args['width'] // args['latent_width'], 2)))
ae = Autoencoder(scales=scales,depth=args['depth'],latent=args['latent'],colors=args['colors']).to(args['device']).eval()

ae.load_state_dict(torch.load('acai_64.pt', map_location=args['device']))

optimizer_g = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0, 0.9))
optimizer_d = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0, 0.9))

loss_type = 'hinge'

if loss_type != 'bce':
    criterion_d = DisLoss(loss_type)
    criterion_g = GenLoss(loss_type)

G.train();
D.train();

BCE = nn.BCEWithLogitsLoss()

In [None]:
iteration = 0
dis_fakes = []
dis_reals = []
for epoch in range(10):
    for _ in tqdm_notebook(range(len(train_loader))):
        # =================================================================== #
        #                         1. Get new batch                            #
        # =================================================================== #
        
        latent_real, y_real = next(iter(train_loader))
        y_fake = sample_target_labels(y_real, num_classes)
        latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
        cond_true = nn.functional.one_hot(y_real, num_classes)
        cond_fake = nn.functional.one_hot(y_fake, num_classes)
        
        bs = latent_real.size(0)
        
        # =================================================================== #
        #                        2. Train Discriminator                       #
        # =================================================================== #
        for _ in range(5):
            dis_real = D(latent_real, y_real)

            latent_real, y_real = next(iter(train_loader))
            y_fake = sample_target_labels(y_real, num_classes)
            latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
            cond_true = nn.functional.one_hot(y_real, num_classes)
            cond_fake = nn.functional.one_hot(y_fake, num_classes)

            latent_fake = G(latent_real, cond_fake)
            dis_fake = D(latent_fake, y_fake)

            dis_fakes.append(dis_fake)

            if loss_type == 'bce':
                d_loss = BCE(dis_real, torch.ones_like(dis_real).to(DEVICE)) + \
                         BCE(dis_fake, torch.zeros_like(dis_fake).to(DEVICE))
            else:
                d_loss = criterion_d(dis_real, dis_fake)

            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()

        # =================================================================== #
        #                         3. Train Generator                          #
        # =================================================================== #
        latent_real, y_real = next(iter(train_loader))
        y_fake = sample_target_labels(y_real, num_classes)
        latent_real, y_real, y_fake = latent_real.to(DEVICE), y_real.to(DEVICE), y_fake.to(DEVICE)
        cond_true = nn.functional.one_hot(y_real, num_classes)
        cond_fake = nn.functional.one_hot(y_fake, num_classes)
        
        latent_fake = G(latent_real, cond_fake)
        dis_fake = D(latent_fake, y_fake)
        
        latent_cyclic =  G(latent_fake, cond_true)
        cyclic_loss = F.l1_loss(latent_cyclic, latent_real)
        
        dis_fakes.append(dis_fake)
        dis_reals.append(dis_real)
        
        # latent_id =  G(latent_real, y_real)
        # identity_loss = F.l1_loss(latent_id, latent_real)
        
        if loss_type == 'bce':
            g_loss = BCE(dis_fake, torch.ones_like(dis_fake).to(DEVICE)) + cyclic_loss
        else:
            g_loss = criterion_g(dis_fake) + cyclic_loss

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
        if iteration % 50 == 0:
            print("Step: {} | D_loss: {:.3f} | G_loss: {:.3f} | Cyclic loss: {:.3f}".format(iteration, 
                                                                        d_loss.item(), 
                                                                        g_loss.item(),
                                                                        cyclic_loss.item()))
            
            
        if iteration % 100 == 0:
            G.eval()
            with torch.no_grad():
                latent_fake = G(latent_real, cond_fake)
                x_decoded = ae.decoder(latent_fake.view(bs, -1, 2, 2)).detach().cpu().permute(0, 2, 3, 1)
                x_decoded = (x_decoded - x_decoded.min()) / (x_decoded.max() - x_decoded.min())
            
            images = ae.decoder(latent_real.view(bs, -1, 2, 2)).detach().cpu().permute(0, 2, 3, 1)
            images = (images - images.min()) / (images.max() - images.min())
            target = np.array([train_dataset.one2multi[a.item()] for a in y_fake]).astype(int)
            target[target == -1] = 0

            fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(20, 5))
            # print(y_fake[:8])
            for i in range(2):
                for j in range(8):
                    if i == 0:
                        name = '\n'.join([name + f": {val}" for name, val in zip(features, target[j])])
                        ax[i][j].set_title(str(name), fontsize=14)
                        ax[i][j].imshow(images[j], aspect='auto')
                        ax[i][j].axis('off')
                    else:
                        ax[i][j].imshow(x_decoded[j], aspect='auto')
                        ax[i][j].axis('off')
            plt.tight_layout(pad=0)
            plt.show()
            G.train()
        
        iteration += 1

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_reals], dim=0), 
         bins=1000, histtype='step', color='darkviolet', label='Real');
plt.xticks(fontsize=14)
plt.hist(torch.cat([a.detach().cpu().squeeze() for a in dis_fakes], dim=0), 
         bins=1000, histtype='step', color='g', label='Fake');
plt.legend(fontsize=14)
plt.xticks(fontsize=14);