In [19]:
import glob
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from PIL import Image
from PIL import ImageOps
import random

# Hyper-parameters
latent_size = 100
hidden_size = 256
image_size = 3*64*64
num_epochs = 25
batch_size = 6
discriminator_iteration = 1
generator_iteration = 1
sample_dir = './Sample Images'

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
def convertbg(img):
    img = img.convert("RGBA")
    datas = img.getdata()
 
    newData = []
 
    for item in datas:
        if item[0] == 255 and item[1] == 255 and item[2] == 255:
            newData.append((255, 255, 255, 0))
        else:
            newData.append(item)
 
    img.putdata(newData)
    img = img.convert('RGB')
    return img

In [45]:
normalization_stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64,64)),
    transforms.Normalize(*normalization_stats)
])

spritesList = list()
versions = ['black-white','emerald','platinum']
for ver in versions:
    images = glob.glob("./Pokemon/main-sprites/"+ ver +"/*.png")
    for image in images:
        with open(image, 'rb') as file:
            img = Image.open(file)
            mirror_img = ImageOps.mirror(img)
            img = img.convert('RGB')
            mirror_img = mirror_img.convert('RGB')
            if ver == 'crystal' or ver == 'yellow':
                img = convertbg(img)
                mirror_img = convertbg(img)
            img = transform(img)
            mirror_img = transform(mirror_img)
            spritesList.append(img)
            spritesList.append(mirror_img)

print(len(spritesList))

6948


In [46]:
#Do colour transformation for the images
color_transform = transforms.Compose([
    transforms.ColorJitter(0.5, 0.5, 0.5)
])

color_sprites = list()
for idx, (image) in enumerate(spritesList):
    c_image = color_transform(image)
    color_sprites.append(c_image)
    
spritesList = spritesList + color_sprites
random.shuffle(spritesList)

#Gaussian Blur the images as well
g_transform = transforms.Compose([
    transforms.GaussianBlur(3)
])

blur_sprites = list()
for idx, (image) in enumerate(spritesList):
    g_image = g_transform(image)
    blur_sprites.append(g_image)
    
spritesList = spritesList + blur_sprites
random.shuffle(spritesList)

print(len(spritesList))

27792


In [23]:
# Discriminator - DCGAN implementation
# input noise dimension
nz = 100
# number of generator filters
ngf = 64
#number of discriminator filters
ndf = 64
#number of channels
nc=3

D = nn.Sequential(
    # Input is 3 x 64 x 64
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # Layer Output: 64 x 32 x 32
    
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # Layer Output: 128 x 16 x 16
    
    nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # Layer Output: 128 x 8 x 8
    
    nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # Layer Output: 128 x 4 x 4
    
    # With a 4x4, we can condense the channels into a 1 x 1 x 1 to produce output
    nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0, bias=False),
    nn.Sigmoid()
)

# Generator 
G = nn.Sequential(
    nn.ConvTranspose2d(nz, 128, kernel_size=4, padding=0, stride=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # Layer output: 256 x 4 x 4
    
    nn.ConvTranspose2d(128, 128, kernel_size=4, padding=1, stride=2, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # Layer output: 128 x 8 x 8
    
    nn.ConvTranspose2d(128, 128, kernel_size=4, padding=1, stride=2, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # Layer output: 64 x 16 x 16
    
    nn.ConvTranspose2d(128, 64, kernel_size=4, padding=1, stride=2, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # Layer output: 32 x 32 x 32
    
    nn.ConvTranspose2d(64, 3, kernel_size=4, padding=1, stride=2, bias=False),
    nn.Tanh()
)

# Load State Dict
D.load_state_dict(torch.load('D.ckpt'))
G.load_state_dict(torch.load('G.ckpt'))

# Device setting
D = D.to(device)
G = G.to(device)

In [48]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.000002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.000002)

In [29]:
def denorm(x):
    # TANH (-1, 1)
    out = (x * 0.5) + 0.5
    return out

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [26]:
pokemonTrainingList = torch.stack(spritesList)
pokemonTrainingList = pokemonTrainingList.split(batch_size)
pokemonTrainingList = np.array(pokemonTrainingList)
print(pokemonTrainingList.shape)

(2316,)


  This is separate from the ipykernel package so we can avoid doing imports until
  This is separate from the ipykernel package so we can avoid doing imports until


In [49]:
# Start training
one = torch.FloatTensor([1])
mone = one * -1

one.to(device)
mone.to(device)

total_step = len(pokemonTrainingList)
for epoch in range(num_epochs):
    for i, images in enumerate(pokemonTrainingList):
        # print(images.shape)
        # images = images.reshape(batch_size, -1).to(device)
        images = images.to(device)
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0

        for j in range(discriminator_iteration):  
          # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
          # Second term of the loss is always zero since real_labels == 1
          outputs = D(images)
          outputs = outputs.view(batch_size,-1)
          d_loss_real = criterion(outputs, real_labels)
          real_score = outputs
          
          # Compute BCELoss using fake images
          # First term of the loss is always zero since fake_labels == 0
          z = torch.randn(batch_size, nz , 1 , 1).to(device)
          fake_images = G(z)
          outputs = D(fake_images).view(batch_size,-1)
          d_loss_fake = criterion(outputs, fake_labels)
          fake_score = outputs
          
          # Backprop and optimize
          d_loss = d_loss_real + d_loss_fake
          reset_grad()
          # calculates gradient
          d_loss.backward()
          # Update parameters
          d_optimizer.step()
          

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        ######GAN
        # Compute loss with fake images
        for k in range(generator_iteration):  
          z = torch.randn(batch_size, nz , 1 , 1).to(device)
          fake_images = G(z)
          outputs = D(fake_images).view(batch_size,-1)
          
          # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
          # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
          g_loss = criterion(outputs, real_labels)
          
          # Backprop and optimize
          reset_grad()
          g_loss.backward()
          g_optimizer.step()
        
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 3, 64, 64)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 3, 64, 64)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
    torch.save(G.state_dict(), 'G.ckpt')
    torch.save(D.state_dict(), 'D.ckpt')

# Save the model checkpoints 


Epoch [0/25], Step [100/2316], d_loss: 0.0007, g_loss: 6.6811, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [200/2316], d_loss: 0.0041, g_loss: 8.9238, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [300/2316], d_loss: 0.0007, g_loss: 7.3642, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [400/2316], d_loss: 0.0060, g_loss: 5.7421, D(x): 1.00, D(G(z)): 0.01
Epoch [0/25], Step [500/2316], d_loss: 0.0031, g_loss: 6.5072, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [600/2316], d_loss: 0.0116, g_loss: 6.6692, D(x): 1.00, D(G(z)): 0.01
Epoch [0/25], Step [700/2316], d_loss: 0.0030, g_loss: 6.5204, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [800/2316], d_loss: 0.0065, g_loss: 5.6686, D(x): 1.00, D(G(z)): 0.01
Epoch [0/25], Step [900/2316], d_loss: 0.0078, g_loss: 5.1998, D(x): 1.00, D(G(z)): 0.01
Epoch [0/25], Step [1000/2316], d_loss: 0.0038, g_loss: 5.8314, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step [1100/2316], d_loss: 0.0020, g_loss: 10.8248, D(x): 1.00, D(G(z)): 0.00
Epoch [0/25], Step