In [10]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import make_grid, save_image


from matplotlib import pyplot as plt

class Generator(nn.Module):
    def __init__(self, latent_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.model(x)
        return x.view(-1, 1, 28, 28)   # Bx1x28x28

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),   #784 is MNIST input feature dimension
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 1),      # single value stating fake or real
            nn.Sigmoid()            # value between 0 and 1
        )
    def forward(self, x):
        x = x.view(-1, 784)         # reshape input same as torch.flatten(x, start_dim=1)
        return self.model(x)

In [16]:
# torch parameters
SEED = 60            # reproducability
# NN Parameters
EPOCHS = 50          # number of epochs
LR = 0.001           # learning rate
MOMENTUM = 0.9       # momentum for the optimizer
WEIGHT_DECAY = 1e-5  # weight decay for the optimizer
GAMMA = 0.1          # learning rate schedular
BATCH_SIZE = 256     # number of images to load per iteration
# GAN parameters
SAMPLE_SIZE = 64     # number of fake images to sample
LATENT_SIZE = 128    # size of latent or noise vector
DISC_STEPS  = 1      # number of steps to apply to the discriminator

# manual seed to reproduce same results
torch.manual_seed(SEED)

# DOWNLOADING AND LOADING MNIST DATASET 
mnist_folder= '/home/atmis/Documents/DATASETS/mnist'

# normalize each image and set the pixel values between -1 and 1
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

to_pil_image = transforms.ToPILImage()

# download the dataset if not already downloaded and set necessery transforms
tr_dataset   = MNIST(mnist_folder, train=True, download=True, transform=img_transform)
# prepare loader for the training dataset
train_loader = torch.utils.data.DataLoader(tr_dataset, batch_size=BATCH_SIZE, shuffle=True)

# determine where to run the code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# create the generator network and move to gpu if available
gen_net = Generator(LATENT_SIZE).to(device)
# create the discriminator and move to gpu if available
disc_net = Discriminator().to(device)

# specify the loss to be used
# Binary Cross Entropy Loss
loss_fn = nn.BCELoss()
# specify the optimizer for generator network
optimizer_gen = optim.SGD(gen_net.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# specify the optimizer for discriminator network
optimizer_disc = optim.Adam(disc_net.parameters(), lr=LR)

# function to create the noise vector
def create_noise():
    return torch.randn(SAMPLE_SIZE, LATENT_SIZE).to(device)


def train_discriminator(real_x, fake_x):
    B = real_x.size(0)                         # batch size
    real_y = torch.ones(B, 1).to(device)       # create labels of 1
    fake_y = torch.zeros(SAMPLE_SIZE, 1).to(device)      # create labels of 0
    
    optimizer_disc.zero_grad()
    # forward pass the real data
    output_real = disc_net(real_x)
    # estimate loss         
    loss_real = loss_fn(output_real, real_y)
    # forward pass the fake data
    output_fake = disc_net(fake_x)
    # estimate loss
    loss_fake = loss_fn(output_fake, fake_y)
    # accumulate gradients for both passes 
    loss_real.backward()
    loss_fake.backward()
    # update weights of discriminator
    optimizer_disc.step()
    return loss_real + loss_fake

# update generator weights using the gradients of discriminator
def train_generator(fake_x):
    B = fake_x.size(0)                       # batch size
    real_y = torch.ones(B, 1).to(device)     # create labels of 1

    optimizer_gen.zero_grad()
    # forward pass the fake data on discriminator
    output = disc_net(fake_x)
    # determine how far we are from real label
    loss = loss_fn(output, real_y)
    # calculate gradients
    loss.backward()
    # update generator weights
    optimizer_gen.step()
    return loss


In [17]:
# to save the images generated by the generator
def save_generator_image(image, path):
    save_image(image, path)

In [18]:
losses_g = [] # store generator loss after each epoch
losses_d = [] # store discriminator loss after each epoch
images = []   # store images generatd by the generator


# create the noise vector
noise = create_noise()
# put the networks in training mode
gen_net.train()
disc_net.train()

for epoch in range(EPOCHS):
    loss_g = 0.0
    loss_d = 0.0
    for idx, (img, _) in enumerate(train_loader):
        # move image to gpu if exists
        img = img.to(device)
        # get batch size
        b_size = len(img)
        # run the discriminator for DISC_STEPS number of steps
        for _ in range(DISC_STEPS):
            fake_x = gen_net(create_noise()).detach()
            real_x = img
            # train the discriminator network
            loss_d += train_discriminator(real_x, fake_x)
        fake_x = gen_net(create_noise())
        # train the generator network
        loss_g += train_generator(fake_x)
    # create the final fake image for the epoch
    generated_img = gen_net(noise).cpu().detach()
    # make the images as grid
    generated_img = make_grid(generated_img)
    # save the generated torch tensor models to disk
    save_generator_image(generated_img, f"gan_outputs/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / idx # total generator loss for the epoch
    epoch_loss_d = loss_d / idx # total discriminator loss for the epoch
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    
    print(f"Epoch {epoch} of {EPOCHS}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

print('DONE TRAINING')
torch.save(gen_net.state_dict(), 'gan_outputs/generator.pth')
torch.save(disc_net.state_dict(), 'gan_outputs/discriminator.pth')

ValueError: Target and input must have the same number of elements. target nelement (256) != input nelement (64)