In [8]:
from src.model import DCGAN

import torch
import torchvision
import torchvision.transforms as transforms

In [5]:
HIDDEN_DIM = 100
FMAP_DIM = 64
N_CHANNELS = 3

model = DCGAN(
    hidden_dim=HIDDEN_DIM,
    feature_map_dim=FMAP_DIM,
    n_channels=N_CHANNELS
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# get function handles of loss and metrics
generator_loss = torch.nn.BCELoss()
discriminator_loss = torch.nn.BCELoss()

# build optimizer, learning rate scheduler. delete every line containing lr_scheduler for
# disabling scheduler
generator_trainable_params = filter(lambda p: p.requires_grad, model.generator.parameters())
discriminator_trainable_params = filter(lambda p: p.requires_grad, model.discriminator.parameters())
generator_optimizer = torch.optim.Adam(generator_trainable_params, lr=3e-4, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator_trainable_params, lr=3e-4, betas=(0.5, 0.999))
generator_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(generator_optimizer, gamma=0.99)
discriminator_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(discriminator_optimizer, gamma=0.99)

print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Number of parameters in the generator: {sum(p.numel() for p in model.generator.parameters())}")
print(f"Number of parameters in the discriminator: {sum(p.numel() for p in model.discriminator.parameters())}")

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, HIDDEN_DIM, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

Number of parameters in the model: 6342272
Number of parameters in the generator: 3576704
Number of parameters in the discriminator: 2765568


In [12]:
from src.dataset import get_dataloaders

transform = transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

dataloader = get_dataloaders('data/cats',
                             transform=transform)["train"]

batch = next(iter(dataloader))
batch[0].shape

torch.Size([3, 64, 64])

In [13]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

num_epochs = 50



print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        discriminator_optimizer.zero_grad()
        real_cpu = data.to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        output = model.discriminator(real_cpu).view(-1)
        error_disc_real = discriminator_loss(output, label)
        error_disc_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        noise = torch.randn(batch_size, HIDDEN_DIM, 1, 1, device=device)
        fake = model(noise)
        label.fill_(fake_label)
        output = model.discriminator(fake.detach()).view(-1)
        error_disc_fake = discriminator_loss(output, label)
        error_disc_fake.backward()
        D_G_z1 = output.mean().item()

        error_disc = error_disc_real + error_disc_fake
        discriminator_optimizer.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        generator_optimizer.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = model.discriminator(fake).view(-1)
        # Calculate G's loss based on this output
        error_gen = generator_loss(output, label)
        # Calculate gradients for G
        error_gen.backward()
        D_G_z2 = output.mean().item()
        # Update G
        generator_optimizer.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     error_disc.item(), error_gen.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(error_gen.item())
        D_losses.append(error_disc.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = model(fixed_noise).detach().cpu()
            img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Starting Training Loop...


RuntimeError: Given transposed=1, weight of size [100, 512, 4, 4], expected input[64, 3, 64, 64] to have 100 channels, but got 3 channels instead