In [1]:
# code adapted from https://github.com/eriklindernoren/PyTorch-GAN
%cd

/home/kacper


In [2]:
from torch import nn
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from torchvision.utils import save_image

cuda = True if torch.cuda.is_available() else False

class Options:
    def __init__(self):
        self.n_epochs = 200
        self.batch_size = 64
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.latent_dim = 100
        self.img_size = 64
        self.channels = 3
        self.n_critic = 5
        self.clip_value = 0.01
        self.sample_interval = 400

opt = Options()
img_shape = (opt.channels, opt.img_size, opt.img_size)

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity


# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [4]:
from DeepLearning.Project3 import data
dataloader = data.load_dataloader_preprocess()

In [5]:
# ----------
#  Training
# ----------

with open('experiments_gan/wgan_trajectory.json', 'w') as f:
    f.write("[")

import datetime

batches_done = 0
for epoch in range(opt.n_epochs):
    print(f"Epoch: {epoch}, time: {datetime.datetime.isoformat(datetime.datetime.now())}")

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

        
        if i % 10 == 0:
            with open('experiments_gan/wgan_trajectory.json', 'a') as f:
                print(
                    '{"Epoch": %d, "Batch": %d, "D loss": %f, "G loss": %f},'
                    % (epoch, i, loss_D.item(), loss_G.item()),
                    file=f
                )
        
        
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/wgan/%d.png" % batches_done, nrow=5, normalize=True)
    
    torch.save(generator, f"experiments_gan/wgan_generator_epoch_{epoch}.pt")
    torch.save(discriminator, f"experiments_gan/wgan_discriminator_epoch_{epoch}.pt")

Epoch: 0, time: 2023-06-04T10:43:54.572022
Epoch: 1, time: 2023-06-04T10:48:47.562740
Epoch: 2, time: 2023-06-04T10:53:37.797565
Epoch: 3, time: 2023-06-04T10:58:27.017636
Epoch: 4, time: 2023-06-04T11:03:16.414734
Epoch: 5, time: 2023-06-04T11:08:06.747023
Epoch: 6, time: 2023-06-04T11:13:01.110881
Epoch: 7, time: 2023-06-04T11:17:59.457947
Epoch: 8, time: 2023-06-04T11:22:47.660494
Epoch: 9, time: 2023-06-04T11:27:44.313037
Epoch: 10, time: 2023-06-04T11:32:35.828373
Epoch: 11, time: 2023-06-04T11:37:24.471322
Epoch: 12, time: 2023-06-04T11:42:14.870167
Epoch: 13, time: 2023-06-04T11:47:12.888645
Epoch: 14, time: 2023-06-04T11:52:02.675512
Epoch: 15, time: 2023-06-04T11:56:50.280424
Epoch: 16, time: 2023-06-04T12:01:40.497159
Epoch: 17, time: 2023-06-04T12:06:29.790248
Epoch: 18, time: 2023-06-04T12:11:16.981282
Epoch: 19, time: 2023-06-04T12:16:04.249803
Epoch: 20, time: 2023-06-04T12:20:50.438764
Epoch: 21, time: 2023-06-04T12:25:39.404390
Epoch: 22, time: 2023-06-04T12:30:26.40507

KeyboardInterrupt: 