In [1]:
import argparse
import itertools

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import hiddenlayer as hl
import graphviz
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [3]:
import glob
import random
import os
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

In [4]:
import visdom
import numpy as np
vis = visdom.Visdom()
vis.text('Hello, world!')
vis.image(np.ones((3, 10, 10)))



'window_3753268d5a6c5a'

In [13]:
import random
import time
import datetime
import sys

from torch.autograd import Variable
import torch
from visdom import Visdom
import numpy as np

def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)

class Logger():
    def __init__(self, n_epochs, batches_epoch):
        self.viz = Visdom()
        self.n_epochs = n_epochs
        self.batches_epoch = batches_epoch
        self.epoch = 1
        self.batch = 1
        self.prev_time = time.time()
        self.mean_period = 0
        self.losses = {}
        self.loss_windows = {}
        self.image_windows = {}


    def log(self, losses=None, images=None):
        self.mean_period += (time.time() - self.prev_time)
        self.prev_time = time.time()

        sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch))

        for i, loss_name in enumerate(losses.keys()):
            if loss_name not in self.losses:
                self.losses[loss_name] = losses[loss_name].data[0]
            else:
                self.losses[loss_name] += losses[loss_name].data[0]

            if (i+1) == len(losses.keys()):
                sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch))
            else:
                sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch))

        batches_done = self.batches_epoch*(self.epoch - 1) + self.batch
        batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch 
        sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done)))

        # Draw images
        for image_name, tensor in images.items():
            if image_name not in self.image_windows:
                self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name})
            else:
                self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name})

        # End of epoch
        if (self.batch % self.batches_epoch) == 0:
            # Plot losses
            for loss_name, loss in self.losses.items():
                if loss_name not in self.loss_windows:
                    self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), 
                                                                    opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name})
                else:
                    self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append')
                # Reset losses for next epoch
                self.losses[loss_name] = 0.0

            self.epoch += 1
            self.batch = 1
            sys.stdout.write('\n')
        else:
            self.batch += 1
            
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode='train'):
        transforms_ = [ transforms.Resize(int(143), Image.BICUBIC), 
                transforms.RandomCrop(128), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 
              ]
        
        self.transform = transforms.Compose(transforms_)        
        #content image source
        self.X = sorted(glob.glob(os.path.join(root, f'{mode}Content', '*')))
        
        self.Y = []
        #style image source(s)
        style_sources = sorted(glob.glob(os.path.join(root, f'{mode}Styles', '*')))
        for label,style in enumerate(style_sources):
            temp = [(label,x) for x in sorted(glob.glob(style_sources[label]+"/*"))]
            self.Y.extend(temp)
        
        
    def __getitem__(self,index):
        output = {}
        output['content'] = self.transform(Image.open(self.X[index % len(self.X)]))
        
        #select style
        selection = self.Y[random.randint(0, len(self.Y) - 1)]
        output['style'] = self.transform(Image.open(selection[1]))
        output['style_label'] = selection[0]
    
        return output
    
    def __len__(self):
        return max(len(self.X), len(self.Y))
        
        
        
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)
        

In [14]:
#https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py
class TVLoss(nn.Module):
    def __init__(self,TVLoss_weight=tv_strength or None):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]
    
def label2tensor(label,tensor):
    for i in range(label.size(0)):
        tensor[i].fill_(label[i])
    return tensor

In [15]:
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
    
    
class ResidualBlock(nn.Module):
    def __init__(self,in_features):
        super(ResidualBlock,self).__init__()
        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x) #skip connection    

class Encoder(nn.Module):    
    def __init__(self, in_nc, ngf=64):
        super(Encoder, self).__init__()
        
        #Inital Conv Block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(in_nc, ngf, 7),
                    nn.InstanceNorm2d(ngf),
                    nn.ReLU(inplace=True) ]
        
        in_features = ngf
        out_features = in_features *2
        
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)    
            ]
            
            in_features = out_features
            out_features = in_features * 2
            
        self.model = nn.Sequential(*model)
        
    def forward(self,x):
        #Return batch w/ encoded content picture
        return [self.model(x['content']),
               x['style_label']]
    
