In [1]:
import torch 
from torch import nn,optim

### Discriminator has no fc layers
class Discriminator(nn.Module):

    def __init__(self, channels_img, features_d):
        super(Discriminator,self).__init__()

        # Input shape : N x channels_img x 64 x 64
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32x32
            nn.LeakyReLU(0.2),

            self._block(features_d, features_d*2, 4, 2, 1), # 16x16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8x8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4x4

            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x1
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):

        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.disc(x)


class Generator(nn.Module):

    def __init__(self, z_dim, channels_img, features_g):
        super(Generator,self).__init__()

        self.gen = nn.Sequential(

            self._block(z_dim, features_g*16, 4, 1, 0), # 4x4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32

            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), # 64x64
            nn.Tanh()

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):

        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self,x):
        return self.gen(x)

def initialize_weights(model):
    for m in model.modules() :
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100 
    x = torch.randn((N, in_channels, H, W))

    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)

    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torchvision 
from torchvision import datasets,transforms
from torch.utils.data import DataLoader 
from torch.utils.tensorboard import SummaryWriter 

BATCH_SIZE = 128
LR = 0.0002
CHANNELS_IMG = 1
FEATURES_D = 64
Z_DIM = 100
FEATURES_G = 64
IMAGE_SIZE = 64
step = 0


dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

t = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
])

# data = datasets.ImageFolder('/users/gursi/desktop/celeb_dataset', transform=t)
# loader = DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=True)

data = datasets.MNIST('./data', transform=t, download=True, train=True)
loader = DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=True)

disc = Discriminator(CHANNELS_IMG, FEATURES_D)
disc.load_state_dict(torch.load('/content/drive/MyDrive/DCGAN/disc267.pt', map_location=dev))
disc.to(dev)
# initialize_weights(disc)
opt_disc = optim.Adam(disc.parameters(), lr = LR, betas=(0.5, 0.999))

gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_G)
gen.load_state_dict(torch.load('/content/drive/MyDrive/DCGAN/gen267.pt', map_location=dev))
gen.to(dev)
# initialize_weights(gen)
opt_gen = optim.Adam(gen.parameters(), lr = LR, betas=(0.5, 0.999))

criterion = nn.BCELoss()

fixed_noise = torch.randn((32, Z_DIM, 1, 1)).to(dev)
writer_real = SummaryWriter(log_dir='runs/real')
writer_fake = SummaryWriter(log_dir='runs/fake')

gen.train()
disc.train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [4]:
EPOCHS = 100
for epoch in range(EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(dev)
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(dev)

        # Train disc 
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

        fake = gen(noise)
        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        loss_D = (loss_disc_fake + loss_disc_real) / 2
        disc.zero_grad()
        loss_D.backward(retain_graph = True)
        opt_disc.step()

        # train gen
        output = disc(fake).reshape(-1)
        loss_G = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_G.backward()
        opt_gen.step()

        if batch_idx % 20 == 0:
            print(
                f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_D:.4f}, loss G: {loss_G:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

    torch.save(disc.state_dict(), '/content/drive/MyDrive/DCGAN/disc{step}.pt')
    torch.save(gen.state_dict(), '/content/drive/MyDrive/DCGAN/gen{step}.pt')

Epoch [0/100] Batch 0/469                   Loss D: 0.1179, loss G: 0.2055
Epoch [0/100] Batch 20/469                   Loss D: 0.1283, loss G: 3.2908
Epoch [0/100] Batch 40/469                   Loss D: 0.1022, loss G: 2.6566
Epoch [0/100] Batch 60/469                   Loss D: 0.7539, loss G: 1.2973
Epoch [0/100] Batch 80/469                   Loss D: 0.1047, loss G: 2.8591
Epoch [0/100] Batch 100/469                   Loss D: 0.1220, loss G: 2.9924
Epoch [0/100] Batch 120/469                   Loss D: 0.0916, loss G: 2.9417
Epoch [0/100] Batch 140/469                   Loss D: 0.1642, loss G: 4.1294
Epoch [0/100] Batch 160/469                   Loss D: 0.1003, loss G: 3.2563
Epoch [0/100] Batch 180/469                   Loss D: 0.1002, loss G: 2.7847
Epoch [0/100] Batch 200/469                   Loss D: 0.2226, loss G: 3.3296
Epoch [0/100] Batch 220/469                   Loss D: 0.0759, loss G: 2.9488
Epoch [0/100] Batch 240/469                   Loss D: 0.0650, loss G: 3.3634
Epoch

KeyboardInterrupt: ignored

In [6]:
torch.save(disc.state_dict(), f'/content/drive/MyDrive/DCGAN/final_disc.pt')
torch.save(gen.state_dict(), f'/content/drive/MyDrive/DCGAN/final_gen.pt')