In [None]:
import sys
sys.path.append("/content/drive/MyDrive/SNU")

In [None]:
import pickle
from __future__ import print_function
import itertools
import math

import torch
from torch import optim
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torch.autograd import Variable

import os
import numpy as np
from data import get_trn_loader, get_val_loader, get_test_loader

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

In [None]:
trn_loader, val_loader, test_loader = get_trn_loader(), get_val_loader(), get_test_loader()

  cpuset_checked))


In [None]:
def toYUV(rgb):
    rgb = rgb.numpy()
    R, G, B = rgb[0, :, :], rgb[1, :, :], rgb[2, :, :]
    Y = 0.299 * R + 0.587 * G + 0.114 *B
    U = -0.147 * R + -0.289 * G + 0.436 * G
    V = 0.615 * R + -0.515 * G - 0.100 * B
    return torch.from_numpy(np.asarray([Y, U, V]).reshape(3, 64, 64))
    
def toRGB(yuv, batchsize):
    """shape of yuv is bs x 3 x 64 x 64, ordered by YUV"""
    lst = []
    for data in yuv:
        Y, U, V = data[0, :, :], data[1, :, :], data[2, :, :]
        R = Y + 1.140 * V
        G = Y + (-0.395 * U) + (-0.581 * V)
        B = Y + 2.032 * U
        lst.append([R,G,B])
    return np.asarray(lst).reshape(batchsize, 3, 64, 64)#.clip(0, 255)

In [None]:
def extractGray(batchSize, yuv):
    lst = []
    for data in yuv:
        lst.append(data[0])
    return np.asarray(lst).reshape(batchSize, 1, 64, 64)

#Discriminator

In [None]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        self.cnn = nn.Sequential(
            # 3 x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # 64 x 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # 128 x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 256 x 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)

            # 512 x 4 x 4
        )
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1),
            nn.Sigmoid()
        )
    def forward(self, input):
        # input is real or fake colored image
        x = self.cnn(input)
        x = x.view(x.size(0), 512 * 4 * 4) # flatten it
        output = self.fc(x)
        return output.view(-1,1).squeeze(1)

# Generator

