<a href="https://colab.research.google.com/github/ko4ro/paper_survey_colab/blob/main/cut_jp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive Unpaired Translation (CUT)


<td><a target="_blank" href="https://github.com/taesungp/contrastive-unpaired-translation"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">GitHub</a></td>
  <td>     <a href="http://taesung.me/ContrastiveUnpairedTranslation/"><img src="https://www.tensorflow.org/images/download_logo_32px.png">元論文をダウンロード</a>
</td>
</table>

このノートブックでは、「[Contrastive Learning for Unpaired Image-to-Image Translation](https://arxiv.org/pdf/2007.1565)」について、実際にコードを動かしながら、内容を理解していくものである。

この論文では教師なし画像変換タスクにおいて [Contrastive(対照)学習](https://qiita.com/omiita/items/a7429ec42e4eef4b6a4d)を有効に使う方法を示している。ポイントとして、本研究では画像全体ではなく、パッチ単位かつ多層でContrastive(対照)学習を行っている。
またパッチのNegativeサンプルは他の画像から得るのではなく、入力画像内からサンプリングすることでパッチ間の相互情報量が最大化されることを期待しています。従来手法と比較し性能を向上していることに加え、学習時間も短縮できることを実証している。さらには、それぞれの「ドメイン」が1枚の画像のみである場合においても、学習できるように拡張することができる。

![cut_horse2zebra](https://raw.githubusercontent.com/ko4ro/paper_survey_colab/main/asset/figures/cut_horse2zebra.jpeg)

## Gitのリポジトリをクローンしてくる

In [1]:
!git clone  -b feature/colab https://github.com/ko4ro/contrastive-unpaired-translation.git cut
!pip install -r ./cut/requirements.txt

fatal: destination path 'cut' already exists and is not an empty directory.


## データセット準備

In [2]:
dataset_name = "grumpifycat" #@param ["ae_photos","apple2orange", "summer2winter_yosemite", "horse2zebra", "monet2photo", "cezanne2photo", "ukiyoe2photo", "vangogh2photo", "maps", "cityscapes", "facades", "iphone2dslr_flower", "mini", "mini_pix2pix", "mini_colorization", "grumpifycat"]

In [3]:
!cd ./cut/ && bash datasets/download_cut_dataset.sh $dataset_name

+ FILE=grumpifycat
+ [[ grumpifycat != \a\e\_\p\h\o\t\o\s ]]
+ [[ grumpifycat != \a\p\p\l\e\2\o\r\a\n\g\e ]]
+ [[ grumpifycat != \s\u\m\m\e\r\2\w\i\n\t\e\r\_\y\o\s\e\m\i\t\e ]]
+ [[ grumpifycat != \h\o\r\s\e\2\z\e\b\r\a ]]
+ [[ grumpifycat != \m\o\n\e\t\2\p\h\o\t\o ]]
+ [[ grumpifycat != \c\e\z\a\n\n\e\2\p\h\o\t\o ]]
+ [[ grumpifycat != \u\k\i\y\o\e\2\p\h\o\t\o ]]
+ [[ grumpifycat != \v\a\n\g\o\g\h\2\p\h\o\t\o ]]
+ [[ grumpifycat != \m\a\p\s ]]
+ [[ grumpifycat != \c\i\t\y\s\c\a\p\e\s ]]
+ [[ grumpifycat != \f\a\c\a\d\e\s ]]
+ [[ grumpifycat != \i\p\h\o\n\e\2\d\s\l\r\_\f\l\o\w\e\r ]]
+ [[ grumpifycat != \m\i\n\i ]]
+ [[ grumpifycat != \m\i\n\i\_\p\i\x\2\p\i\x ]]
+ [[ grumpifycat != \m\i\n\i\_\c\o\l\o\r\i\z\a\t\i\o\n ]]
+ [[ grumpifycat != \g\r\u\m\p\i\f\y\c\a\t ]]
+ [[ grumpifycat == \c\i\t\y\s\c\a\p\e\s ]]
+ echo 'Specified [grumpifycat]'
Specified [grumpifycat]
+ URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/grumpifycat.zip
+ ZIP_FILE=./datasets/grumpifycat.zip

## インポート

In [4]:
import sys
ROOT_PATH = '/content/cut/'
sys.path.append(ROOT_PATH)
print(sys.path)

['', '/content', '/env/python', '/usr/lib/python37.zip', '/usr/lib/python3.7', '/usr/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages/IPython/extensions', '/root/.ipython', '/content/cut/']


In [5]:
import time
import torch
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer

## 学習パラメータの準備

In [6]:
opt = TrainOptions().parse()  # get training options

----------------- Options ---------------
                 CUT_mode: CUT                           
               batch_size: 1                             
                    beta1: 0.5                           
                    beta2: 0.999                         
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
                 dataroot: placeholder                   
             dataset_mode: unaligned                     
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: None                          
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256

In [7]:
opt.dataroot = f'/content/cut/datasets/{dataset_name}'
opt.name = f'{dataset_name}_CUT'
opt.n_epochs = 5
opt.n_epochs_decay = 5

In [8]:
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset)    # get the number of images in the dataset.

dataset [UnalignedDataset] was created


  cpuset_checked))


