In [1]:
import os
import numpy as np
import math
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.utils import save_image


os.makedirs('cgan_images',exist_ok = True)


num_epochs = 200
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
z_dims = 100
num_classes = 10
img_size = 32
channels = 1
sample_interval = 400 # interval between image sampling


img_shape = (channels,img_size,img_size)

cuda = True if torch.cuda.is_available() else False

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.label_emb = nn.Embedding(num_classes,num_classes)

        self.main = nn.Sequential(

                nn.Linear(z_dims + num_classes, 128),
                nn.LeakyReLU(0.2,inplace = True),

                nn.Linear(128,256),
                nn.BatchNorm1d(256,0.8),
                nn.LeakyReLU(0.2,inplace = True),

                nn.Linear(256,512),
                nn.BatchNorm1d(512,0.8),
                nn.LeakyReLU(0.2,inplace = True),

                nn.Linear(512,1024),
                nn.BatchNorm1d(1024,0.8),
                nn.LeakyReLU(0.2,inplace = True),

                nn.Linear(1024,int(np.prod(img_shape))), # 1 x 32 x 32
                nn.Tanh()
        )

    def forward(self,noise,labels):
        # concatenate the label embedding and the input
        gen_input = torch.cat((self.label_emb(labels),noise),-1)
        img = self.main(gen_input)
        img = img.view(img.size(0),*img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        self.label_emb = nn.Embedding(num_classes,num_classes)

        self.main = nn.Sequential(

                nn.Linear(num_classes + int(np.prod(img_shape)),512),
                nn.LeakyReLU(0.2,inplace = True),
                nn.Linear(512,512),
                nn.Dropout(0.4),
                nn.LeakyReLU(0.2,inplace = True),
                nn.Linear(512,512),
                nn.Dropout(0.4),
                nn.LeakyReLU(0.2,inplace = True),
                nn.Linear(512,1)
        )


    def forward(self,img,labels):
        disc_input = torch.cat((img.view(img.size(0),-1),self.label_emb(labels)),-1)
        validity = self.main(disc_input)
        return validity


# Loss functions
critertion = torch.nn.MSELoss()


# Initialize the adversaries

gen = Generator()
disc = Discriminator()

if cuda:
    gen.cuda()
    disc.cuda()
    critertion.cuda()

# Data

dataloader = DataLoader(datasets.MNIST('data',train = True,download = True,
                    transform = transforms.Compose(
                        [transforms.Resize(img_size),transforms.ToTensor(),
                            transforms.Normalize([0.5],[0.5])])),
                        batch_size = batch_size, shuffle = True)

# Optimizers

optim_gen = optim.Adam(gen.parameters(),lr = lr,betas = (b1,b2))
optim_disc = optim.Adam(disc.parameters(),lr = lr,betas = (b1,b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

def sample_image(n_row, batches_done):
    "Saving a grid of generated digits ranging form 0 to num_classes"
    # sample noise
    z = Variable(FloatTensor(np.random.normal(0,1,(n_row**2,z_dims))))

    # Getting labels from 0 to num_classes for n_rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = gen(z,labels)
    save_image(gen_imgs.data,"cgan_images/%d.png" % batches_done, nrow = n_row,normalize = True)



## Training ##

for epoch in range(num_epochs):

    for i,(imgs,labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)


        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))


        ## Train generator

        optim_gen.zero_grad()

        z = Variable(FloatTensor(np.random.normal(0,1,(batch_size,z_dims))))
        gen_labels = Variable(LongTensor(np.random.randint(0,num_classes,batch_size)))

        gen_imgs = gen(z,gen_labels)

        validity = disc(gen_imgs,gen_labels)
        g_loss = critertion(validity,valid)

        g_loss.backward()
        optim_gen.step()


        ## Train discriminator

        optim_disc.zero_grad()

        validity_real = disc(real_imgs,labels)
        d_real_loss = critertion(validity_real,valid)

        validity_fake = disc(gen_imgs.detach(),gen_labels)
        d_fake_loss = critertion(validity_fake,fake)
        
        d_loss = (d_real_loss + d_fake_loss)/2

        d_loss.backward()
        optim_disc.step()

        if i % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch*len(dataloader) + i
        if batches_done % sample_interval == 0:
            sample_image(n_row = 10,batches_done = batches_done)
    


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


[Epoch 0/200] [Batch 0/938] [D loss: 0.560840] [G loss: 1.041603]
[Epoch 0/200] [Batch 100/938] [D loss: 0.183681] [G loss: 0.238295]
[Epoch 0/200] [Batch 200/938] [D loss: 0.108711] [G loss: 0.447054]
[Epoch 0/200] [Batch 300/938] [D loss: 0.095283] [G loss: 0.526289]
[Epoch 0/200] [Batch 400/938] [D loss: 0.152627] [G loss: 0.291588]
[Epoch 0/200] [Batch 500/938] [D loss: 0.073992] [G loss: 0.586982]
[Epoch 0/200] [Batch 600/938] [D loss: 0.096513] [G loss: 0.519078]
[Epoch 0/200] [Batch 700/938] [D loss: 0.093237] [G loss: 0.633268]
[Epoch 0/200] [Batch 800/938] [D loss: 0.118070] [G loss: 0.610693]
[Epoch 0/200] [Batch 900/938] [D loss: 0.106240] [G loss: 0.426255]
[Epoch 1/200] [Batch 0/938] [D loss: 0.131168] [G loss: 0.365451]
[Epoch 1/200] [Batch 100/938] [D loss: 0.097327] [G loss: 0.615146]
[Epoch 1/200] [Batch 200/938] [D loss: 0.144526] [G loss: 1.197335]
[Epoch 1/200] [Batch 300/938] [D loss: 0.079483] [G loss: 0.624254]
[Epoch 1/200] [Batch 400/938] [D loss: 0.079704] [G 

KeyboardInterrupt: ignored

In [14]:
!ls

0.png	   19600.png  29200.png  39200.png  48400.png  58400.png  6800.png
10000.png  20000.png  29600.png  39600.png  48800.png  58800.png  68400.png
10400.png  2000.png   30000.png  40000.png  49200.png  59200.png  68800.png
10800.png  20400.png  30400.png  4000.png   49600.png  59600.png  69200.png
11200.png  20800.png  30800.png  400.png    50000.png  60000.png  69600.png
11600.png  21200.png  31200.png  40400.png  50400.png  6000.png   70000.png
12000.png  21600.png  31600.png  40800.png  50800.png  60400.png  70400.png
1200.png   22000.png  32000.png  41200.png  51200.png  60800.png  70800.png
12400.png  22400.png  3200.png	 41600.png  51600.png  61200.png  71200.png
12800.png  22800.png  32400.png  42000.png  52000.png  61600.png  71600.png
13200.png  23200.png  32800.png  42400.png  5200.png   62000.png  72000.png
13600.png  23600.png  33200.png  42800.png  52400.png  62400.png  7200.png
14000.png  24000.png  33600.png  43200.png  52800.png  62800.png  72400.png
14400.png  2400.p

In [3]:
cd cgan_images

/content/cgan_images


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [16]:
!cp *.png /content/drive/MyDrive/cgan_images