# Imports

In [12]:
from torchvision import transforms as T
from torch.utils.data.dataset import Dataset
from torch.utils import data
from PIL import Image
import os
import torch
from pickle import HIGHEST_PROTOCOL

# Useful Constants

In [4]:
root_dir = '/Users/gbotev/Downloads/memes'
new_size = 64

# Define `CustomDataset`

In [5]:
class CustomDataset(Dataset):
    """Custom dataset."""

    def __init__(self, root_dir, transforms=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transforms = transforms
        self.img_names = [name for name in os.listdir(root_dir) if os.path.isfile(os.path.join(self.root_dir, name))]
        self.num_imgs = len(self.img_names)

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root_dir,
                                      self.img_names[index]))
        if self.transforms:
            img = self.transforms(img)
        return img

# Calculate Normalization and Initialize `CustomDataset`

In [6]:
dataset = CustomDataset(root_dir,
                        T.Compose([T.Resize(new_size),
                                   T.ToTensor()]))
means = []
stds = []
for img in dataset:
    means.append(torch.mean(img))
    stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
print(f'Mean: {mean}\n Std: {std}')
norm = T.Normalize(mean=mean, std=std)

Mean: 0.6056310534477234
 Std: 0.2573034465312958


In [7]:
dataset = CustomDataset(root_dir,
                        T.Compose([T.Resize(new_size),
                                   T.ToTensor(),
                                   norm]))

Sanity check to make sure we have 3,326 images.

In [8]:
len(dataset)

3326

# Network Architecture
Taken from:
https://github.com/pytorch/ignite/blob/master/examples/gan/dcgan.py

In [9]:
class Net(torch.nn.Module):
    """ A base class for both generator and the discriminator.
    Provides a common weight initialization scheme.
    """
    def weights_init(self):
        for m in self.modules():
            classname = m.__class__.__name__
            if "Conv" in classname:
                m.weight.data.normal_(0.0, 0.02)
            elif "BatchNorm" in classname:
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

    def forward(self, x):
        return x

In [10]:
class Generator(Net):
    """ Generator network.
    Args:
        nf (int): Number of filters in the second-to-last deconv layer
    """
    def __init__(self, z_dim, nf, nc):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(in_channels=z_dim, out_channels=nf * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(nf * 8),
            nn.ReLU(inplace=True),
            # state size. (nf*8) x 4 x 4
            nn.ConvTranspose2d(in_channels=nf * 8, out_channels=nf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf * 4),
            nn.ReLU(inplace=True),
            # state size. (nf*4) x 8 x 8
            nn.ConvTranspose2d(in_channels=nf * 4, out_channels=nf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf * 2),
            nn.ReLU(inplace=True),
            # state size. (nf*2) x 16 x 16
            nn.ConvTranspose2d(in_channels=nf * 2, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf),
            nn.ReLU(inplace=True),
            # state size. (nf) x 32 x 32
            nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        self.weights_init()

    def forward(self, x):
        return self.net(x)

In [11]:
class Discriminator(Net):
    """ Discriminator network.
    Args:
        nf (int): Number of filters in the first conv layer.
    """
    def __init__(self, nc, nf):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(in_channels=nc, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (nf) x 32 x 32
            nn.Conv2d(in_channels=nf, out_channels=nf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (nf*2) x 16 x 16
            nn.Conv2d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (nf*4) x 8 x 8
            nn.Conv2d(in_channels=nf * 4, out_channels=nf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (nf*8) x 4 x 4
            nn.Conv2d(in_channels=nf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid(),
        )
        self.weights_init()

    def forward(self, x):
        output = self.net(x)
        return output.view(-1, 1).squeeze(1)

# Training

Data

In [None]:
batch_size = 64
n_workers = 2
loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

In [None]:


# netowrks
netG = Generator(z_dim, g_filters, num_channels).to(device)
netD = Discriminator(num_channels, d_filters).to(device)

# criterion
bce = nn.BCELoss()

# optimizers
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))

# load pre-trained models
if saved_G:
    netG.load_state_dict(torch.load(saved_G))

if saved_D:
    netD.load_state_dict(torch.load(saved_D))

# misc
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

def get_noise():
    return torch.randn(batch_size, z_dim, 1, 1, device=device)

# The main function, processing a batch of examples
def step(engine, batch):

    # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
    real, _ = batch
    real = real.to(device)

    # -----------------------------------------------------------
    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
    netD.zero_grad()

    # train with real
    output = netD(real)
    errD_real = bce(output, real_labels)
    D_x = output.mean().item()

    errD_real.backward()

    # get fake image from generator
    noise = get_noise()
    fake = netG(noise)

    # train with fake
    output = netD(fake.detach())
    errD_fake = bce(output, fake_labels)
    D_G_z1 = output.mean().item()

    errD_fake.backward()

    # gradient update
    errD = errD_real + errD_fake
    optimizerD.step()

    # -----------------------------------------------------------
    # (2) Update G network: maximize log(D(G(z)))
    netG.zero_grad()

    # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
    output = netD(fake)
    errG = bce(output, real_labels)
    D_G_z2 = output.mean().item()

    errG.backward()

    # gradient update
    optimizerG.step()

    return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}

# ignite objects
trainer = Engine(step)
checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False)
timer = Timer(average=True)