class Transformer(nn.Module):
    def __init__(self,n_styles, ngf,auto_id=True):
        super(Transformer, self).__init__()
        
        #nclasses = input_nclasses
        self.t = nn.ModuleList([ResidualBlock(ngf*4) for i in range(n_styles)])
        if auto_id:
            self.t.append(Identity())
        #self.i = Identity()
                
    def forward(self,x):
        #x0 is content, x1 is label 
        label = x[1][0]
#         print(label)
#         print(len(label))
#         print(label.shape)
        mix = sum([self.t[i](x[0])*v for (i,v) in enumerate(label)])
        #return content transformed by style specific residual block 
        return mix
        
class Decoder(nn.Module):
    def __init__(self, out_nc, ngf, n_residual_blocks=9):
        super(Decoder, self).__init__()
        
        in_features = ngf * 4
        out_features = in_features//2
        
        model = []
        #ResBlockLand
        
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, out_nc, 7),
                    nn.Tanh() ]
        
        self.model = nn.Sequential(*model)
        
    def forward(self,x):
        return self.model(x)
        
class Generator(nn.Module):
    def __init__(self,in_nc,out_nc,n_styles,ngf):
        super(Generator, self).__init__()
        
        self.encoder = Encoder(in_nc,ngf)
        self.transformer = Transformer(n_styles,ngf)
        self.decoder = Decoder(out_nc,ngf)
        
    def forward(self,x):
        #Pass generator batch of {content=,style=?,style_label=}
        print('generator shape/style in: ',x['content'].shape, x['style_label'])
        
        e = self.encoder(x)
        print(e[0].shape,e[1])
        t = self.transformer(e)
        print(t.shape)
        d = self.decoder(t)
        print(d.shape)
        return d
    
class Discriminator(nn.Module):
    def __init__(self, input_nc, n_styles, ndf=64):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(in_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]
        
        
        self.model = nn.Sequential(*model)

        # FCN classification layer-
        self.fldiscriminator = nn.Conv2d(512, 1, 4, padding=1)
        
        # aux class layer
        self.aux_clf = nn.Conv2d(512, n_styles, 4, padding = 4)

    def forward(self, x):
        base =  self.model(x)
        print('base shape: ',base.shape)
        discrim = self.fldiscriminator(base)
        print('init discrim shape: ',discrim.shape)
        # Average pooling and flatten
        discrim=F.avg_pool2d(discrim, discrim.size()[2:]).view(discrim.size()[0], -1) 
        print('preview discrim: ',discrim)
        print('preview discrim shape: ',discrim.size())
        discrim = discrim.view(-1)
        print('discrim shape: ',discrim.size())
        print('discrim: ',discrim)
        clf = self.aux_clf(base).transpose_(1,3)
        print('clf shape: ',clf.shape)
        print('clf transpose: ',clf.transpose_(1,3).shape)
        
        return [discrim,clf.transpose_(1,3)]
            

In [None]:
#TRAIN OPTIONS FROM GATED GAN
epoch = 0
n_epochs = 10 #default = 200
batchSize = 1
dataroot = './photo2fourcollection'
batchSize = 1
loadSize = 143
fineSize = 128
ngf = 64
ndf = 64    
in_nc = 3 
out_nc = 3 
niter = 100  
niter_decay = 100 
lr = 0.0002 
beta1 = 0.5 
#ntrain = math.huge 
flip = 1  
display_id = 10 
display_winsize = 128 
display_freq = 25 
gpu = 1 
name = ''   
which_direction = 'AtoB'
phase = 'train'
nThreads = 2
save_epoch_freq = 1
save_latest_freq = 5000 
print_freq = 50
save_display_feq = 2500
continue_train = 0
serial_batches = 0
checkpoints_dir = './checkpoints'
cudnn = 1
which_model_netD = 'basic'
which_model_netG = 'auto_gated_resnet_6blocks'
norm = 'instance'
n_layers_D = 3
lambda_A = 10.0
lambda_B = 10.0
model = 'gated_gan'
use_lsgan = 1
align_data = 0
pool_size = 50
resize_or_crop = 'resize_and_crop'
autoencoder_constrain = 10 
n_styles = 4
test_data_path = ''
decay_epoch=1
cuda=False
tv_strength=1e-6