In [9]:
model = create_model(opt)      # create a model given opt.model and other options
print('The number of training images = %d' % dataset_size)


model [CUTModel] was created
The number of training images = 214


In [10]:
visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots

Setting up a new session...


create web directory ./checkpoints/grumpifycat_CUT/web...


In [11]:
opt.visualizer = visualizer
total_iters = 0
optimize_time = 0.1
times = []

In [None]:
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        dataset.set_epoch(epoch)
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            batch_size = data["A"].size(0)
            total_iters += batch_size
            epoch_iter += batch_size
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_start_time = time.time()
            if epoch == opt.epoch_count and i == 0:
                model.data_dependent_initialize(data)
                model.setup(opt)               # regular setup: load and print networks; create schedulers
                model.parallelize()
            model.set_input(data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
                if opt.display_id is None or opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                print(opt.name)  # it's useful to occasionally show the experiment name on console
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()

        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
        model.update_learning_rate()                     # update learning rates at the end of every epoch.

  cpuset_checked))
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
[Network F] Total number of parameters : 0.560 M
[Network D] Total number of parameters : 2.765 M
-----------------------------------------------
(epoch: 1, iters: 100, time: 0.494, data: 0.229) G_GAN: 0.295 D_real: 0.197 D_fake: 0.238 G: 3.924 NCE: 3.682 NCE_Y: 3.575 
(epoch: 1, iters: 200, time: 0.697, data: 0.002) G_GAN: 0.282 D_real: 0.271 D_fake: 0.196 G: 3.159 NCE: 2.979 NCE_Y: 2.775 
End of epoch 1 / 10 	 Time Taken: 228 sec
learning rate = 0.0002000


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 2, iters: 86, time: 0.819, data: 0.005) G_GAN: 0.361 D_real: 0.137 D_fake: 0.283 G: 2.806 NCE: 2.578 NCE_Y: 2.312 
(epoch: 2, iters: 186, time: 0.894, data: 0.002) G_GAN: 0.348 D_real: 0.297 D_fake: 0.147 G: 2.515 NCE: 2.268 NCE_Y: 2.066 
End of epoch 2 / 10 	 Time Taken: 217 sec
learning rate = 0.0002000


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 3, iters: 72, time: 0.939, data: 0.002) G_GAN: 0.407 D_real: 0.101 D_fake: 0.195 G: 2.200 NCE: 1.867 NCE_Y: 1.720 
(epoch: 3, iters: 172, time: 0.966, data: 0.002) G_GAN: 0.498 D_real: 0.262 D_fake: 0.097 G: 2.003 NCE: 1.502 NCE_Y: 1.507 
End of epoch 3 / 10 	 Time Taken: 216 sec
learning rate = 0.0002000


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 4, iters: 58, time: 0.983, data: 0.002) G_GAN: 0.510 D_real: 0.122 D_fake: 0.158 G: 1.936 NCE: 1.437 NCE_Y: 1.414 
(epoch: 4, iters: 158, time: 0.993, data: 0.002) G_GAN: 0.206 D_real: 0.270 D_fake: 0.302 G: 1.868 NCE: 1.501 NCE_Y: 1.823 
End of epoch 4 / 10 	 Time Taken: 217 sec
learning rate = 0.0002000


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 5, iters: 44, time: 0.999, data: 0.003) G_GAN: 0.311 D_real: 0.146 D_fake: 0.348 G: 1.677 NCE: 1.369 NCE_Y: 1.364 
(epoch: 5, iters: 144, time: 1.003, data: 0.002) G_GAN: 0.518 D_real: 0.248 D_fake: 0.153 G: 1.976 NCE: 1.560 NCE_Y: 1.356 
saving the model at the end of epoch 5, iters 1070
End of epoch 5 / 10 	 Time Taken: 217 sec
learning rate = 0.0001667


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 6, iters: 30, time: 1.005, data: 0.002) G_GAN: 0.430 D_real: 0.011 D_fake: 0.463 G: 3.011 NCE: 2.526 NCE_Y: 2.636 