# attach running average metrics
monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(trainer, "errD")
RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(trainer, "errG")
RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(trainer, "D_G_z1")
RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(trainer, "D_G_z2")

# attach progress bar
pbar = ProgressBar()
pbar.attach(trainer, metric_names=monitoring_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
def print_logs(engine):
    fname = os.path.join(output_dir, LOGS_FNAME)
    columns = ["iteration",] + list(engine.state.metrics.keys())
    values = [str(engine.state.iteration),] + [str(round(value, 5)) for value in engine.state.metrics.values()]

    with open(fname, "a") as f:
        if f.tell() == 0:
            print("\t".join(columns), file=f)
        print("\t".join(values), file=f)

    message = "[{epoch}/{max_epoch}][{i}/{max_i}]".format(
        epoch=engine.state.epoch, max_epoch=epochs, i=(engine.state.iteration % len(loader)), max_i=len(loader)
    )
    for name, value in zip(columns, values):
        message += " | {name}: {value}".format(name=name, value=value)

    pbar.log_message(message)

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def save_fake_example(engine):
    fake = netG(fixed_noise)
    path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch))
    vutils.save_image(fake.detach(), path, normalize=True)

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def save_real_example(engine):
    img, y = engine.state.batch
    path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch))
    vutils.save_image(img, path, normalize=True)

# adding handlers using `trainer.add_event_handler` method API
trainer.add_event_handler(
    event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"netG": netG, "netD": netD}
)

# automatically adding handlers via a special `attach` method of `Timer` handler
timer.attach(
    trainer,
    start=Events.EPOCH_STARTED,
    resume=Events.ITERATION_STARTED,
    pause=Events.ITERATION_COMPLETED,
    step=Events.ITERATION_COMPLETED,
)

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
    pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format(engine.state.epoch, timer.value()))
    timer.reset()

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def create_plots(engine):
    try:
        import matplotlib as mpl

        mpl.use("agg")

        import numpy as np
        import pandas as pd
        import matplotlib.pyplot as plt

    except ImportError:
        warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

    else:
        df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration")
        _ = df.plot(subplots=True, figsize=(20, 20))
        _ = plt.xlabel("Iteration number")
        fig = plt.gcf()
        path = os.path.join(output_dir, PLOT_FNAME)

        fig.savefig(path)

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
    if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
        engine.terminate()
        warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

        create_plots(engine)
        checkpoint_handler(engine, {"netG_exception": netG, "netD_exception": netD})

    else:
        raise e

# Setup is done. Now let's run the training
trainer.run(loader, epochs)