In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [10]:
class Discriminator(nn.Module):
    
    def __init__(self, img_channels, features): # disc features
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
                        
            nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.disc(x)
        

In [3]:
class Generator(nn.Module):
    
    def __init__(self, noise_channels, img_channels, features): # gen features
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(noise_channels, features * 16, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(features*16),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(features * 16, features*8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.ReLU(True),
                        
            nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features*2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(features*2, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
        
    def forward(self, x):
        return self.gen(x)
        

In [16]:
def initialize_weights(model):
    for m in model.modules():
#         if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
#             nn.init.normal_(m.weight_data, 0.0, 0.02)
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

In [12]:
def test():
    batch, channels, height, width = 8, 3, 64, 64
    features = 8
    noise_dim = 100
    x = torch.randn(batch, channels, height, width)
    disc = Discriminator(channels, features)
    assert disc(x).shape == (batch, 1, 1, 1), 'Failed'
    z = torch.randn(batch, noise_dim, 1, 1)
    gen = Generator(noise_dim, channels, features)
    assert gen(z).shape == (batch, channels, height, width), 'Failed'
test()

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

LEARNING_RATE = 2e-4
IMG_SIZE = 64
NOISE_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
IMG_CHANNELS = 1
BATCH_SIZE = 64
EPOCHS = 8

In [29]:
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * IMG_CHANNELS, [0.5] * IMG_CHANNELS)
])

dataset = torchvision.datasets.MNIST('datasets/', transform=transform, download=False)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

RuntimeError: Dataset not found. You can use download=True to download it

In [17]:
gen_model = Generator(NOISE_DIM, IMG_CHANNELS, FEATURES_GEN).to(device)
disc_model = Discriminator(IMG_CHANNELS, FEATURES_DISC).to(device)
gen, disc = gen_model, disc_model
initialize_weights(gen_model)
initialize_weights(disc_model)

In [18]:
BETAS = (0.5, 0.999)
opt_gen = optim.Adam(gen_model.parameters(), lr=LEARNING_RATE, betas=BETAS)
opt_disc = optim.Adam(disc_model.parameters(), lr=LEARNING_RATE, betas=BETAS)
criterion = nn.BCELoss()

In [None]:
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f'logs/real')
writer_fake = SummaryWriter(f'logs/fake')
step = 0

In [None]:
gen.train()
disc.train()

In [None]:
for idx, (real, _) in enumerate(dataloader):
    real = real.to(device)
    noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
    fake = gen(noise)
    
    # Train Disc
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake.detach()).reshape(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real + loss_disc_fake) / 2
    disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()
    
    # Train Gen
    output = disc(fake)
    loss_gen = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()
    
    if batch_idx % 100 == 0:
        print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
              Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")
        
        with torch.no_grad():
            fake = gen(fixed_noise)
            # take out (up to) 32 examples
            img_grid_real = torchvision.utils.make_grid(
                real[:32], normalize=True
            )
            img_grid_fake = torchvision.utils.make_grid(
                fake[:32], normalize=True
            )

            writer_real.add_image("Real", img_grid_real, global_step=step)
            writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1