# S-GAN sample test

In [1]:
import argparse
from loss import sganloss
import os
import numpy as np
from dataloader import *
import math

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
cuda = True if torch.cuda.is_available() else False 
cuda

False

In [3]:
def weights_init_normal(m): #?
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class IdentityPadding(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(IdentityPadding, self).__init__()

        self.pooling = nn.MaxPool2d(1, stride=stride)
        self.add_channels = out_channels - in_channels

    def forward(self, x):
        out = F.pad(x, (0, 0, 0, 0, 0, self.add_channels))
        out = self.pooling(out)
        return out

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, down_sample=False):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.stride = stride

        if down_sample:
            self.down_sample = IdentityPadding(in_channels, out_channels, stride)
        else:
            self.down_sample = None

    def forward(self, x):
        shortcut = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.down_sample is not None:
            shortcut = self.down_sample(x)

        out += shortcut
        out = self.relu(out)
        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.FaceOcclusion_1=nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            # -----
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(),
            # -----
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
            # -----
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU()
            # -----
        )
        self.FaceOcclusion_2=nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=7, stride=1, padding=3),
            nn.Sigmoid()
        )

        self.FaceCompletion=nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.ReLU(),
            # -----
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            # -----
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        # occlusion aware module
        out_predicted=self.FaceOcclusion_1(x)
        # out_InvertedM = torch.ones(1, 1, 128, 128).cuda() - x
        out_predictedM=self.FaceOcclusion_2(out_predicted)
        out_InvertedM=torch.ones(1, 1, 128, 128) - out_predictedM
        out_oa=torch.matmul(out_predicted, out_predictedM)

        # face completion module
        out_synth=self.FaceCompletion(out_oa)
        out_fc=torch.matmul(out_InvertedM, out_synth)
        out_filter=torch.matmul(x, out_predictedM)
        out_final=out_filter + out_fc

        
        return out_predictedM, out_InvertedM, out_synth, out_final

In [4]:

class weight():

    def __init__(self):
        self.lam1 = 0.1
        self.lam2 = 0.2
        self.lam3 = 0.2
        self.lam4 = 0.2
        self.lam5 = 0.1
        self.lam6 = 0.2
        self.alpha = 0.5
        self.beta = 0.5

w=weight()

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator_block = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU()
        )

        self.adv_layer = nn.Sequential(nn.Conv2d(2048, 1, kernel_size=3, stride=1, padding=1),
                                       nn.Sigmoid()
                                       )
        self.attr_layer = nn.Sequential(nn.Conv2d(2048, 10, kernel_size=2, stride=1, padding=0),
                                        nn.Softmax())  # attribute classification대신 얼굴 인식 수행

    def forward(self, x):
        out = self.discriminator_block(x)
        # out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.attr_layer(out)

        return validity, label

adversarial_loss = nn.BCELoss()
attribute_loss = nn.MSELoss()  # discriminator에 사용되는 attribute loss

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    attribute_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