In [17]:
dataloader = DataLoader(ImageDataset('./photo2fourcollection'), 
                        batch_size=1, shuffle=True)
batch = next(iter(dataloader))


In [18]:
generator = Generator(in_nc, out_nc, n_styles, ngf)
discriminator= Discriminator(in_nc,n_styles, ndf)

In [27]:
#Losses Init
use_lsgan=True
if use_lsgan:
    criterion_GAN = nn.MSELoss()
else: 
    criterion_GAN = nn.BCELoss()
    
    
criterion_ACGAN = nn.CrossEntropyLoss(weight=None)
criterion_Rec = nn.L1Loss()
criterion_Enc = nn.MSELoss()
criterion_TV = TVLoss(TVLoss_weight=tv_strength)




In [28]:
#Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(generator.parameters(),
                                lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), 
                               lr=lr, betas=(0.5, 0.999))


lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch,decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_epochs,epoch, decay_epoch).step)



In [43]:
#Set vars for training
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batchSize, in_nc, fineSize, fineSize)
input_B = Tensor(batchSize, out_nc, fineSize, fineSize)
target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

D_A_size = discriminator(real_style)[0].size()  
D_AC_size = discriminator(real_style)[1].size()

class_label_B = torch.Tensor(D_AC_size[0],D_AC_size[1],D_AC_size[2]).long()

autoflag_OHE = torch.Tensor(1,n_styles+1).fill_(0).long()
autoflag_OHE[0][-1] = 1

fake_label = torch.Tensor(D_A_size).fill_(0.0)
real_label = torch.Tensor(D_A_size).fill_(0.9) 

rec_A_AE = torch.Tensor(batchSize,in_nc,fineSize,fineSize)

fake_buffer = ReplayBuffer()

##INIT WEIGHTS
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

base shape:  torch.Size([1, 512, 15, 15])
init discrim shape:  torch.Size([1, 1, 14, 14])
preview discrim:  tensor([[0.3353]], grad_fn=<ViewBackward>)
preview discrim shape:  torch.Size([1, 1])
discrim shape:  torch.Size([1])
discrim:  tensor([0.3353], grad_fn=<ViewBackward>)
clf shape:  torch.Size([1, 20, 20, 4])
clf transpose:  torch.Size([1, 4, 20, 20])
base shape:  torch.Size([1, 512, 15, 15])
init discrim shape:  torch.Size([1, 1, 14, 14])
preview discrim:  tensor([[0.3353]], grad_fn=<ViewBackward>)
preview discrim shape:  torch.Size([1, 1])
discrim shape:  torch.Size([1])
discrim:  tensor([0.3353], grad_fn=<ViewBackward>)
clf shape:  torch.Size([1, 20, 20, 4])
clf transpose:  torch.Size([1, 4, 20, 20])


Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace)
  )
  (fldiscriminator): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (aux_clf): Conv2d(512, 4, kernel_size=(4, 4), stride=(1, 1), padding=(4, 4))
)

In [44]:
real_content = Variable(input_A.copy_(batch['content']))
real_style = Variable(input_B.copy_(batch['style']))
style_label = batch['style_label']
style_OHE = F.one_hot(style_label,n_styles).long()


In [48]:
logger = Logger(n_epochs, len(dataloader))




