In [1]:
import os
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, utils
from torch.autograd import Variable
from random import random

%matplotlib inline

In [2]:
IMG_SIZE = 160
Z_DIMENSION = 32
BATCH_SIZE = 36
NUM_CHANNELS = 3
NUM_ITERATIONS = 25

In [3]:
imageio.plugins.ffmpeg.download()

In [4]:
class DoomFrameDataset(torch.utils.data.Dataset):
    def __init__(self, video_location, transform=None):
        """
        video_location (string): Path to the video file
        transform (function, optional): Transforms to apply
        """
        self.video_reader = imageio.get_reader(video_location,  'ffmpeg')
        self.frames = [PIL.Image.fromarray(self.video_reader.get_data(idx)) 
                       for idx in range(len(self.video_reader)) 
                       if idx % 4 == 0]
        self.transform = transform

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        image = self.frames[idx]

        if self.transform:
            image = self.transform(image)

        return image

img_trans = transforms.Compose([transforms.Resize(IMG_SIZE),
                                transforms.CenterCrop(IMG_SIZE),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])

doom_data = DoomFrameDataset("data/doom_gameplay.mp4", img_trans)

doom_loader = torch.utils.data.DataLoader(doom_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [5]:
class DoomGenerator(nn.Module):
    def __init__(self, hidden_units = 128):
        super(DoomGenerator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(Z_DIMENSION, hidden_units*32, 4, 1, 0)
        self.deconv1_bn = nn.BatchNorm2d(hidden_units*32)
        self.deconv2 = nn.ConvTranspose2d(hidden_units*32, hidden_units*16, 4, 1, 0)
        self.deconv2_bn = nn.BatchNorm2d(hidden_units*16)
        self.deconv3 = nn.ConvTranspose2d(hidden_units*16, hidden_units*8, 4, 1, 0)
        self.deconv3_bn = nn.BatchNorm2d(hidden_units*8)
        self.deconv4 = nn.ConvTranspose2d(hidden_units*8, hidden_units*4, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(hidden_units*4)
        self.deconv5 = nn.ConvTranspose2d(hidden_units*4, hidden_units*2, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm2d(hidden_units*2)
        self.deconv6 = nn.ConvTranspose2d(hidden_units*2, hidden_units, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm2d(hidden_units)
        self.deconv7 = nn.ConvTranspose2d(hidden_units, NUM_CHANNELS, 4, 2, 1)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, z_data):
        x = F.relu(self.deconv1_bn(self.deconv1(z_data)))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = F.relu(self.deconv5_bn(self.deconv5(x)))
        x = F.relu(self.deconv6_bn(self.deconv6(x)))
        x = F.tanh(self.deconv7(x))

        return x

doom_generator = DoomGenerator()

In [6]:
class DoomDiscriminator(nn.Module):
    def __init__(self, hidden_units=128):
        super(DoomDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(NUM_CHANNELS, hidden_units, 4, 2, 1)
        self.conv2 = nn.Conv2d(hidden_units, hidden_units*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(hidden_units*2)
        self.conv3 = nn.Conv2d(hidden_units*2, hidden_units*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(hidden_units*4)
        self.conv4 = nn.Conv2d(hidden_units*4, hidden_units*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(hidden_units*8)
        self.conv5 = nn.Conv2d(hidden_units*8, hidden_units*16, 4, 1, 0)
        self.conv5_bn = nn.BatchNorm2d(hidden_units*16)
        self.conv6 = nn.Conv2d(hidden_units*16, hidden_units*32, 4, 1, 0)
        self.conv6_bn = nn.BatchNorm2d(hidden_units*32)
        self.conv7 = nn.Conv2d(hidden_units*32, 1, 4, 1, 0)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, image):
        x = F.leaky_relu(self.conv1(image), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.leaky_relu(self.conv5_bn(self.conv5(x)), 0.2)
        x = F.leaky_relu(self.conv6_bn(self.conv6(x)), 0.2)
        x = F.sigmoid(self.conv7(x))

        return x

doom_discriminator = DoomDiscriminator()

In [7]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [8]:
#doom_discriminator.weight_init(0, .02)
doom_discriminator.cuda()

#doom_generator.weight_init(0, .02)
doom_generator.cuda()

DoomGenerator(
  (deconv1): ConvTranspose2d (32, 4096, kernel_size=(4, 4), stride=(1, 1))
  (deconv1_bn): BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True)
  (deconv2): ConvTranspose2d (4096, 2048, kernel_size=(4, 4), stride=(1, 1))
  (deconv2_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
  (deconv3): ConvTranspose2d (2048, 1024, kernel_size=(4, 4), stride=(1, 1))
  (deconv3_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
  (deconv4): ConvTranspose2d (1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (deconv4_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  (deconv5): ConvTranspose2d (512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (deconv5_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  (deconv6): ConvTranspose2d (256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (deconv6_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (deconv7): ConvTranspose2d (128, 3, kernel_size=(4,

In [10]:
#doom_generator.load_state_dict(torch.load('gan_files/gen_latest.pth'))
#doom_discriminator.load_state_dict(torch.load('gan_files/disc_latest.pth'))

In [11]:
noise = torch.FloatTensor(BATCH_SIZE, Z_DIMENSION, 1, 1).cuda()
fixed_noise = Variable(torch.FloatTensor(BATCH_SIZE, Z_DIMENSION, 1, 1).normal_(0, 1)).cuda()

disc_opt = optim.Adam(doom_discriminator.parameters(), lr=.0001, betas=(.5, 0.999))
gen_opt = optim.Adam(doom_generator.parameters(), lr=.0001, betas=(.5, 0.999))

binary_cross_entropy_loss = nn.BCELoss()

disc_scheduler = optim.lr_scheduler.StepLR(disc_opt, 10, .5)
gen_scheduler = optim.lr_scheduler.StepLR(gen_opt, 10, .5)

added_noise = .1

In [None]:
fake = doom_generator(fixed_noise)
utils.save_image(fake.data,
        'gan_files/fake_frames_before_training.png',
        normalize = True, nrow = 6)

for epoch in range(12, NUM_ITERATIONS + 1):
    disc_scheduler.step()
    gen_scheduler.step()
    for repeat in range(1, 5):
        for i, data in enumerate(doom_loader, start = 1):

            if i == 1 and epoch == 1:
                utils.save_image(data, 'gan_files/real_frames.png', normalize = True, nrow = 6)


            real_label = Variable(torch.FloatTensor(data.size(0)).uniform_(0.7, 1.2)).cuda()
            fake_label = Variable(torch.FloatTensor(noise.size(0)).uniform_(0.0, 0.3)).cuda()

            doom_discriminator.zero_grad()
            noisy_input = data + torch.FloatTensor(*data.size()).normal_(0, added_noise * .9 ** epoch)
            inputv = Variable(noisy_input).cuda()

            real_output = doom_discriminator(inputv)


            # train with fake
            noise.normal_(0, 1)
            noisev = Variable(noise).cuda()
            fake = doom_generator(noisev)
            noisy_fake = fake.data + torch.FloatTensor(*fake.size()).normal_(0, added_noise * .9 ** epoch).cuda()
            fake_output = doom_discriminator(Variable(noisy_fake).cuda())

            if random() < 0.01:    
                disc_real_loss = binary_cross_entropy_loss(real_output.squeeze(), Variable(real_label.data.uniform_(0.0, 0.3)).cuda())
                disc_fake_loss = binary_cross_entropy_loss(fake_output.squeeze(), Variable(fake_label.data.uniform_(0.7, 1.2)).cuda())
            else:
                disc_real_loss = binary_cross_entropy_loss(real_output.squeeze(), real_label)
                disc_fake_loss = binary_cross_entropy_loss(fake_output.squeeze(), fake_label)

            disc_real_loss.backward()
            disc_fake_loss.backward()

            disc_loss = disc_real_loss + disc_fake_loss

            disc_opt.step()



            doom_generator.zero_grad()
            output = doom_discriminator(fake)
            real_label = Variable(torch.FloatTensor(output.size(0)).fill_(1)).cuda()

            gen_loss = binary_cross_entropy_loss(output.squeeze(), real_label)
            gen_loss.backward()
            gen_opt.step()


        print(f"""[{epoch}/{NUM_ITERATIONS}] :: [{repeat}/{4}]
        Disc_Real_Loss: {disc_real_loss.data[0]} Disc_Fake_Loss: {disc_fake_loss.data[0]} Gen_Loss: {gen_loss.data[0]}""")

    fake = doom_generator(fixed_noise)
    utils.save_image(fake.data,
            f'gan_files/fake_frames_after_epoch_{epoch}.png',
            normalize = True, nrow = 6)

    # do checkpointing
    torch.save(doom_generator.state_dict(), f'gan_files/gen_latest.pth')
    torch.save(doom_discriminator.state_dict(), f'gan_files/disc_latest.pth')

[12/25] :: [1/4]
        Disc_Real_Loss: 0.4116857945919037 Disc_Fake_Loss: 0.49440664052963257 Gen_Loss: 1.2685251235961914
[12/25] :: [2/4]
        Disc_Real_Loss: 0.40914639830589294 Disc_Fake_Loss: 0.4656500816345215 Gen_Loss: 1.874472737312317
[12/25] :: [3/4]
        Disc_Real_Loss: 0.3832114636898041 Disc_Fake_Loss: 0.4555117189884186 Gen_Loss: 1.9223933219909668
[12/25] :: [4/4]
        Disc_Real_Loss: 0.2888214886188507 Disc_Fake_Loss: 0.5268737077713013 Gen_Loss: 2.0882744789123535
[13/25] :: [1/4]
        Disc_Real_Loss: 0.43545666337013245 Disc_Fake_Loss: 0.4795427918434143 Gen_Loss: 1.274022102355957
[13/25] :: [2/4]
        Disc_Real_Loss: 0.3524535596370697 Disc_Fake_Loss: 0.4227713644504547 Gen_Loss: 2.1557698249816895
[13/25] :: [3/4]
        Disc_Real_Loss: 0.25271275639533997 Disc_Fake_Loss: 0.4680973291397095 Gen_Loss: 1.4142323732376099
[13/25] :: [4/4]
        Disc_Real_Loss: 1.2793246507644653 Disc_Fake_Loss: 0.8484671115875244 Gen_Loss: 0.5126442313194275
[14/25