In [None]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()

        self.fc = nn.Linear(100, 1 * 64 * 64)
        self.conv1 = nn.Conv2d(2, 130, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(130)

        self.conv2 = nn.Conv2d(132, 66, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(66)

        self.conv3 = nn.Conv2d(68, 65, 3, 1, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(65)

        self.conv4 = nn.Conv2d(66, 65, 3, 1, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(65)

        self.conv5 = nn.Conv2d(66, 33, 3, 1, 1, bias=False)
        self.bn5 = nn.BatchNorm2d(33)

        self.conv6 = nn.Conv2d(34, 2, 3, 1, 1, bias=False)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, input, noise_pure):
        # input is grayscale image(Y of YUV), noise is random sampled noise
        noise = self.fc(noise_pure)
        noise = noise.view(noise.size(0), 1, 64, 64)

        # 2 x 64 x 64
        x = self.conv1(torch.cat([input, noise], dim=1))
        x = self.bn1(x)
        x = self.relu(x)

        # 130 x 64 x 64
        input2 = torch.cat([input, x ,noise], dim=1)
        # 132 x 64 x 64
        x = self.conv2(input2)
        x = self.bn2(x)
        x = self.relu(x)

        # 66 x 64 x 64
        input3 = torch.cat([input, x, noise], dim=1)
        # 68 x 64 x 64
        x = self.conv3(input3)
        x = self.bn3(x)
        x = self.relu(x)

        # 65 x 64 x 64
        input4 = torch.cat([input, x], dim=1)
        # 66 x 64 x 64
        x = self.conv4(input4)
        x = self.bn4(x)
        x = self.relu(x)

        # 65 x 64 x 64
        input5 = torch.cat([input, x], dim=1)
        # 66 x 64 x 64
        x = self.conv5(input5)
        x = self.bn5(x)
        x = self.relu(x)

        # 33 x 64 x 64
        input6 = torch.cat([input, x], dim=1)
        # 34 x 64 x 64
        x = self.conv6(input6)

        output = torch.cat([input, x], dim=1)
        return output

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:         # Conv weight init
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:  # BatchNorm weight init
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
netG = _netG().cuda()
netG.apply(weights_init)
print(netG)

netD = _netD().cuda()
netD.apply(weights_init)
print(netD)

_netG(
  (fc): Linear(in_features=100, out_features=4096, bias=True)
  (conv1): Conv2d(2, 130, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(130, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(132, 66, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(66, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(68, 65, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(65, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(66, 65, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn4): BatchNorm2d(65, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(66, 33, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn5): BatchNorm2d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(34, 2, kernel_size=(3

In [None]:
criterion = nn.BCELoss().cuda()
batchSize = 32


input = torch.FloatTensor(batchSize, 3, 64, 64).cuda()
noise = torch.FloatTensor(batchSize, 100).cuda()

label = torch.FloatTensor(batchSize).cuda()
real_label = 1
fake_label = 0
optimizerD = optim.Adam(netD.parameters(), lr=0.0002,betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002,betas=(0.5, 0.999))
result_dict= {}
loss_D, loss_G = [], []

In [None]:
outf= '/content/drive/MyDrive/SNU/CNN_distortion/result'

for epoch in range(1,300):
    for i, (data, _) in enumerate(trn_loader):
        data = data.cuda()
        batchSize = len(data)
        gray = extractGray(batchSize, data.cpu().numpy())
        grayv = Variable(torch.from_numpy(gray)).cuda()
        #############
        # D!        #
        #############
        netD.zero_grad()
        ##############
        # real image #
        ##############
        input.resize_as_(data).copy_(data)
        label.resize_(len(data)).fill_(real_label)

        inputv = Variable(input).cuda()
        labelv = Variable(label).cuda()

        output = netD(inputv)
        errD_real = criterion(output, labelv)
        errD_real.backward()
        D_x = output.data.mean()

        ##############
        # fake image #
        ##############
        noise.resize_(batchSize, 100).uniform_(0,1)
        noisev = Variable(noise).cuda()

        # create fake images
        fake = netG(grayv, noisev)

        # cal loss
        output = netD(fake.detach())
        labelv = Variable(label.fill_(fake_label)).cuda()
        errD_fake = criterion(output, labelv)
        errD_fake.backward()
        D_G_z1 = output.data.mean()

        errD = errD_real + errD_fake
        optimizerD.step()

        ##############
        # G!         #
        ##############
        netG.zero_grad()
        labelv = Variable(label.fill_(real_label)).cuda()
        output = netD(fake)

        errG = criterion(output, labelv)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()

        if ((i+1) % 100 == 0):
            if not os.path.exists('results/'):
                os.makedirs('results/')
            rgb = toRGB(fake.cpu().data.numpy(), batchSize)
            vutils.save_image(torch.from_numpy(rgb), '%s/fake_samples_epoch_%s.png' % (outf, str(epoch)+" "+str(i+1)))
    print(epoch)
    print(errD.data, errG.data)
    rgb = toRGB(fake.cpu().data.numpy(), batchSize)
    vutils.save_image(torch.from_numpy(rgb),'%s/fake_samples_epoch_%s.png' % (outf, epoch))
    loss_D.append(errD.data)
    loss_G.append(errG.data)
    result_dict = {"loss_D":loss_D,"loss_G":loss_G}
    pickle.dump(result_dict,open("{}/result_dict.p".format(outf),"wb"))
    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG.pth' % (outf))
    torch.save(netD.state_dict(), '%s/netD.pth' % (outf))

  cpuset_checked))


1
tensor(0.2448, device='cuda:0') tensor(2.5188, device='cuda:0')
2
tensor(4.0196, device='cuda:0') tensor(6.0146, device='cuda:0')
3
tensor(2.1774, device='cuda:0') tensor(5.0486, device='cuda:0')
4
tensor(0.0406, device='cuda:0') tensor(5.6102, device='cuda:0')
5
tensor(0.9352, device='cuda:0') tensor(2.8675, device='cuda:0')
6
tensor(0.1135, device='cuda:0') tensor(3.6036, device='cuda:0')
