In [1]:
!git clone https://github.com/pfnet-research/chainer-pix2pix.git

Cloning into 'chainer-pix2pix'...
remote: Counting objects: 50, done.[K
remote: Total 50 (delta 0), reused 0 (delta 0), pack-reused 49[K
Unpacking objects: 100% (50/50), done.


In [1]:
!apt-get -qq -y install libcusparse8.0 libnvrtc8.0 libnvtoolsext1 > /dev/null
!ln -snf /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so.8.0 /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so
!pip -q install https://github.com/kmaehashi/chainer-colab/releases/download/2018-02-06/cupy_cuda80-4.0.0b3-cp36-cp36m-linux_x86_64.whl
!pip -q install 'chainer==4.0.0b3'
!apt-get -qq -y install xvfb freeglut3-dev ffmpeg> /dev/null
!pip -q install chainerrl
!pip -q install gym
!pip -q install pyglet
!pip -q install pyopengl
!pip -q install pyvirtualdisplay
!pip install scikit-image

Extracting templates from packages: 100%


# mnistで確認

In [39]:
#!/usr/bin/env python

from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions


# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=1,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    args = parser.parse_args(args=[])

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()


if __name__ == '__main__':
    main()

GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 20



Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...


epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J     total [..................................................]  0.83%
this epoch [########..........................................] 16.67%
       100 iter, 0 epoch / 20 epochs
       inf iters/sec. Estimated time to finish: 0:00:00.
[4A[J     total [..................................................]  1.67%
this epoch [################..................................] 33.33%
       200 iter, 0 epoch / 20 epochs
    175.78 iters/sec. Estimated time to finish: 0:01:07.128553.
[4A[J     total [#.................................................]  2.50%
this epoch [#########################.........................] 50.00%
       300 iter, 0 epoch / 20 epochs
    176.27 iters/sec. Estimated time to finish: 0:01:06.376843.
[4A[J     total [#.................................................]  3.33%
this epoch [#################################.................] 66.67%
       400 i

[J     total [###########.......................................] 22.50%
this epoch [#########################.........................] 50.00%
      2700 iter, 4 epoch / 20 epochs
    146.15 iters/sec. Estimated time to finish: 0:01:03.634723.
[4A[J     total [###########.......................................] 23.33%
this epoch [#################################.................] 66.67%
      2800 iter, 4 epoch / 20 epochs
    147.04 iters/sec. Estimated time to finish: 0:01:02.566602.
[4A[J     total [############......................................] 24.17%
this epoch [#########################################.........] 83.33%
      2900 iter, 4 epoch / 20 epochs
    147.87 iters/sec. Estimated time to finish: 0:01:01.538462.
[4A[J5           0.0290892   0.0758302             0.990582       0.9808                    29.1014       
[J     total [############......................................] 25.00%
this epoch [..................................................]  0.00%


[J     total [######################............................] 44.17%
this epoch [#########################################.........] 83.33%
      5300 iter, 8 epoch / 20 epochs
    146.81 iters/sec. Estimated time to finish: 0:00:45.636490.
[4A[J9           0.0176098   0.0956173             0.994515       0.978                     45.564        
[J     total [######################............................] 45.00%
this epoch [..................................................]  0.00%
      5400 iter, 9 epoch / 20 epochs
    144.54 iters/sec. Estimated time to finish: 0:00:45.661141.
[4A[J     total [######################............................] 45.83%
this epoch [########..........................................] 16.67%
      5500 iter, 9 epoch / 20 epochs
    144.98 iters/sec. Estimated time to finish: 0:00:44.834738.
[4A[J     total [#######################...........................] 46.67%
this epoch [################..................................] 33.33%


[J     total [################################..................] 65.83%
this epoch [########..........................................] 16.67%
      7900 iter, 13 epoch / 20 epochs
    145.14 iters/sec. Estimated time to finish: 0:00:28.249338.
[4A[J     total [#################################.................] 66.67%
this epoch [################..................................] 33.33%
      8000 iter, 13 epoch / 20 epochs
    145.45 iters/sec. Estimated time to finish: 0:00:27.501333.
[4A[J     total [#################################.................] 67.50%
this epoch [#########################.........................] 50.00%
      8100 iter, 13 epoch / 20 epochs
    145.74 iters/sec. Estimated time to finish: 0:00:26.759661.
[4A[J     total [##################################................] 68.33%
this epoch [#################################.................] 66.67%
      8200 iter, 13 epoch / 20 epochs
    146.02 iters/sec. Estimated time to finish: 0:00:26.023783.


[J     total [###########################################.......] 87.50%
this epoch [#########################.........................] 50.00%
     10500 iter, 17 epoch / 20 epochs
    144.64 iters/sec. Estimated time to finish: 0:00:10.370764.
[4A[J     total [############################################......] 88.33%
this epoch [#################################.................] 66.67%
     10600 iter, 17 epoch / 20 epochs
    146.63 iters/sec. Estimated time to finish: 0:00:09.547604.
[4A[J     total [############################################......] 89.17%
this epoch [#########################################.........] 83.33%
     10700 iter, 17 epoch / 20 epochs
    146.63 iters/sec. Estimated time to finish: 0:00:08.865591.
[4A[J18          0.00792945  0.107475              0.997616       0.9821                    82.6932       
[J     total [#############################################.....] 90.00%
this epoch [..................................................]  0.0

In [0]:
#gpuに乗ってることを確認

In [97]:
from chainer import cuda
cuda.available


True

In [98]:
!pip list | grep cupy


cupy-cuda80 (4.0.0b3)


In [61]:
!pip install --upgrade pandas

Requirement already up-to-date: pandas in /usr/local/lib/python3.6/dist-packages
Collecting pytz>=2011k (from pandas)
  Downloading pytz-2018.3-py2.py3-none-any.whl (509kB)
[K    100% |████████████████████████████████| 512kB 2.0MB/s 
[?25hCollecting numpy>=1.9.0 (from pandas)
  Downloading numpy-1.14.1-cp36-cp36m-manylinux1_x86_64.whl (12.2MB)
[K    100% |████████████████████████████████| 12.2MB 116kB/s 
[?25hCollecting python-dateutil>=2 (from pandas)
  Downloading python_dateutil-2.6.1-py2.py3-none-any.whl (194kB)
[K    100% |████████████████████████████████| 194kB 6.1MB/s 
[?25hRequirement already up-to-date: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2->pandas)
Installing collected packages: pytz, numpy, python-dateutil
  Found existing installation: pytz 2016.7
    Uninstalling pytz-2016.7:
      Successfully uninstalled pytz-2016.7
  Found existing installation: numpy 1.14.0
    Uninstalling numpy-1.14.0:
      Successfully uninstalled numpy-1

In [0]:
!mkdir dataset

In [17]:
!wget https://www.dropbox.com/s/ak394rqfx00r4cr/facades.zip

--2018-03-06 06:58:13--  https://www.dropbox.com/s/ak394rqfx00r4cr/facades.zip
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.1, 2620:100:601c:1::a27d:601
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.1|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://dl.dropboxusercontent.com/content_link/P7Jxo93wFblBYzZuZpRYStbgwfGofj2I1S67uo7o9sJRShXmmy5d0ee9JniJek9O/file [following]
--2018-03-06 06:58:13--  https://dl.dropboxusercontent.com/content_link/P7Jxo93wFblBYzZuZpRYStbgwfGofj2I1S67uo7o9sJRShXmmy5d0ee9JniJek9O/file
Resolving dl.dropboxusercontent.com (dl.dropboxusercontent.com)... 162.125.6.6, 2620:100:601c:6::a27d:606
Connecting to dl.dropboxusercontent.com (dl.dropboxusercontent.com)|162.125.6.6|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9371662 (8.9M) [application/zip]
Saving to: ‘facades.zip’


2018-03-06 06:58:14 (75.0 MB/s) - ‘facades.zip’ saved [9371662/9371662]



In [19]:
!unzip facades.zip

Archive:  facades.zip
   creating: facades/
   creating: facades/test/
   creating: facades/train/
   creating: facades/test/b/
  inflating: facades/test/b/cmp_b0328.jpg  
  inflating: facades/test/b/cmp_b0327.jpg  
  inflating: facades/test/b/cmp_b0326.jpg  
  inflating: facades/test/b/cmp_b0325.jpg  
  inflating: facades/test/b/cmp_b0324.jpg  
  inflating: facades/test/b/cmp_b0323.jpg  
  inflating: facades/test/b/cmp_b0322.jpg  
  inflating: facades/test/b/cmp_b0321.jpg  
  inflating: facades/test/b/cmp_b0320.jpg  
  inflating: facades/test/b/cmp_b0319.jpg  
  inflating: facades/test/b/cmp_b0318.jpg  
  inflating: facades/test/b/cmp_b0317.jpg  
  inflating: facades/test/b/cmp_b0316.jpg  
  inflating: facades/test/b/cmp_b0315.jpg  
  inflating: facades/test/b/cmp_b0314.jpg  
  inflating: facades/test/b/cmp_b0313.jpg  
  inflating: facades/test/b/cmp_b0312.jpg  
  inflating: facades/test/b/cmp_b0311.jpg  
  inflating: facades/test/b/cmp_b0310.jpg  
  inflating:

In [2]:
from __future__ import print_function
import os
import argparse
from os.path import join
import numpy as np
from os import listdir
from os.path import join
from scipy.misc import imread, imresize, imsave

import chainer
import chainer.links as L
from chainer import cuda
import chainer.functions as F
from chainer import training
from chainer.training import extensions
from chainer import serializers
from chainer.utils import force_array
from chainer import optimizers, cuda, serializers
from chainer import Variable
from chainer.dataset import dataset_mixin

  util.experimental('cupy.core.fusion')


In [0]:
class DatasetFromFolder(dataset_mixin.DatasetMixin):
    def __init__(self, image_dir):
        self.a_path = join(image_dir, "a")
        self.b_path = join(image_dir, "b")
        self.image_filenames = [x for x in listdir(self.a_path) if is_image_file(x)]

    def __getitem__(self, index):
        # Load Image
        input = load_img(join(self.a_path, self.image_filenames[index]))
        target = load_img(join(self.b_path, self.image_filenames[index]))
        return input, target

    def __len__(self):
        return len(self.image_filenames)
      
def get_training_set(root_dir):
    train_dir = join(root_dir, "train")

    return DatasetFromFolder(train_dir)

def get_test_set(root_dir):
    test_dir = join(root_dir, "test")

    return DatasetFromFolder(test_dir)

In [0]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = imread(filepath)
    if len(img.shape) < 3:
        img = np.expand_dims(img, axis=2)
        img = np.repeat(img, 3, axis=2)
    img = imresize(img, (256, 256))
    # img = np.transpose(img, (2, 0, 1))
    return img
  
def save_img(img, filename):
    img = deprocess_img(img)
    img = img.numpy()
    img *= 255.0
    img = img.clip(0, 255)
    img = np.transpose(img, (1, 2, 0))
    img = imresize(img, (250, 200, 3))
    img = img.astype(np.uint8)
    imsave(filename, img)
    print("Image saved as {}".format(filename))

    
def preprocess_img(img):
    # [0,255] image to [0,1]
    min = img.min()
    max = img.max()
    img.add_(-min).mul_(1.0 / (max - min))
    # RGB to BGR
    # [0,1] to [-1,1]
    img = img.mul_(2).add_(-1)
    # check that input is in expected range
    assert img.max() <= 1, 'badly scaled inputs'
    assert img.min() >= -1, "badly scaled inputs"
    return img

def deprocess_img(img):
    # BGR to RGB
    # [-1,1] to [0,1]
    img = img.add_(1).div_(2)
    return img

In [0]:
class EncoderDecoder(chainer.Chain):
    def __init__(self, input_nc, output_nc, ngf):
      super(EncoderDecoder, self).__init__()
      with self.init_scope():
        #input_nc=3, output_nc=3
        #ngf=64
        #Convolution2D(self, in_channels, out_channels, ksize=None, stride=1, pad=0, nobias=False, initialW=None, initial_bias=None)
        self.conv1 = L.Convolution2D(input_nc, ngf, 4, 2, 1)
        self.conv2 = L.Convolution2D(ngf, ngf * 2, 4, 2, 1)
        self.conv3 = L.Convolution2D(ngf * 2, ngf * 4, 4, 2, 1)
        self.conv4 = L.Convolution2D(ngf * 4, ngf * 8, 4, 2, 1)
        self.conv5 = L.Convolution2D(ngf * 8, ngf * 8, 4, 2, 1)
        self.conv6 = L.Convolution2D(ngf * 8, ngf * 8, 4, 2, 1)
        self.conv7 = L.Convolution2D(ngf * 8, ngf * 8, 4, 2, 1)
        self.conv8 = L.Convolution2D(ngf * 8, ngf * 8, 4, 2, 1)
        self.batch_norm = L.BatchNormalization(ngf)
        self.batch_norm2 = L.BatchNormalization(ngf * 2)
        self.batch_norm4 = L.BatchNormalization(ngf * 4)
        self.batch_norm8 = L.BatchNormalization(ngf * 8)
        self.dconv1 = L.Deconvolution2D(ngf * 8, ngf * 8, 4, 2, 1)
        self.dconv2 = L.Deconvolution2D(ngf * 8 * 2, ngf * 8, 4, 2, 1)
        self.dconv3 = L.Deconvolution2D(ngf * 8 * 2, ngf * 8, 4, 2, 1)
        self.dconv4 = L.Deconvolution2D(ngf * 8 * 2, ngf * 8, 4, 2, 1)
        self.dconv5 = L.Deconvolution2D(ngf * 8 * 2, ngf * 4, 4, 2, 1)
        self.dconv6 = L.Deconvolution2D(ngf * 4 * 2, ngf * 2, 4, 2, 1)
        self.dconv7 = L.Deconvolution2D(ngf * 2 * 2, ngf, 4, 2, 1)
        self.dconv8 = L.Deconvolution2D(ngf * 2, output_nc, 4, 2, 1)

    def __call__(self, input_encdec):
        # Encoder
        # Convolution layers:
        # input_encdec is (nc) x 256 x 256
        e1 = self.conv1(input_encdec)
        # state size is (ngf) x 128 x 128
        e2 = self.batch_norm2(self.conv2(F.leaky_relu(e1)))
        # state size is (ngf x 2) x 64 x 64
        e3 = self.batch_norm4(self.conv3(F.leaky_relu(e2)))
        # state size is (ngf x 4) x 32 x 32
        e4 = self.batch_norm8(self.conv4(F.leaky_relu(e3)))
        # state size is (ngf x 8) x 16 x 16
        e5 = self.batch_norm8(self.conv5(F.leaky_relu(e4)))
        # state size is (ngf x 8) x 8 x 8
        e6 = self.batch_norm8(self.conv6(F.leaky_relu(e5)))
        # state size is (ngf x 8) x 4 x 4
        e7 = self.batch_norm8(self.conv7(F.leaky_relu(e6)))
        # state size is (ngf x 8) x 2 x 2
        # No batch norm on output of Encoder
        e8 = self.conv8(F.leaky_relu(e7))

        # Decoder
        # Deconvolution layers:
        # state size is (ngf x 8) x 1 x 1
        d1_ = F.dropout(self.batch_norm8(self.dconv1(F.relu(e8))))
        # state size is (ngf x 8) x 2 x 2
        d1 = F.concat((d1_, e7), axis=1)
        d2_ = F.dropout(self.batch_norm8(self.dconv2(F.relu(d1))))
        # state size is (ngf x 8) x 4 x 4
        d2 = F.concat((d2_, e6), axis=1)
        d3_ = F.dropout(self.batch_norm8(self.dconv3(F.relu(d2))))
        # state size is (ngf x 8) x 8 x 8
        d3 = F.concat((d3_, e5), axis=1)
        d4_ = self.batch_norm8(self.dconv4(F.relu(d3)))
        # state size is (ngf x 8) x 16 x 16
        d4 = F.concat((d4_, e4), axis=1)
        d5_ = self.batch_norm4(self.dconv5(F.relu(d4)))
        # state size is (ngf x 4) x 32 x 32
        d5 = F.concat((d5_, e3), axis=1)
        d6_ = self.batch_norm2(self.dconv6(F.relu(d5)))
        # state size is (ngf x 2) x 64 x 64
        d6 = F.concat((d6_, e2), axis=1)
        d7_ = self.batch_norm(self.dconv7(F.relu(d6)))
        # state size is (ngf) x 128 x 128
        d7 = F.concat((d7_, e1), axis=1)
        d8 = self.dconv8(F.relu(d7))
        # state size is (nc) x 256 x 256
        output = F.tanh(d8)
        return output


class Discriminator(chainer.Chain):
    def __init__(self, input_nc, output_nc, ngf):
      super(Discriminator, self).__init__()
      with self.init_scope():
        #input_nc=3, output_nc=3
        #ngf=64
        #Convolution2D(self, in_channels, out_channels, ksize=None, stride=1, pad=0, nobias=False, initialW=None, initial_bias=None)
        self.disconv1 = L.Convolution2D(input_nc + output_nc, ngf, 4, 2, 1)
        self.disconv2 = L.Convolution2D(ngf, ngf * 2, 4, 2, 1)
        self.disconv3 = L.Convolution2D(ngf * 2, ngf * 4, 4, 2, 1)
        self.disconv4 = L.Convolution2D(ngf * 4, ngf * 8, 4, 1, 1)
        self.disconv5 = L.Convolution2D(ngf * 8, 1, 4, 1, 1)
        self.batch_norm2 = L.BatchNormalization(ngf * 2)
        self.batch_norm4 = L.BatchNormalization(ngf * 4)
        self.batch_norm8 = L.BatchNormalization(ngf * 8)

    def __call__(self, input_disc):
        #input_nc=3, output_nc=3
        #ngf=64
        #Convolution2D(self, in_channels, out_channels, ksize=None, stride=1, pad=0, nobias=False, initialW=None, initial_bias=None)
        # input_disc is (nc x 2) x 256 x 256
        h1 = self.disconv1(input_disc)
        # state size is (ndf) x 128 x 128
        h2 = self.batch_norm2(self.disconv2(F.leaky_relu(h1)))
        # state size is (ndf x 2) x 64 x 64
        h3 = self.batch_norm4(self.disconv3(F.leaky_relu(h2)))
        # state size is (ndf x 4) x 32 x 32
        h4 = self.batch_norm8(self.disconv4(F.leaky_relu(h3)))
        # state size is (ndf x 8) x 31 x 31
        h5 = self.disconv5(F.leaky_relu(h4))
        # state size is (ndf) x 30 x 30, corresponds to 70 x 70 receptive
        output = F.sigmoid(h5)
        
        # print(output)
#         print(output.shape)
        return output

In [35]:
parser = argparse.ArgumentParser(description='chainer implementation of pix2pix')
parser.add_argument('--batchsize', '-b', type=int, default=1, help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=200, help='Number of sweeps over the dataset to train')
# parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)')
parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') #gpu
parser.add_argument('--dataset', '-i', default='facades', help='Directory of image files.')
parser.add_argument('--out', '-o', default='result', help='Directory to output the result')
parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--snapshot_interval', type=int, default=1000, help='Interval of snapshot')
parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console')

parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
args = parser.parse_args(args=[])

print('GPU: {}'.format(args.gpu))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))

print('===> Loading datasets')
root_path = "./"
train_set = get_training_set(root_path + args.dataset)
test_set = get_test_set(root_path + args.dataset)


GPU: 0
# Minibatch-size: 1
# epoch: 200
===> Loading datasets


In [0]:
encoderdecoder_model = EncoderDecoder(args.input_nc, args.output_nc, args.ngf)
discriminator_model = Discriminator(args.input_nc, args.output_nc, args.ngf)

In [37]:
if args.gpu >= 0:
    print("use gpu")
    chainer.backends.cuda.get_device_from_id(args.gpu).use()
    encoderdecoder_model.to_gpu()
    discriminator_model.to_gpu()

use gpu


In [0]:
optimizer_encoderdecoder = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_encoderdecoder.setup(encoderdecoder_model)
optimizer_discriminator = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_discriminator.setup(discriminator_model)

In [0]:
if args.gpu >= 0:
    xp = cuda.cupy
    label = xp.random.randn(args.batchsize)
    real_label = Variable(xp.ones((1,1,30,30), dtype=xp.float32))
    fake_label = Variable(xp.zeros((1,1,30,30), dtype=xp.float32))
else:
    label = np.random.randn(args.batchsize)
    real_label = Variable(np.ones((1,1,30,30), dtype=np.float32))
    fake_label = Variable(np.zeros((1,1,30,30), dtype=np.float32))

# loss

In [0]:
def loss_criterion(output, label, lam1=100, lam2=1):
    print("===loss_criterion===")
    loss = (F.mean_absolute_error(output, label))
    return loss

In [0]:
def loss_criterion_l1(y_out, t_out, lam1=100, lam2=1):
    print("===loss_criterion_l1===")
    # batchsize,_,w,h = list(y_out.data.shape)
    batchsize,_,w,h = y_out.data.shape
    loss_rec = (F.mean_absolute_error(y_out, t_out))
    print("loss_rec:", loss_rec)
    loss_adv = lam2*F.sum(F.softplus(-y_out)) / batchsize / w / h
    print("loss_adv:", loss_adv)
    loss = loss_rec + loss_adv
    return loss

In [0]:
def loss_dis(y_in, y_out):
    batchsize,_,w,h = y_in.data.shape    
    print(F.sum(F.softplus(-y_in)))
    L1 = F.sum(F.softplus(-y_in)) / batchsize / w / h
    L2 = F.sum(F.softplus(y_out)) / batchsize / w / h
    loss = L1 + L2
    # print("loss_dis:", loss)
    return loss

In [50]:
for iteration, batch in enumerate(train_set, 1):
  ############################
  # (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
  ###########################
  # train with real
  discriminator_model.zerograds()
  
  if args.gpu >= 0:
    real_A, real_B = xp.asarray(batch[0] / 255.0, dtype=xp.float32) , xp.asarray(batch[1] / 255.0, dtype=xp.float32)
  else:
      real_A, real_B = np.asarray(batch[0], dtype=np.float32) / 255.0, np.asarray(batch[1], dtype=np.float32) / 255.0
  real_A = real_A.transpose(2, 0, 1) #チャンネルを前に出す
  real_B = real_B.transpose(2, 0, 1)
  real_A = real_A.reshape(1,3,256,256)
  real_B = real_B.reshape(1,3,256,256)  
  #
  
#   output = discriminator_model(F.concat((real_A, real_B), axis=1))
#   label = (real_label)
# #   print("output", output)
#   print("output", output.shape)
  
#   # print("=> err_d_real")
#   err_d_real = loss_dis(output, label)
#   print("err_d_real", err_d_real)
#   err_d_real.backward()

  fake_b = encoderdecoder_model(real_A)
  y_fake = discriminator_model(F.concat((real_A, fake_b), axis=1))
  y_real = discriminator_model(F.concat((real_A, real_B), axis=1))
  
  print("y_fake",y_fake[0][0][0][:3])
  print("y_real",y_real[0][0][0][:3])
  print("y_fake",y_fake.shape)
  print(loss_criterion_l1(fake_b, real_B))
  
#   label = (fake_label)
#   # print("=> err_d_fake")
#   err_d_fake = loss_dis(output, label)
#   err_d_fake.backward()
#   err_d = (err_d_real + err_d_fake) / 2.0
#   # print("err_d_real: ", err_d_real)
#   # print("err_d_fake: ", err_d_fake)
#   # print("(err_d_real + err_d_fake) / 2.0 ", (err_d_real + err_d_fake) / 2.0)
#   optimizer_discriminator.update()
  ############################
  # (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
  ###########################
#   output = discriminator_model(F.concat((real_A, fake_b), axis=1))
#   label = (real_label)
#   err_g = loss_criterion(output, label) + loss_criterion_l1(fake_b, real_B)
#   err_g.backward()
#   optimizer_encoderdecoder.update()

  print("===> Epoch[{}]({}/{}): Loss_D: {} Loss_G: {} ".format(
      args.epoch, iteration, len(train_set), err_d.data, err_g.data))

#   if args.snapshot_interval % args.epoch == 0:
#       serializers.save_npz("encoderdecoder_model_"+str(args.epoch), encoderdecoder_model)
#       serializers.save_npz("discriminator_model_"+str(args.epoch), discriminator_model)

  if issubdtype(ts, int):
  elif issubdtype(type(size), float):


y_fake variable([0.46424884 0.44736555 0.64248323])
y_real variable([0.33556655 0.4039415  0.4807378 ])
y_fake (1, 1, 30, 30)
===> Epoch[200](1/200): Loss_D: nan Loss_G: nan 
y_fake variable([0.28270423 0.38378084 0.33759063])
y_real variable([0.5495077  0.44476414 0.3948666 ])
y_fake (1, 1, 30, 30)
===> Epoch[200](2/200): Loss_D: nan Loss_G: nan 
y_fake variable([0.36939508 0.2787093  0.3760412 ])
y_real variable([0.58014387 0.6282082  0.53632915])
y_fake (1, 1, 30, 30)
===> Epoch[200](3/200): Loss_D: nan Loss_G: nan 
y_fake variable([0.3643916 0.2657321 0.4170055])
y_real variable([0.42573005 0.58259    0.28723422])
y_fake (1, 1, 30, 30)
===> Epoch[200](4/200): Loss_D: nan Loss_G: nan 
y_fake variable([0.48874155 0.45009148 0.45325002])
y_real variable([0.3835166  0.4756761  0.23097081])
y_fake (1, 1, 30, 30)
===> Epoch[200](5/200): Loss_D: nan Loss_G: nan 
y_fake variable([0.50521487 0.47832873 0.37160125])
y_real variable([0.693871  0.4836356 0.3912316])
y_fake (1, 1, 30, 30)
===> 

KeyboardInterrupt: ignored