Discriminator(
  (discriminator_block): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.01)
    (8): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): LeakyReLU(negative_slope=0.01)
    (10): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (11): LeakyReLU(negative_slope=0.01)
  )
  (adv_layer): Sequential(
    (0): Conv2d(2048, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Sigmoid()
  )
  (attr_layer): Sequential(
    (0): Conv2d(2048, 10, kernel_size=(2, 2), stride=(1, 1))
    (1): Softmax(dim=None)

## Dataloader에서 sample batch로 test

In [6]:
def show(img,y,color=False): #미리보기
    npimg=img.numpy()
    y=y.numpy()
    npimg_tr=np.transpose(npimg,(1,2,0))
    y_tr=np.transpose(y,(1,2,0))
    plt.subplot(2,2,1)
    plt.imshow(npimg_tr)
    plt.subplot(2,2,2)
    plt.imshow(y_tr)

OAGan_dataset = OAGandataset( paired = True, folder_numbering = False )
train_dataloader = DataLoader(OAGan_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=3) #3 batch

dataiter = iter(train_dataloader)

example_batch = next(dataiter)

In [None]:
show(example_batch[0][0],example_batch[1][0])
print(example_batch[2][0])

In [None]:
show(example_batch[0][1],example_batch[1][1])
print(example_batch[2][1])

In [None]:
show(example_batch[0][2],example_batch[1][2])
print(example_batch[2][2])

In [7]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [8]:
imgs,imgs_gt,labels=example_batch

batch_size = imgs.shape[0]

# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(0.0), requires_grad=False)
fake_attr_gt = Variable(LongTensor(batch_size).fill_(10), requires_grad=False)

# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))

# -----------------
#  Train Generator
# -----------------

optimizer_G.zero_grad()

# Sample noise and labels as generator input
# z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

# Generate a batch of images
# gen_imgs = generator(z)
print("real_imgs: ", real_imgs.shape) #x_occ

real_imgs:  torch.Size([3, 3, 128, 128])


In [None]:
def show1(img,y,color=False): #미리보기
    #npimg=img.detach().numpy()
    #npimg=img.numpy()
    #y=y.numpy()
    npimg_tr=np.transpose(img,(1,2,0))
    y_tr=np.transpose(y,(1,2,0))
    plt.subplot(2,2,1)
    plt.imshow(npimg_tr)
    plt.subplot(2,2,2)
    plt.imshow(y_tr)
    
#out_synth[0]

In [9]:
out_predictionM, out_InvertedM, out_synth, out_final = generator(real_imgs)
loss = sganloss([out_final,
                out_predictionM,
                out_InvertedM,
                imgs_gt,
                out_synth])
                

#print("gen_imgs: ", gen_imgs.shape)
print("predictM: ",out_predictionM.shape)
print("out_InvertedM: ", out_InvertedM.shape) 
print("out_synth: ", out_synth.shape)
print("out_final: ", out_final.shape)
print ("imgs_gt:", imgs_gt.shape)

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[3, 1, 128, 128] to have 3 channels, but got 1 channels instead

In [None]:
show1(np.array(out_predictionM[0].detach()),np.array(out_InvertedM[0].detach()))

In [None]:
show1(np.array(out_synth[0].detach()),np.array(out_final[0].detach()))
#print(example_batch[2][2])

In [None]:
#validity, _ = discriminator(out_final)
#print('validity shape: ', validity.shape)
#print('valid shape: ', valid.shape)

g_loss = 0
g_loss += w.lam1*loss.perceptual_loss(out_synth, out_final, imgs_gt)
g_loss += w.lam2*loss.style_loss(out_synth, out_final, imgs_gt)
g_loss += w.lam3*loss.pixel_loss(out_final, imgs_gt, out_InvertedM, out_predictionM, w.alpha, w.beta)  
g_loss += w.lam4*loss.smooth_loss(out_final, imgs_gt, out_predictionM)
g_loss += w.lam5*loss.l2_norm(out_predictionM)
g_loss += w.lam6*loss.adversarial_loss(validity,valid) 

#g_loss += w.lam1*loss.pixel_loss(out_final, imgs_gt, out_InvertedM, out_predictionM, w.alpha, w.beta)  
#g_loss += w.lam2*loss.smooth_loss(out_final, imgs_gt, out_predictionM)
#g_loss += w.lam3*loss.perceptual_loss(out_synth, out_final, imgs_gt)
#g_loss += w.lam4*loss.style_loss(out_synth, out_final, imgs_gt)
#g_loss += w.lam5*loss.l2_norm(out_predictionM)
#g_loss += w.lam6*adversarial_loss(validity, valid)

#g_loss.backward()
#optimizer_G.step()

In [None]:
g_loss

In [None]:
imgs,imgs_gt,labels=example_batch

batch_size = imgs.shape[0]

# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(0.0), requires_grad=False)
fake_attr_gt = Variable(LongTensor(batch_size).fill_(10), requires_grad=False)

# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))

# -----------------
#  Train Generator
# -----------------

optimizer_G.zero_grad()

# Sample noise and labels as generator input
# z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

# Generate a batch of images
# gen_imgs = generator(z)
print("real_imgs: ", real_imgs.shape)
out_predictionM, out_InvertedM, out_synth, out_final = generator(real_imgs)
print("gen_imgs: ", gen_imgs.shape)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        print('validity shape: ', validity.shape)
        print('valid shape: ', valid.shape)
        
        
        g_loss += loss.perceptual_loss(out_synth, out_final, )
        g_loss += adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # d_alpha, d_beta는 discriminator에 사용되는 2가지 loss함수에 대한 가중치값으로 우리가 결정해야 하는듯
        d_alpha = 0.5
        d_beta = 0.5

        # Loss for real images
        real_pred, real_attr = discriminator(real_imgs)
        # d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
        d_real_loss = d_alpha * adversarial_loss(real_pred, valid) + d_beta * attribute_loss(real_attr, labels)

        # Loss for fake images
        fake_pred, fake_attr = discriminator(gen_imgs.detach())
        # d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2
        d_fake_loss = d_alpha * adversarial_loss(fake_pred, fake) + d_beta * attribute_loss(fake_attr, fake_attr_gt)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy 
        pred = np.concatenate([real_attr.data.cpu().numpy(), fake_attr.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_attr_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        # print('d_loss type: ', type(d_loss))
        d_loss = d_loss.type(torch.FloatTensor)
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, 100, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


# net = SiameseNetwork().cuda()
# criterion = ContrastiveLoss()
# optimizer = optim.Adam(net.parameters(),lr = 0.0005 )
# counter = []
# loss_history = []
# iteration_number= 0


---------------

In [None]:
paired_dataset = OAGandataset(paired=True, folder_numbering=False)
#unpaired_dataset = OAGandataset(unpaired=True, folder_numbering=False)

train_dataloader_p = DataLoader(paired_dataset,
                                shuffle=True,
                                num_workers=0,
                                batch_size= 5) #batch size?
# #train_dataloader_up = DataLoader(unpaired_dataset,
#                             shuffle=True,
#                             num_workers=0,
#                             batch_size=30)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

loss = sganloss()

In [None]:
for i, (imgs, imgs_gt, labels) in enumerate(train_dataloader_p):
    print(labels)
    print(imgs_gt.shape)

In [None]:
#paired image training (unpaired도 따로 만들고, loss도 상황에 따라 적용)
for epoch in range(100):
    for i, (imgs, imgs_gt, labels) in enumerate(train_dataloader_p):
        #print(imgs.shape)
        #print(imgs_gt.shape)
        #print(labels.shape)
        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1, 2, 2).fill_(0.0), requires_grad=False)
        fake_attr_gt = Variable(LongTensor(batch_size).fill_(10), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        # z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

        # Generate a batch of images
        # gen_imgs = generator(z)
        print("real_imgs: ", real_imgs.shape)
        out_predictionM, out_InvertedM, out_synth, out_final = generator(real_imgs)
        print("gen_imgs: ", gen_imgs.shape)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        print('validity shape: ', validity.shape)
        print('valid shape: ', valid.shape)
        
        
        g_loss += loss.perceptual_loss(out_synth, out_final, )
        g_loss += adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # d_alpha, d_beta는 discriminator에 사용되는 2가지 loss함수에 대한 가중치값으로 우리가 결정해야 하는듯
        d_alpha = 0.5
        d_beta = 0.5

        # Loss for real images
        real_pred, real_attr = discriminator(real_imgs)
        # d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
        d_real_loss = d_alpha * adversarial_loss(real_pred, valid) + d_beta * attribute_loss(real_attr, labels)

        # Loss for fake images
        fake_pred, fake_attr = discriminator(gen_imgs.detach())
        # d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2
        d_fake_loss = d_alpha * adversarial_loss(fake_pred, fake) + d_beta * attribute_loss(fake_attr, fake_attr_gt)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy 
        pred = np.concatenate([real_attr.data.cpu().numpy(), fake_attr.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_attr_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        # print('d_loss type: ', type(d_loss))
        d_loss = d_loss.type(torch.FloatTensor)
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, 100, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
