In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import cv2

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from pix2pix_models import *
from pix2pix_datasets import *

import torch.nn as nn
import torch.nn.functional as F
import torch

from data import *
from utils.augmentations import SSDAugmentation
from layers.modules import MultiBoxLoss
from ssd import build_ssd

import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.utils.data as data

In [2]:
# SSD
DATASET_ROOT = "/media/arg_ws3/5E703E3A703E18EB/data/argbot"
DATASET_NAME = "person_mask"
cfg = argbot
BASE_NET = "./weights/vgg16_reducedfc.pth"
DATA_DETECTION = ARGBOTDetection
BATCH_SIZE = 1
PRETRAINED_MODEL = "/home/arg_ws3/ssd.pytorch/weights/person/person_20000.pth"
#PRETRAINED_MODEL = None
PRETRAINED_EPOCH = 0
SAVE_MODEL_EPOCH = 1
START_ITER = 0
NUM_WORKERS = 8
EPOCH = 300
adjust_lr_epoch = [60, 80, 150]
CUDA = True
LR = 1e-4
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
GAMMA = 0.1
VISDOM = False

# GAN
b1 = 0.5
b2 = 0.999
lr = 0.002
decay_epoch = 100
n_cpu = 8
sample_interval = 200
channels = 3
Tensor = torch.cuda.FloatTensor if CUDA else torch.FloatTensor

In [3]:
if torch.cuda.is_available():
    if not CUDA:
        print("WTF are u wasting your CUDA device?")
    else:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

In [4]:
def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")

# Initial model weights & bias
def xavier(param):
    init.xavier_uniform(param)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        m.bias.data.zero_()

