In [2]:
import argparse
import os
import io
import random
import copy
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
# Seed choice for deterministic results
manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
print("Random Seed: ", manualSeed)

Random Seed:  3384


In [4]:
# Available Painter Supercategories
# sports, accessory, animal, outdoor, vehicle, person, indoor, appliance, electronic, furniture
# food, kitchen, water, ground, solid, sky, plant, structural, building, textile, window, floor
# ceiling, wall, rawmaterial
painter_type = "animal"
image_size = 128
lo_bound = image_size // 4
up_bound = lo_bound + (image_size // 2)
sample_size = 5000
train_path = 'inpainting/train2014'
dev_path = 'inpainting/val2014'
train_annotation_path='inpainting/annotations/instances_train2014.json'
dev_annotation_path='inpainting/annotations/instances_val2014.json'

In [5]:
with open(train_annotation_path, 'r') as R:
    train_ann = json.loads(R.read())
    
with open(dev_annotation_path, 'r') as V:
    valid_ann = json.loads(V.read())

cat_labels = ['']*91
for i in range(len(train_ann['categories'])):
    cat_labels[train_ann['categories'][i]['id']] = train_ann['categories'][i]['supercategory']

In [6]:
class DataIterator(object):
    """Data Iterator for COCO."""

    def __init__(
        self,
        train_path=train_path,
        dev_path=dev_path,
        train_annotation_path=train_annotation_path,
        dev_annotation_path=dev_annotation_path,
    ):
        """Initialize params."""
        self.train_path = train_path
        self.train_annotation_path = train_annotation_path
        self.dev_path = dev_path
        self.dev_annotation_path = dev_annotation_path
        print('Processing data ...')
        self._get_real_and_fake_images()

    # JMAK - cropping to 128
    def _get_real_and_fake_images(self):            
        """Get real and fake images from path."""
        self.train_dataset = dset.CocoDetection(
            root=self.train_path,
            annFile=self.train_annotation_path,
            transform=transforms.Compose([
                 transforms.CenterCrop(image_size),
                 transforms.ToTensor()])
        )
       
        self.valid_dataset = dset.CocoDetection(
            root=self.dev_path, 
            annFile=self.dev_annotation_path,
            transform=transforms.Compose([
                 transforms.CenterCrop(image_size),
                 transforms.ToTensor()])
        )
                
        # ELDRICK: First, copy over desired number of training and validation. 
        print('Populating training images & captions ...')
        train_images = []
        # There appears to be one image missing for some weird reason.
        try:
            for img, label in self.train_dataset:
                area_index = 0
                largest_area = 0
                try:
                    for j in range(len(label)):
                        if label[j]['area'] > largest_area:
                            largest_area = label[j]['area']
                            area_index = j
                    img_cat_id = label[area_index]['category_id']
                except:
                    pass
                if cat_labels[img_cat_id] == painter_type:
                    train_images.append(img)
                if len(train_images) % 500 == 0:
                    print("Gathered ", len(train_images), " training images so far")
                if len(train_images) == sample_size:
                    break
        except IOError:
            pass
        
        train_images = torch.stack(train_images)
        
        # ELDRICK: Second, changed this to match above to terminate
        print('Populating validation images ...')
        valid_images = []
        try:
            for img, label in self.valid_dataset:
                area_index = 0
                largest_area = 0
                try:
                    for j in range(len(label)):
                        if label[j]['area'] > largest_area:
                            largest_area = label[j]['area']
                            area_index = j
                    img_cat_id = label[area_index]['category_id']
                except:
                    pass
                if cat_labels[img_cat_id] == painter_type:
                    valid_images.append(img)
                if len(valid_images) % 500 == 0:
                    print("Gathered ", len(valid_images), " validation images so far")
                if len(valid_images) == sample_size:
                    break
        except IOError:
            pass
        
        valid_images = torch.stack(valid_images)

        # ELDRICK: Crop out 128 by 128
        # JMAK: Crop out 64x64 
        print('Cropping 64x64 patch for training images ...')
        noisy_train_images = copy.deepcopy(train_images.numpy())
        noisy_train_images[:, :, lo_bound:up_bound, lo_bound:up_bound] = 0
        noisy_train_images = torch.from_numpy(noisy_train_images)

        print('Cropping 64x64 patch for validation images ...')
        noisy_valid_images = copy.deepcopy(valid_images.numpy())
        noisy_valid_images[:, :, lo_bound:up_bound, lo_bound:up_bound] = 0
        noisy_valid_images = torch.from_numpy(noisy_valid_images)
        
        self.train_images = train_images
        self.valid_images = valid_images

        self.noisy_train_images = noisy_train_images
        self.noisy_valid_images = noisy_valid_images
                
        self.num_train = len(train_images)
        self.num_valid = len(valid_images)

    # Return proper sized samples from minibatch - 128x128
    # return a 64 x 64 minibatch - JMAK
    def get_train_minibatch(self, index, batch_size):
        """Return a minibatch of real and fake examples."""
        real_examples = Variable(self.train_images[index: index + batch_size]).cuda()
        fake_examples = Variable(self.noisy_train_images[index: index + batch_size]).cuda()
        return real_examples, real_examples[:, :, lo_bound:up_bound, lo_bound:up_bound], fake_examples

    def get_valid_minibatch(self, index, batch_size):
        """Return a minibatch of real and fake examples."""
        real_examples = Variable(self.valid_images[index: index + batch_size]).cuda()
        fake_examples = Variable(self.noisy_valid_images[index: index + batch_size]).cuda()
        return real_examples, real_examples[:, :, lo_bound:up_bound, lo_bound:up_bound], fake_examples

In [7]:
iterator = DataIterator()
print("Number of Total Class Samples, Train: ", iterator.num_train)
print("Number of Total Class Samples, Valid: ", iterator.num_valid)

Processing data ...
loading annotations into memory...
Done (t=9.91s)
creating index...
index created!
loading annotations into memory...
Done (t=5.63s)
creating index...
index created!
Populating training images & captions ...
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  training images so far
Gathered  0  tra

Gathered  4500  validation images so far
Gathered  4500  validation images so far
Gathered  4500  validation images so far
Gathered  4500  validation images so far
Gathered  5000  validation images so far
Cropping 64x64 patch for training images ...
Cropping 64x64 patch for validation images ...
Number of Total Class Samples, Train:  5000
Number of Total Class Samples, Valid:  5000


In [9]:
fig = plt.figure(figsize=(10, 5))
j = np.random.randint(low=0, high=iterator.num_train)
fig.add_subplot(1, 2, 1)
plt.imshow(iterator.train_images[j].numpy().transpose(1, 2, 0))
plt.axis('off')
fig.add_subplot(1, 2, 2)
plt.imshow(iterator.noisy_train_images[j].numpy().transpose(1, 2, 0))
plt.axis('off')

AttributeError: module 'numpy' has no attribute 'bool_'

<Figure size 720x360 with 0 Axes>

In [10]:
class Generator(nn.Module):
    """Generator module."""

    def __init__(self, start_filter):
        """Initialize generator."""
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
        # input is (nc) x 128 x 128
        nn.Conv2d(3,start_filter,4,2,1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        
        # state size: (nef) x 64 x 64
        nn.Conv2d(start_filter,start_filter,4,2,1, bias=False),
        nn.BatchNorm2d(start_filter),
        nn.LeakyReLU(0.2, inplace=True),
        
        # state size: (nef) x 32 x 32
        nn.Conv2d(start_filter,start_filter*2,4,2,1, bias=False),
        nn.BatchNorm2d(start_filter*2),
        nn.LeakyReLU(0.2, inplace=True),
        
        # state size: (nef*2) x 16 x 16
        nn.Conv2d(start_filter*2,start_filter*4,4,2,1, bias=False),
        nn.BatchNorm2d(start_filter*4),
        nn.LeakyReLU(0.2, inplace=True),
        
        # state size: (nef*4) x 8 x 8
        nn.Conv2d(start_filter*4,start_filter*8,4,2,1, bias=False),
        nn.BatchNorm2d(start_filter*8),
        nn.LeakyReLU(0.2, inplace=True),
        
        # state size: (nef*8) x 4 x 4
        nn.Conv2d(start_filter*8,4000,4, bias=False),
        
        # state size: (nBottleneck) x 1 x 1
        nn.BatchNorm2d(4000),
        nn.LeakyReLU(0.2, inplace=True),
        
        # input is Bottleneck, going into a convolution
        nn.ConvTranspose2d(4000, start_filter * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(start_filter * 8),
        nn.ReLU(True),
        
        # state size. (ngf*8) x 4 x 4
        nn.ConvTranspose2d(start_filter * 8, start_filter * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(start_filter * 4),
        nn.ReLU(True),
        
        # state size. (ngf*4) x 8 x 8
        nn.ConvTranspose2d(start_filter * 4, start_filter * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(start_filter * 2),
        nn.ReLU(True),
        
        # state size. (ngf*2) x 16 x 16
        nn.ConvTranspose2d(start_filter * 2, start_filter, 4, 2, 1, bias=False),
        nn.BatchNorm2d(start_filter),
        nn.ReLU(True),
        
        # state size. (ngf) x 32 x 32
        nn.ConvTranspose2d(start_filter, 3, 4, 2, 1, bias=False),
        nn.Tanh()
        # state size. (nc) x 64 x 64
        )
    
    def forward(self, input):
        output = self.main(input)
        return output

In [11]:
class Discriminator(nn.Module):
    """Discriminator."""

    def __init__(self, start_filter):
        """Initialize params."""
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, start_filter, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf) x 32 x 32
            nn.Conv2d(start_filter, start_filter * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(start_filter * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(start_filter * 2, start_filter * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(start_filter * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(start_filter * 4, start_filter * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(start_filter * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(start_filter * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1)

In [12]:
# Hyperparameters
learning_rate_gen = 1e-3
learning_rate_dis = 1e-3
betas_gen = (0.5, 0.999)
betas_dis = (0.5, 0.999)

generator = Generator(start_filter=64).cuda()
discriminator = Discriminator(start_filter=64).cuda()
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate_gen, betas=betas_gen)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate_dis, betas=betas_dis)
clamp_lower = -0.03
clamp_upper = 0.03
loss_criterion = nn.MSELoss().cuda()
save_dir = 'inpainting/samples'

In [13]:
def save_plots(epoch, fake_images, real_images, real_examples_full):
    # ELDRICK: Change random to fit your sample size
    j = np.random.randint(low=0, high=iterator.num_valid)
    # ELDRICK: Change minibatch size to fit samples
    real_examples_full, real_examples, fake_images = iterator.get_valid_minibatch(j, 64)
    generator.eval()
    reconstructions = generator(fake_images)
    reconstructions = reconstructions.data.cpu().numpy()
    real = real_examples_full.data.cpu().numpy()
    real_copy = copy.deepcopy(real)
    real_copy[:, :, lo_bound:up_bound, lo_bound:up_bound] = reconstructions
    real_copy = torch.from_numpy(real_copy)
    real = torch.from_numpy(real)
    out_tensor = torch.zeros(1, real_copy.size(1), real_copy.size(2), real_copy.size(3))
    for zz, zzz in zip(real_copy[:10], real[:10]):
        out_tensor = torch.cat([out_tensor, zz.unsqueeze(0)])
        out_tensor = torch.cat([out_tensor, zzz.unsqueeze(0)])
    vutils.save_image(out_tensor[1:], 'inpainting/samples/epoch_s%d_samples.png' % (epoch), normalize=True, scale_each=True, nrow=4)
    generator.train()

In [14]:
criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()
for i in range(500):
    discriminator_losses = []
    generator_losses = []
    for j in range(0, iterator.num_train, 64):        
        ############################
        # (1) Update D network
        ############################
        for disc_updates in range(5):
            real_examples_full, real_examples, fake_images = iterator.get_train_minibatch(j, 64)
            D1 = discriminator(real_examples)
            fake = generator(fake_images)
            D2 = discriminator(fake)
            
            true_labels = torch.ones(D1.size()).cuda()
            fake_labels = torch.zeros(D2.size()).cuda()
            
            # Utilize BCE Loss
            real_d_loss = criterion(D1, true_labels).cuda()
            fake_d_loss = criterion(D2, fake_labels).cuda()
            
            discriminator_loss = real_d_loss + fake_d_loss
            optimizer_discriminator.zero_grad()
            discriminator_loss.backward(retain_graph=True)
            
            optimizer_discriminator.step()
            discriminator_losses.append(discriminator_loss.item())

            # clamp parameters to a cube
            for p in discriminator.parameters():
                p.data.clamp_(clamp_lower, clamp_upper)
        
        ############################
        # (2) Update G network
        ############################
        
        generated_images = generator(fake_images)
        
        # Changing Gen Loss to utilize BCE + L2 MSE Loss
        gen_bce_loss = criterion(D2, true_labels).cuda()
        gen_l2_loss = criterionMSE(fake, real_examples).cuda()
        
        generator_loss = 0.001*gen_bce_loss + 0.999*gen_l2_loss
        
        optimizer_generator.zero_grad()
        generator_loss.backward()
        optimizer_generator.step()
        generator_losses.append(generator_loss.item())

    print('[%d] Loss_D: %f Loss_G: %f' % (i, np.mean(discriminator_losses), np.mean(generator_losses)))
    save_plots(i, fake_images, real_examples, real_examples_full)

[0] Loss_D: 0.375053 Loss_G: 0.053182
[1] Loss_D: 0.777685 Loss_G: 0.035414
[2] Loss_D: 0.663081 Loss_G: 0.033636
[3] Loss_D: 0.062871 Loss_G: 0.034997
[4] Loss_D: 0.005371 Loss_G: 0.034718
[5] Loss_D: 0.004128 Loss_G: 0.034118
[6] Loss_D: 0.003657 Loss_G: 0.033505
[7] Loss_D: 0.869406 Loss_G: 0.029199
[8] Loss_D: 0.564921 Loss_G: 0.029121
[9] Loss_D: 0.061893 Loss_G: 0.031286
[10] Loss_D: 0.003412 Loss_G: 0.030793
[11] Loss_D: 0.003260 Loss_G: 0.030026
[12] Loss_D: 0.002731 Loss_G: 0.029255
[13] Loss_D: 0.002536 Loss_G: 0.028660
[14] Loss_D: 0.881148 Loss_G: 0.023950
[15] Loss_D: 0.452676 Loss_G: 0.024448
[16] Loss_D: 0.746494 Loss_G: 0.022719
[17] Loss_D: 0.480381 Loss_G: 0.023171
[18] Loss_D: 0.659718 Loss_G: 0.021842
[19] Loss_D: 0.667160 Loss_G: 0.021135
[20] Loss_D: 0.460765 Loss_G: 0.020845
[21] Loss_D: 0.504489 Loss_G: 0.020299
[22] Loss_D: 1.033842 Loss_G: 0.017968
[23] Loss_D: 0.334711 Loss_G: 0.019366
[24] Loss_D: 0.515729 Loss_G: 0.018535
[25] Loss_D: 0.543140 Loss_G: 0.017

[208] Loss_D: 0.754919 Loss_G: 0.005901
[209] Loss_D: 0.772170 Loss_G: 0.005834
[210] Loss_D: 0.742017 Loss_G: 0.005863
[211] Loss_D: 0.845156 Loss_G: 0.005677
[212] Loss_D: 0.744390 Loss_G: 0.005862
[213] Loss_D: 0.808225 Loss_G: 0.005703
[214] Loss_D: 0.796912 Loss_G: 0.005675
[215] Loss_D: 0.715586 Loss_G: 0.005910
[216] Loss_D: 0.773026 Loss_G: 0.005751
[217] Loss_D: 0.746871 Loss_G: 0.005793
[218] Loss_D: 0.803493 Loss_G: 0.005704
[219] Loss_D: 0.784419 Loss_G: 0.005661
[220] Loss_D: 0.741913 Loss_G: 0.005835
[221] Loss_D: 0.723441 Loss_G: 0.005829
[222] Loss_D: 0.745060 Loss_G: 0.005783
[223] Loss_D: 0.771099 Loss_G: 0.005791
[224] Loss_D: 0.752143 Loss_G: 0.005816
[225] Loss_D: 0.787639 Loss_G: 0.005711
[226] Loss_D: 0.860509 Loss_G: 0.005559
[227] Loss_D: 0.800591 Loss_G: 0.005617
[228] Loss_D: 0.768600 Loss_G: 0.005663
[229] Loss_D: 0.773716 Loss_G: 0.005710
[230] Loss_D: 0.776678 Loss_G: 0.005674
[231] Loss_D: 0.837821 Loss_G: 0.005492
[232] Loss_D: 0.827596 Loss_G: 0.005488


[413] Loss_D: 0.804616 Loss_G: 0.004672
[414] Loss_D: 0.844554 Loss_G: 0.004576
[415] Loss_D: 0.921915 Loss_G: 0.004482
[416] Loss_D: 0.903950 Loss_G: 0.004426
[417] Loss_D: 0.967204 Loss_G: 0.004270
[418] Loss_D: 0.888462 Loss_G: 0.004482
[419] Loss_D: 0.856518 Loss_G: 0.004531
[420] Loss_D: 0.839662 Loss_G: 0.004558
[421] Loss_D: 0.902093 Loss_G: 0.004439
[422] Loss_D: 0.919930 Loss_G: 0.004389
[423] Loss_D: 0.895115 Loss_G: 0.004421
[424] Loss_D: 0.883272 Loss_G: 0.004457
[425] Loss_D: 0.887609 Loss_G: 0.004430
[426] Loss_D: 0.868685 Loss_G: 0.004491
[427] Loss_D: 0.835095 Loss_G: 0.004595
[428] Loss_D: 0.895646 Loss_G: 0.004467
[429] Loss_D: 0.875826 Loss_G: 0.004477
[430] Loss_D: 0.946381 Loss_G: 0.004306
[431] Loss_D: 0.848941 Loss_G: 0.004544
[432] Loss_D: 0.821720 Loss_G: 0.004543
[433] Loss_D: 0.886616 Loss_G: 0.004443
[434] Loss_D: 0.897580 Loss_G: 0.004369
[435] Loss_D: 0.926124 Loss_G: 0.004324
[436] Loss_D: 0.907780 Loss_G: 0.004364
[437] Loss_D: 0.860598 Loss_G: 0.004428


In [16]:
state = {
    'generator_state': generator.state_dict(),
    'discriminator_state': discriminator.state_dict(),
}
torch.save(state, 'inpainting/painter_animal_state_test')