In [49]:
###FUCKIN TRAIN TIME BABY
for epoch in range(epoch,n_epochs):
    for i, batch in enumerate(dataloader):
        real_content = Variable(input_A.copy_(batch['content']))
        real_style = Variable(input_B.copy_(batch['style']))
        style_label = batch['style_label']
        style_OHE = F.one_hot(style_label,n_styles).long()
        
        #### GENERATOR FORWARD
        optimizer_G.zero_grad()
        #Auto Encoder Reconstruction Loss
        rec = generator({
            'content':real_content,
            'style_label': autoflag_OHE # 1-n_styles +1 ((nstyles))
        })
        errRec = criterion_Rec(rec,real_content)

        #Gan Loss
        genfake = generator({
            'content':real_content,
            'style_label': style_OHE
        })
        output = discriminator(genfake)
        errG = criterion_GAN(output[0],torch.tensor([0.]))

        #https://github.com/pytorch/pytorch/issues/29
        #Aux Class Loss
        errG_AC = criterion_ACGAN(output[1].transpose_(1,3),class_label_B)

        tvloss = criterion_TV(genfake)
        errG_total = errRec*autoencoder_constrain + errG_AC*lambda_A + errG + tvloss

        errG_total.backward()
        optimizer_G.step()
        
        #### DISCRIMINATOR FORWARD
        optimizer_D.zero_grad()

        #Real Loss
        output = discriminator(real_style)
        errD_real = criterion_GAN(output[0],torch.tensor([.9]))

        errD_real_class = criterion_ACGAN(output[1].transpose(1,3),class_label_B)

        #Fake Loss
        fake = fake_buffer.push_and_pop(genfake)
        out_real, out_class = discriminator(fake)
        errD_fake = criterion_GAN(out_real,torch.tensor([0.]))

        errD = ((errD_real+errD_fake)/2.0)+errD_real_class
        errD.backward()

        optimizer_D.step()
        
        #Progress report (port 8097)
        logger.log({'loss_G': errG_total, 'loss_G_AE': errRec, 'loss_G_GAN': errG,
                    'loss_G_AC': errG_AC, 'loss_D': errD, 'errD_fake':errD_fake,
                   'errD_real':errD_real, 'errD_class': errD_real_class}, 
                    images={'content': real_content, 'style': real_style, 'transfer': genfake})
    
    ##update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    
    #Save model
    torch.save(netG.state_dict(), 'output/netG.pth')
    torch.save(netD.state_dict(), 'output/netD.pth')

        


generator shape/style in:  torch.Size([1, 3, 128, 128]) tensor([[0, 0, 0, 0, 1]])
torch.Size([1, 256, 32, 32]) tensor([[0, 0, 0, 0, 1]])
torch.Size([1, 256, 32, 32])
torch.Size([1, 3, 128, 128])
generator shape/style in:  torch.Size([1, 3, 128, 128]) tensor([[0, 1, 0, 0]])
torch.Size([1, 256, 32, 32]) tensor([[0, 1, 0, 0]])
torch.Size([1, 256, 32, 32])
torch.Size([1, 3, 128, 128])
base shape:  torch.Size([1, 512, 15, 15])
init discrim shape:  torch.Size([1, 1, 14, 14])
preview discrim:  tensor([[0.5494]], grad_fn=<ViewBackward>)
preview discrim shape:  torch.Size([1, 1])
discrim shape:  torch.Size([1])
discrim:  tensor([0.5494], grad_fn=<ViewBackward>)
clf shape:  torch.Size([1, 20, 20, 4])
clf transpose:  torch.Size([1, 4, 20, 20])
base shape:  torch.Size([1, 512, 15, 15])
init discrim shape:  torch.Size([1, 1, 14, 14])
preview discrim:  tensor([[1.2272]], grad_fn=<ViewBackward>)
preview discrim shape:  torch.Size([1, 1])
discrim shape:  torch.Size([1])
discrim:  tensor([1.2272], grad

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number