# Adjust learning rate during training
def adjust_learning_rate(optimizer, gamma, step):
    """Sets the learning rate to the initial LR decayed by 10 at every
        specified step
    # Adapted from PyTorch Imagenet example:
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    #lr = LR * (gamma ** (step))
    lr = LR * (gamma)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        print("Change learning rate to: ", lr)

In [5]:
ssd_pretrained = build_ssd('train', cfg['min_dim'], 2)
if CUDA:
    net = torch.nn.DataParallel(ssd_pretrained)
    cudnn.benchmark = True
    
if PRETRAINED_MODEL is not None: # Use SSD pretrained model
    print('Resuming training, loading {}...'.format(PRETRAINED_MODEL))
    ssd_pretrained.load_weights(PRETRAINED_MODEL)

Resuming training, loading /home/arg_ws3/ssd.pytorch/weights/person/person_20000.pth...
Loading weights into state dict...
Finished!


  self.priors = Variable(self.priorbox.forward(), volatile=True)
  init.constant(self.weight,self.gamma)


In [6]:
# Delcare SSD Network
ssd_net = build_ssd('train', cfg['min_dim'], 2)
net = ssd_net
if CUDA:
    net = torch.nn.DataParallel(ssd_net)
    cudnn.benchmark = True
SAME_CLASS = True
if PRETRAINED_MODEL is not None: # Use SSD pretrained model
    if SAME_CLASS:
        print('Resuming training, loading {}...'.format(PRETRAINED_MODEL))
        ssd_net.load_weights(PRETRAINED_MODEL)
    else:
        print('Load pretrained model with different classes')
        ssd_pretrained = build_ssd('train', cfg['min_dim'], 2)
        ssd_pretrained.load_weights(PRETRAINED_MODEL)
        ssd_net.vgg = ssd_pretrained.vgg
        ssd_net.extras = ssd_pretrained.extras
        ssd_net.loc = ssd_pretrained.loc
        ssd_net.conf.apply(weights_init)
else:
    print('Initializing weights...')
    vgg_weights = torch.load(BASE_NET) # load vgg pretrained model
    ssd_net.vgg.load_state_dict(vgg_weights)
    ssd_net.extras.apply(weights_init) # Initial SSD model weights & bias
    ssd_net.loc.apply(weights_init)
    ssd_net.conf.apply(weights_init)

optimizer_SSD = optim.SGD(net.parameters(), lr=LR, momentum=MOMENTUM,
                weight_decay=WEIGHT_DECAY)
criterion = MultiBoxLoss(BATCH_SIZE ,2, 0.5, True, 0, True, 3, 0.5,
                False, CUDA)
net.train()
loc_loss = 0
conf_loss = 0
epoch = 0

Resuming training, loading /home/arg_ws3/ssd.pytorch/weights/person/person_20000.pth...
Loading weights into state dict...
Finished!


# GAN

In [7]:
dataset_name = "gan_test_pre"
os.makedirs('images/%s' % dataset_name, exist_ok=True)
os.makedirs('saved_models/%s' % dataset_name, exist_ok=True)
cuda = True if torch.cuda.is_available() else False

In [8]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
criterion_bbx = torch.nn.L1Loss()

In [9]:
img_height = 256
img_width = 256

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height//2**4, img_width//2**4)

# Initialize generator and discriminator
generator = GeneratorUNet(in_channels=3, out_channels=3)
discriminator = Discriminator(in_channels=3)

In [10]:
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()
    criterion_bbx.cuda()
epoch = 0
if True:
    # Load pretrained models
    generator.load_state_dict(torch.load('/home/arg_ws3/ssd.pytorch/saved_models/gan_test/generator_14.pth'))
    #discriminator.load_state_dict(torch.load('saved_models/%s/discriminator_%d.pth' % (dataset_name, epoch)))
    discriminator.apply(weights_init_normal)
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

In [11]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Configure dataloaders
transforms_A = [ transforms.Resize((img_height, img_width), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
transforms_B = [ transforms.Resize((img_height, img_width), Image.BICUBIC),
                transforms.ToTensor() ]

inv_normalize = transforms.Compose([ 
                transforms.Normalize(mean = [ 0., 0., 0. ],
                                             std = [ 1/0.5, 1/0.5, 1/0.5 ]),
                transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
                                             std = [ 1., 1., 1. ]),
                ])

In [12]:
'''dataloader = DataLoader(ImageDataset(transform_=transforms_A),
                        batch_size=1, shuffle=True, num_workers=n_cpu)

val_dataloader = DataLoader(ImageDataset(transform_=transforms_A, mode='test'),
                            batch_size=1, shuffle=True, num_workers=1)

'''

"dataloader = DataLoader(ImageDataset(transform_=transforms_A),\n                        batch_size=1, shuffle=True, num_workers=n_cpu)\n\nval_dataloader = DataLoader(ImageDataset(transform_=transforms_A, mode='test'),\n                            batch_size=1, shuffle=True, num_workers=1)\n\n"

In [13]:
dataset = DATA_DETECTION(root=DATASET_ROOT, image_sets=['person_train'],transform=SSDAugmentation(cfg['min_dim'], MEANS))
data_loader = data.DataLoader(dataset, 1,
                                num_workers=NUM_WORKERS,
                                shuffle=True, collate_fn=detection_collate,
                                pin_memory=True)
batch_iterator = iter(data_loader)
a = next(batch_iterator)
print(a[1])
new_img=F.upsample(a[0], scale_factor=1/1.171875, mode='bilinear')

print(new_img.shape)

[tensor([[0.7024, 0.4157, 0.8204, 0.8196, 0.0000]], device='cpu')]
torch.Size([1, 3, 256, 256])


  "See the documentation of nn.Upsample for details.".format(mode))


In [14]:
'''batch_iterator = iter(dataloader)
a = next(batch_iterator)
print(a[1].shape)
a[0].shape
new_img=F.upsample(a[0], scale_factor=1.171875, mode='bilinear')
new_img = torch.cat((new_img, new_img, new_img), dim=1)
new_img.shape
print(a[1])
targets = [Variable(ann.cuda(), volatile=True) for ann in a[1]]
targets'''

"batch_iterator = iter(dataloader)\na = next(batch_iterator)\nprint(a[1].shape)\na[0].shape\nnew_img=F.upsample(a[0], scale_factor=1.171875, mode='bilinear')\nnew_img = torch.cat((new_img, new_img, new_img), dim=1)\nnew_img.shape\nprint(a[1])\ntargets = [Variable(ann.cuda(), volatile=True) for ann in a[1]]\ntargets"

In [15]:
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs, bbx = next(iter(data_loader))
    real = Variable(imgs.type(Tensor))
    real = F.upsample(real, scale_factor=1/1.171875, mode='bilinear')
    fake = generator(real)
    img_sample = torch.cat((real.data, fake.data), -1)
    save_image(img_sample, 'images/%s/%s.png' % (dataset_name, batches_done), nrow=5, normalize=False)

In [16]:
prev_time = time.time()
for epoch in range(0, EPOCH):
    for i, (img, targets) in enumerate(data_loader):
        # Model inputs
        if CUDA:
            img = Variable(img.cuda())
            targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
        
        img = F.upsample(img, scale_factor=1/1.171875, mode='bilinear')
        real_src_img = Variable(img.type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_src_img.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_src_img.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_trg_img = generator(real_src_img)
        ssd_input_img = F.upsample(fake_trg_img, scale_factor=1.171875, mode='bilinear')
        
        #ssd_input_img = torch.cat((ssd_input_img, ssd_input_img, ssd_input_img), dim=1)
        ssd_out = net(ssd_input_img)
        optimizer_SSD.zero_grad()
        #targets = torch.FloatTensor(bbx).cuda()
        #targets = targets.unsqueeze(0)
        #targets = targets.unsqueeze(0)
        loss_l, loss_c = criterion(ssd_out, targets)
        loss = loss_l + loss_c
        loss.backward(retain_graph=True)
        optimizer_SSD.step()
        
        pred_fake = discriminator(fake_trg_img, real_src_img)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_trg_img, real_src_img)
        #print(loss_pixel, loss_GAN, loss_bbx)

        # Total loss
        #loss_G = loss_GAN*0.5 + loss_pixel * lambda_pixel * 0.5 + loss
        loss_G = loss

        loss_G.backward(retain_graph=True)

        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        '''optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(fake_trg_img.detach(), real_src_img)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_trg_img.detach(), real_src_img)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()'''

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(data_loader) + i
        batches_left = EPOCH * len(data_loader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        '''sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s [loss: %f]" %
                                                        (epoch, EPOCH,
                                                        i, len(data_loader),
                                                        loss_D.item(), loss_G.item(),
                                                        loss_pixel.item(), loss_GAN.item(),
                                                        time_left, 
                                                        loss))'''
        sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [G loss: %f] [loss: %f] ETA: %s " %
                                                        (epoch, EPOCH,
                                                        i, len(data_loader),
                                                        loss_G.item(),
                                                        loss, time_left,))

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            #save_image(fake_trg_img.data[:1], 'images/%d.png' % batches_done, nrow=5, normalize=True)
            sample_images(batches_done)


    if SAVE_MODEL_EPOCH != -1 and epoch % SAVE_MODEL_EPOCH == 0:
        # Save model checkpoints
        print('Saving state, Epoch:', epoch)
        torch.save(ssd_net.state_dict(),'saved_models/%s/ssd_%d.pth' % (dataset_name, epoch))
        torch.save(generator.state_dict(), 'saved_models/%s/generator_%d.pth' % (dataset_name, epoch))
        torch.save(discriminator.state_dict(), 'saved_models/%s/discriminator_%d.pth' % (dataset_name, epoch))
torch.save(generator.state_dict(), 'saved_models/%s/generator.pth' % (dataset_name))
torch.save(discriminator.state_dict(), 'saved_models/%s/discriminator.pth' % (dataset_name))

  import sys


[Epoch 0/300] [Batch 7291/7292] [G loss: 3.018299] [loss: 3.018299] ETA: 4 days, 6:58:58.467663 Saving state, Epoch: 0
[Epoch 1/300] [Batch 7291/7292] [G loss: 2.761356] [loss: 2.761356] ETA: 4 days, 6:34:01.050696 Saving state, Epoch: 1
[Epoch 2/300] [Batch 7291/7292] [G loss: 3.828712] [loss: 3.828712] ETA: 4 days, 7:31:51.090922 Saving state, Epoch: 2
[Epoch 3/300] [Batch 2486/7292] [G loss: 4.284740] [loss: 4.284740] ETA: 4 days, 7:09:56.505126 7  

KeyboardInterrupt: 