In [1]:
import os
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, utils
from torch.autograd import Variable

%matplotlib inline

In [2]:
IMG_SIZE = 64
Z_DIMENSION = 32
BATCH_SIZE = 64
NUM_CHANNELS = 3
NUM_ITERATIONS = 25

In [3]:
class DoomFrameDataset(torch.utils.data.Dataset):
    def __init__(self, video_location, transform=None):
        """
        video_location (string): Path to the video file
        transform (function, optional): Transforms to apply
        """
        self.video_reader = imageio.get_reader(video_location,  'ffmpeg')
        self.transform = transform

    def __len__(self):
        return len(self.video_reader)

    def __getitem__(self, idx):
        image = self.video_reader.get_data(idx)

        if self.transform:
            image = self.transform(image)

        return image

img_trans = transforms.Compose([PIL.Image.fromarray,
                                transforms.Resize(IMG_SIZE),
                                transforms.CenterCrop(IMG_SIZE),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])

doom_data = DoomFrameDataset("data/doom_gameplay.mp4", img_trans)

doom_loader = torch.utils.data.DataLoader(doom_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [4]:
class DoomGenerator(nn.Module):
    def __init__(self, hidden_units = 128):
        super(DoomGenerator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(Z_DIMENSION, hidden_units*8, 4, 1, 0)
        self.deconv1_bn = nn.BatchNorm2d(hidden_units*8)
        self.deconv2 = nn.ConvTranspose2d(hidden_units*8, hidden_units*4, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(hidden_units*4)
        self.deconv3 = nn.ConvTranspose2d(hidden_units*4, hidden_units*2, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(hidden_units*2)
        self.deconv4 = nn.ConvTranspose2d(hidden_units*2, hidden_units, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(hidden_units)
        self.deconv5 = nn.ConvTranspose2d(hidden_units, NUM_CHANNELS, 4, 2, 1)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, z_data):
        x = F.relu(self.deconv1_bn(self.deconv1(z_data)))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = F.tanh(self.deconv5(x))

        return x

doom_generator = DoomGenerator()

In [5]:
noise = torch.FloatTensor(BATCH_SIZE, Z_DIMENSION, 1, 1).normal_(0, 1)

doom_generator(Variable(noise)).size()

torch.Size([64, 3, 64, 64])

In [6]:
class DoomDiscriminator(nn.Module):
    def __init__(self, hidden_units=128):
        super(DoomDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(NUM_CHANNELS, hidden_units, 4, 2, 1)
        self.conv2 = nn.Conv2d(hidden_units, hidden_units*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(hidden_units*2)
        self.conv3 = nn.Conv2d(hidden_units*2, hidden_units*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(hidden_units*4)
        self.conv4 = nn.Conv2d(hidden_units*4, hidden_units*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(hidden_units*8)
        self.conv5 = nn.Conv2d(hidden_units*8, 1, 4, 1, 0)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, image):
        x = F.leaky_relu(self.conv1(image), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x

doom_discriminator = DoomDiscriminator()

In [7]:
doom_discriminator(Variable(doom_data[0].unsqueeze(0)))

Variable containing:
(0 ,0 ,.,.) = 
  0.5833
[torch.FloatTensor of size 1x1x1x1]

In [9]:
doom_discriminator(doom_generator(Variable(noise))).size()

torch.Size([64, 1, 1, 1])

In [10]:
noise = torch.FloatTensor(BATCH_SIZE, Z_DIMENSION, 1, 1)
fixed_noise = Variable(torch.FloatTensor(BATCH_SIZE, Z_DIMENSION, 1, 1).normal_(0, 1))

disc_opt = optim.Adam(doom_discriminator.parameters(), lr=.0005, betas=(.5, 0.999))
gen_opt = optim.Adam(doom_generator.parameters(), lr=.0005, betas=(.5, 0.999))

binary_cross_entropy_loss = nn.BCELoss()

In [None]:
for epoch in range(NUM_ITERATIONS):
    for i, data in enumerate(doom_loader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        doom_discriminator.zero_grad()
        inputv = Variable(data)
        real_label = Variable(torch.FloatTensor(data.size(0)).fill_(1))
        
        

        output = doom_discriminator(inputv)
        disc_real_loss = binary_cross_entropy_loss(output.squeeze(), real_label)
        disc_real_loss.backward()

        # train with fake
        noise.normal_(0, 1)
        noisev = Variable(noise)
        fake = doom_generator(noisev)
        output = doom_discriminator(fake.detach())
        fake_label = Variable(torch.FloatTensor(noise.size(0)).fill_(0)) 
        disc_fake_loss = binary_cross_entropy_loss(output.squeeze(), fake_label)
        disc_fake_loss.backward()
        disc_loss = disc_real_loss + disc_fake_loss
        disc_opt.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        doom_generator.zero_grad()
        output = doom_discriminator(fake)
        real_label = Variable(torch.FloatTensor(output.size(0)).fill_(1))

        gen_loss = binary_cross_entropy_loss(output.squeeze(), real_label)
        gen_loss.backward()
        gen_opt.step()

        print(f"""[{epoch}/{NUM_ITERATIONS}] :: [{i}/{len(doom_loader)}]
        Disc_Real_Loss: {disc_real_loss.data[0]} Disc_Fake_Loss: {disc_fake_loss.data[0]} Gen_Loss: {gen_loss.data[0]}""")
        if i % 100 == 0:
            utils.save_image(data,
                    'gan_files/real_samples.png',
                    normalize=True)
            fake = doom_generator(fixed_noise)
            utils.save_image(fake.data,
                    f'gan_files/fake_samples_epoch_{epoch}.png',
                    normalize=True)

    # do checkpointing
    torch.save(doom_generator.state_dict(), f'gan_files/gen_epoch_{epoch}.pth')
    torch.save(doom_discriminator.state_dict(), f'gan_files/disc_epoch_{epoch}.pth')

[0/25] :: [0/195]
        Disc_Real_Loss: 0.005009102635085583 Disc_Fake_Loss: 0.019049787893891335 Gen_Loss: 6.314743518829346
[0/25] :: [1/195]
        Disc_Real_Loss: 0.013813783414661884 Disc_Fake_Loss: 7.063241004943848 Gen_Loss: 5.036657810211182
[0/25] :: [2/195]
        Disc_Real_Loss: 0.3938445746898651 Disc_Fake_Loss: 0.5488819479942322 Gen_Loss: 7.260267734527588
[0/25] :: [3/195]
        Disc_Real_Loss: 0.7492496967315674 Disc_Fake_Loss: 0.030858883634209633 Gen_Loss: 4.969048500061035
[0/25] :: [4/195]
        Disc_Real_Loss: 0.014431076124310493 Disc_Fake_Loss: 0.581393301486969 Gen_Loss: 4.779274940490723
[0/25] :: [5/195]
        Disc_Real_Loss: 0.025245795026421547 Disc_Fake_Loss: 0.05954959988594055 Gen_Loss: 4.616872310638428
[0/25] :: [6/195]
        Disc_Real_Loss: 0.050661876797676086 Disc_Fake_Loss: 0.1159633994102478 Gen_Loss: 3.944753885269165
[0/25] :: [7/195]
        Disc_Real_Loss: 0.011134237051010132 Disc_Fake_Loss: 0.4208289086818695 Gen_Loss: 4.964893817

        Disc_Real_Loss: 0.5679936408996582 Disc_Fake_Loss: 0.551529586315155 Gen_Loss: 2.7228152751922607
[0/25] :: [65/195]
        Disc_Real_Loss: 0.7279313802719116 Disc_Fake_Loss: 0.26267021894454956 Gen_Loss: 0.8468409776687622
[0/25] :: [66/195]
        Disc_Real_Loss: 0.14282415807247162 Disc_Fake_Loss: 1.8172409534454346 Gen_Loss: 3.214627265930176
[0/25] :: [67/195]
        Disc_Real_Loss: 2.3422703742980957 Disc_Fake_Loss: 0.08540057390928268 Gen_Loss: 1.9127916097640991
[0/25] :: [68/195]
        Disc_Real_Loss: 1.2554121017456055 Disc_Fake_Loss: 0.23947599530220032 Gen_Loss: 0.8889126777648926
[0/25] :: [69/195]
        Disc_Real_Loss: 0.45191413164138794 Disc_Fake_Loss: 0.7466612458229065 Gen_Loss: 0.6345129609107971
[0/25] :: [70/195]
        Disc_Real_Loss: 0.27015790343284607 Disc_Fake_Loss: 0.8908170461654663 Gen_Loss: 0.8621242642402649
[0/25] :: [71/195]
        Disc_Real_Loss: 0.29910680651664734 Disc_Fake_Loss: 0.6630015969276428 Gen_Loss: 1.2329955101013184
[0/25]