In [None]:
import torch
import torchvision
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
# import the dataset and loader from data_utils.py
import data_utils
image_folder_path = 'path/to/image/folder'
desired_image_size = (64, 64)

batch_size = 1
# create a dataset so that dataset[i] returns the ith image
dataset = data_utils.Dataset(image_folder_path, desired_image_size)
# make a dataloader that returns the images as batches for parallel processing
dataloader = torch.utils.data.DataLoader(dataset, batch_size)

In [None]:
# import the models from model.py
import models
generator = models.Generator()
discriminator = models.Discriminator()

In [None]:
# use the gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)

In [None]:
# Initialize the loss function
criterion = torch.nn.BCELoss()

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(64, 100, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# set a learning rate
lr = 0.1

# Setup optimizers for both generator and discriminator
optimizerD = torch.optim.AdamW(discriminator.parameters(), lr=lr)
optimizerG = torch.optim.AdamW(generator.parameters(), lr=lr)

In [None]:
# functions that save and load the model and optimizer
save_to = './checkpoints/model.pt'
def save(path, generator, discriminator, optimizerG, optimizerD):
    torch.save(
        {
            'generator_weights' : generator.state_dict(),
            'discriminator_weights' : discriminator.state_dict(),
            'generator_optimizer_weights' : optimizerG.state_dict(),
            'discriminator_optimizer_weights' : optimizerD.state_dict(),
        },
        path
    )

def load(path):
    # initialize 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(path)
    generator = models.Generator().to(device)
    discriminator = models.Discriminator().to(device)

    optimizerD = torch.optim.Adam(generator.parameters(), lr=lr)
    optimizerG = torch.optim.Adam(discriminator.parameters(), lr=lr)

    generator.load_state_dict(checkpoint['generator_weights'])
    discriminator.load_state_dict(checkpoint['discriminator_weights'])
    optimizerG.load_state_dict(checkpoint['generator_optimizer_weights'])
    optimizerD.load_state_dict(checkpoint['discriminator_optimizer_weights'])

    return generator, discriminator, optimizerG, optimizerD

In [None]:
# create a loop to train the model

num_epochs = 500

generator.train()
discriminator.train()

for epoch in tqdm(range(1, 1+num_epochs)):
    for i, data in enumerate(dataloader, 0):

        ########################################################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        #######################################################
        ## Train with all-real batch
        # Format batch

        # Forward pass real batch through D

        # Calculate loss on all-real batch

        # Calculate gradients for D in backward pass

        ## Train with all-fake batch
        # Generate batch of latent vectors

        # Generate fake image batch with G

        # Classify all fake batch with D

        # Calculate D's loss on the all-fake batch

        # Calculate the gradients for this batch, accumulated (summed) with previous gradients

        # Compute error of D as sum over the fake and the real batches

        # Update D

        ########################################################
        # (2) Update G network: maximize log(D(G(z)))
        #######################################################
        
        # Since we just updated D, perform another forward pass of all-fake batch through D

        # Calculate G's loss based on this output

        # Calculate gradients for G

        # Update G

        # # Output training stats
       
        # Save Losses for plotting later

        # Check how the generator is doing by saving G's output on fixed_noise

In [None]:
# generate images from the model