In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from torchvision.transforms import functional
import matplotlib.pyplot as plt
from model import Generator64, Discriminator64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def generate(epoch, G, z_dim, nrow=10, ncol=10, samp_dir='data/generated/TINY'):
    os.makedirs(samp_dir, exist_ok=True)

    sample_z = torch.randn(nrow*ncol, z_dim, 1, 1, device=device)
    samples = G(sample_z)
    save_image(samples, os.path.join(samp_dir, 'epoch_%03d.png' % (epoch)), nrow=ncol, normalize=True, value_range=(-1,1))

In [None]:
def train(D, G, train_loader, epochs, batch_size=64, lr=0.0002, z_dim=100, model_dir='model/tiny'):
    D_optimizer = optim.Adam(D.parameters(), lr=lr)
    G_optimizer = optim.Adam(G.parameters(), lr=lr)
    
    # LSGAN labels
    a = 0
    b = 1
    c = 1

    for epoch in range(1, epochs + 1):
        D_running_loss = 0.0
        G_running_loss = 0.0

        for real_img, _ in train_loader:
            real_img = real_img.to(device)

            # random noise
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)

            # --------------------
            # update discriminator
            # --------------------

            D_optimizer.zero_grad()

            # real
            D_real = D(real_img)
            D_real_loss = torch.sum((D_real - b) ** 2)

            # fake
            fake_img = G(z)
            D_fake = D(fake_img.detach())
            D_fake_loss = torch.sum((D_fake - a) ** 2)

            # minimizing loss
            D_loss = 0.5 * (D_real_loss + D_fake_loss) / batch_size
            D_loss.backward()
            D_optimizer.step()
            D_running_loss += D_loss.data.item()

            # ----------------
            # update generator
            # ----------------
            
            G_optimizer.zero_grad()

            fake_img = G(z)
            D_fake = D(fake_img)
            
            G_loss = 0.5 * (torch.sum((D_fake - c) ** 2)) / batch_size
            G_loss.backward()
            G_optimizer.step()
            G_running_loss += G_loss.data.item()

        print(f"Epoch {epoch:03d}/{epochs} | D Loss: {D_running_loss:.4f} | G Loss: {G_running_loss:.4f}")

        # generate image
        G.eval()
        generate(epoch, G, z_dim)

    # final models
    os.makedirs(model_dir, exist_ok=True)
    torch.save(G.state_dict(), model_dir+'/generator.pth')
    torch.save(D.state_dict(), model_dir+'/discriminator.pth')
    print("models saved")

In [None]:
z_dim = 100
batch_size = 64
epochs = 500
lr = 0.0002
path_to_data = 'data/tiny-imagenet-200/train'

# dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root=path_to_data, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

D = Discriminator64().to(device)
G = Generator64(nz=z_dim).to(device)

train(D, G, train_loader, epochs, batch_size, lr, z_dim)