<a href="https://colab.research.google.com/github/ko4ro/paper_survey_colab/blob/main/cut_jp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive Unpaired Translation (CUT)


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://github.com/taesungp/contrastive-unpaired-translation"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">GitHub</a>
  </td>
  <td>
    <a href="http://taesung.me/ContrastiveUnpairedTranslation/"><img src="https://www.tensorflow.org/images/download_logo_32px.png">元論文をダウンロード</a>
  </td>
</table>

このノートブックでは、ECCV2020で発表された「[Contrastive Learning for Unpaired Image-to-Image Translation](https://arxiv.org/pdf/2007.15651.pdf)」について、実際にコードを動かしながら、内容を理解していくものである。

この論文では教師なし画像変換タスクにおいて [Contrastive Learning](https://qiita.com/omiita/items/a7429ec42e4eef4b6a4d)を有効に使う方法を示している。ポイントとして、本研究では画像全体ではなく、パッチ単位かつ多層でContrastive Learningを行っている。
またパッチのNegativeサンプルは他の画像から得るのではなく、入力画像内からサンプリングすることでパッチ間の相互情報量が最大化されることを期待しています。従来手法と比較して、性能を向上していることに加え、学習時間も短縮できることを実証している。さらには、それぞれの「ドメイン」が1枚の画像のみである場合においても、学習できるように拡張することができる。

![cut_horse2zebra](https://raw.githubusercontent.com/ko4ro/paper_survey_colab/main/asset/figures/cut_horse2zebra.jpeg)

## git clone & pip　インストール

gitのリポジトリをクローンして、必要なパッケージをインストールしている。

In [1]:
!git clone  -b feature/colab https://github.com/ko4ro/contrastive-unpaired-translation.git cut
!pip install -r ./cut/requirements.txt

Cloning into 'cut'...
remote: Enumerating objects: 260, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 260 (delta 40), reused 29 (delta 28), pack-reused 198[K
Receiving objects: 100% (260/260), 17.90 MiB | 12.52 MiB/s, done.
Resolving deltas: 100% (131/131), done.
Collecting dominate>=2.4.0
  Downloading dominate-2.6.0-py2.py3-none-any.whl (29 kB)
Collecting visdom>=0.1.8.8
  Downloading visdom-0.1.8.9.tar.gz (676 kB)
[K     |████████████████████████████████| 676 kB 6.3 MB/s 
Collecting GPUtil>=1.4.0
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
Collecting jsonpatch
  Downloading jsonpatch-1.32-py2.py3-none-any.whl (12 kB)
Collecting torchfile
  Downloading torchfile-0.1.0.tar.gz (5.2 kB)
Collecting websocket-client
  Downloading websocket_client-1.2.1-py2.py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 1.5 MB/s 
Collecting jsonpointer>=1.9
  Downloading jsonpointer-2.1-py2.py3-none-an

## データセット準備

ここでは、Fig 8 of the paper の　Russian Blue cat (89枚)→ Grumpy cat (215枚) のデータセットを用いる。

![grumpycat.jpeg](https://raw.githubusercontent.com/ko4ro/paper_survey_colab/main/asset/figures/grumpycat.jpeg)

In [2]:
dataset_name = "grumpifycat" #@param ["ae_photos","apple2orange", "summer2winter_yosemite", "horse2zebra", "monet2photo", "cezanne2photo", "ukiyoe2photo", "vangogh2photo", "maps", "cityscapes", "facades", "iphone2dslr_flower", "mini", "mini_pix2pix", "mini_colorization", "grumpifycat"]

In [3]:
!cd ./cut/ && bash datasets/download_cut_dataset.sh $dataset_name

+ FILE=grumpifycat
+ [[ grumpifycat != \a\e\_\p\h\o\t\o\s ]]
+ [[ grumpifycat != \a\p\p\l\e\2\o\r\a\n\g\e ]]
+ [[ grumpifycat != \s\u\m\m\e\r\2\w\i\n\t\e\r\_\y\o\s\e\m\i\t\e ]]
+ [[ grumpifycat != \h\o\r\s\e\2\z\e\b\r\a ]]
+ [[ grumpifycat != \m\o\n\e\t\2\p\h\o\t\o ]]
+ [[ grumpifycat != \c\e\z\a\n\n\e\2\p\h\o\t\o ]]
+ [[ grumpifycat != \u\k\i\y\o\e\2\p\h\o\t\o ]]
+ [[ grumpifycat != \v\a\n\g\o\g\h\2\p\h\o\t\o ]]
+ [[ grumpifycat != \m\a\p\s ]]
+ [[ grumpifycat != \c\i\t\y\s\c\a\p\e\s ]]
+ [[ grumpifycat != \f\a\c\a\d\e\s ]]
+ [[ grumpifycat != \i\p\h\o\n\e\2\d\s\l\r\_\f\l\o\w\e\r ]]
+ [[ grumpifycat != \m\i\n\i ]]
+ [[ grumpifycat != \m\i\n\i\_\p\i\x\2\p\i\x ]]
+ [[ grumpifycat != \m\i\n\i\_\c\o\l\o\r\i\z\a\t\i\o\n ]]
+ [[ grumpifycat != \g\r\u\m\p\i\f\y\c\a\t ]]
+ [[ grumpifycat == \c\i\t\y\s\c\a\p\e\s ]]
+ echo 'Specified [grumpifycat]'
Specified [grumpifycat]
+ URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/grumpifycat.zip
+ ZIP_FILE=./datasets/grumpifycat.zip

## インポート

In [4]:
import sys
ROOT_PATH = '/content/cut/'
sys.path.append(ROOT_PATH)
print(sys.path)

['', '/content', '/env/python', '/usr/lib/python37.zip', '/usr/lib/python3.7', '/usr/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages/IPython/extensions', '/root/.ipython', '/content/cut/']


In [5]:
import time
import torch
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer

## 学習パラメータの準備

In [6]:
opt = TrainOptions().parse()  # get training options

----------------- Options ---------------
                 CUT_mode: CUT                           
               batch_size: 1                             
                    beta1: 0.5                           
                    beta2: 0.999                         
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
                 dataroot: placeholder                   
             dataset_mode: unaligned                     
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: None                          
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256

In [7]:
opt.dataroot = f'/content/cut/datasets/{dataset_name}'
opt.name = f'{dataset_name}_CUT'
opt.n_epochs = 1
opt.n_epochs_decay = 1
opt.print_freq = 1
opt.save_latest_freq = 1
opt.save_epoch_freq = 1
opt.num_threads = 2
opt.load_size = 286

In [8]:
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset)    # get the number of images in the dataset.

dataset [UnalignedDataset] was created


In [9]:
model = create_model(opt)      # create a model given opt.model and other options
print('The number of training images = %d' % dataset_size)


model [CUTModel] was created
The number of training images = 214


In [None]:
visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots

Setting up a new session...
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/urllib3/connection.py", line 159, in _new_conn
    (self._dns_host, self.port), self.timeout, **extra_kw)
  File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 80, in create_connection
    raise err
  File "/usr/local/lib/python3.7/dist-packages/urllib3/util/connection.py", line 70, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
    chunked=chunked)
  File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 354, in _make_request
    conn.request(method, url, **httplib_request_kw)
  File "/usr/lib/python3.7/http/client.py", line 1281, in request
    self._send_request(method, url, body

Exception in user code:
------------------------------------------------------------


[Errno 99] Cannot assign requested address
on_close() takes 1 positional argument but 3 were given
[Errno 99] Cannot assign requested address
on_close() takes 1 positional argument but 3 were given
Visdom python client failed to establish socket to get messages from the server. This feature is optional and can be disabled by initializing Visdom with `use_incoming_socket=False`, which will prevent waiting for this request to timeout.




Could not connect to Visdom server. 
 Trying to start a server....
Command: /usr/bin/python3 -m visdom.server -p 8097 &>/dev/null &
create web directory ./checkpoints/grumpifycat_CUT/web...


In [None]:
opt.visualizer = visualizer
total_iters = 0
optimize_time = 0.1
times = []

In [None]:
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        dataset.set_epoch(epoch)
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            batch_size = data["A"].size(0)
            total_iters += batch_size
            epoch_iter += batch_size
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_start_time = time.time()
            if epoch == opt.epoch_count and i == 0:
                model.data_dependent_initialize(data)
                model.setup(opt)               # regular setup: load and print networks; create schedulers
                model.parallelize()
            model.set_input(data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
                if opt.display_id is None or opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                print(opt.name)  # it's useful to occasionally show the experiment name on console
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()

        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
        model.update_learning_rate()                     # update learning rates at the end of every epoch.

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
[Network F] Total number of parameters : 0.560 M
[Network D] Total number of parameters : 2.765 M
-----------------------------------------------
(epoch: 1, iters: 1, time: 0.163, data: 0.188) G_GAN: 0.955 D_real: 1.002 D_fake: 0.000 G: 8.403 NCE: 7.309 NCE_Y: 7.587 
saving the latest model (epoch 1, total_iters 1)
grumpifycat_CUT
(epoch: 1, iters: 2, time: 0.167, data: 0.003) G_GAN: 0.788 D_real: 0.322 D_fake: 0.029 G: 6.426 NCE: 5.613 NCE_Y: 5.664 
saving the latest model (epoch 1, total_iters 2)
grumpifycat_CUT
(epoch: 1, iters: 3, time: 0.171, data: 0.002) G_GAN: 0.590 D_real: 0.147 D_fake: 1.345 G: 6.209 NCE: 5.586 NCE_Y: 5.651 
saving the latest model (epoch 1, total_iters 3)
grumpifycat_CUT
(epoch: 1, iters: 4, time: 0.175, data: 0.004) G_GAN: 0.452 D_real: 0.797 D_fake: 0.630 G: 6.030 NCE: 5.596 NCE_Y: 5.559 
saving the latest model (epoch 1, total_iters 4)
grumpifycat_CUT
(epoch: 1,

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


(epoch: 2, iters: 1, time: 0.705, data: 0.152) G_GAN: 0.272 D_real: 0.244 D_fake: 0.313 G: 3.555 NCE: 3.412 NCE_Y: 3.155 
saving the latest model (epoch 2, total_iters 215)
grumpifycat_CUT
(epoch: 2, iters: 2, time: 0.707, data: 0.001) G_GAN: 0.327 D_real: 0.353 D_fake: 0.140 G: 3.467 NCE: 3.325 NCE_Y: 2.956 
saving the latest model (epoch 2, total_iters 216)
grumpifycat_CUT
(epoch: 2, iters: 3, time: 0.708, data: 0.002) G_GAN: 0.299 D_real: 0.209 D_fake: 0.164 G: 3.101 NCE: 2.947 NCE_Y: 2.656 
saving the latest model (epoch 2, total_iters 217)
grumpifycat_CUT
(epoch: 2, iters: 4, time: 0.710, data: 0.003) G_GAN: 0.292 D_real: 0.178 D_fake: 0.205 G: 3.023 NCE: 2.800 NCE_Y: 2.663 
saving the latest model (epoch 2, total_iters 218)
grumpifycat_CUT
(epoch: 2, iters: 5, time: 0.711, data: 0.005) G_GAN: 0.308 D_real: 0.145 D_fake: 0.221 G: 2.989 NCE: 2.690 NCE_Y: 2.671 
saving the latest model (epoch 2, total_iters 219)
grumpifycat_CUT
(epoch: 2, iters: 6, time: 0.713, data: 0.008) G_GAN: 0

In [None]:
!zip -r /content/checkpoints.zip /content/checkpoints/

  adding: content/checkpoints/ (stored 0%)
  adding: content/checkpoints/experiment_name/ (stored 0%)
  adding: content/checkpoints/experiment_name/train_opt.txt (deflated 79%)
  adding: content/checkpoints/grumpifycat_CUT/ (stored 0%)
  adding: content/checkpoints/grumpifycat_CUT/2_net_G.pth (deflated 7%)
  adding: content/checkpoints/grumpifycat_CUT/latest_net_D.pth (deflated 7%)
  adding: content/checkpoints/grumpifycat_CUT/latest_net_F.pth (deflated 7%)
  adding: content/checkpoints/grumpifycat_CUT/1_net_F.pth (deflated 7%)
  adding: content/checkpoints/grumpifycat_CUT/web/ (stored 0%)
  adding: content/checkpoints/grumpifycat_CUT/web/images/ (stored 0%)
  adding: content/checkpoints/grumpifycat_CUT/web/images/epoch002_idt_B.png (deflated 0%)
  adding: content/checkpoints/grumpifycat_CUT/web/images/epoch002_real_B.png (deflated 0%)
  adding: content/checkpoints/grumpifycat_CUT/web/images/epoch002_real_A.png (deflated 0%)
  adding: content/checkpoints/grumpifycat_CUT/web/images/epoc