In [None]:
import os
os.chdir('/home/extra/micheal/IDP')

# Preicise cut and paste

In this notebook, we train the pix2pix model to learn a artifitially combined bscan of OCT and OP, from their combined segmentation labels.

In [None]:
# pip install opencv-python
from collections import namedtuple
from glob import glob
from pathlib import Path
from os.path import join
import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from idp_utils.visualization import visualize_segmentation, plot_horizontal
from idp_utils.data_handling.mask import expand_label, get_dst_shadow
import idp_utils.data_handling.constants as C

### Randomly choose source and target

Fisrt, we randomly choose a OCT and a iOCT image from the datasets, whose segmentations and bscans will be later combined for training.


These codes are moved to `idp_tuils.data_handling.ulabel`

```python
Label = namedtuple("Label", "instrument mirror ilm ipl rpe bm", defaults=(None,)*6)
unified_label = Label(instrument=1, mirror=2, ilm=3, ipl=4, rpe=5, bm=6)
aroi_label = Label(ilm=19, ipl=57, rpe=171, bm=190)
op_label = Label(instrument=2, mirror=4, ilm=1, rpe=3)

def unify_label(label, src_labels, dst_labels, remove_list=[]):
    """Transform values in label from src_labels to dst_labels.
    It will return a new label. The original label will be left untouched.
    Args:
        label (numpy array)
        src_labels (namedtuple(Label))
        dst_labels (namedtuple(Label))
        remove_list (list(int)): all labels in this list will be set to 0
    """
    label_copy = label.copy()
    for l in Label._fields:
        s_label = getattr(src_labels, l)
        d_label = getattr(dst_labels, l)
        if s_label is not None and d_label is not None:
            label_copy[label == s_label] = d_label
    for l in remove_list:
        label_copy[label == l] = 0
    return label_copy
```

In [None]:
from idp_utils.data_handling.ulabel import unified_label, aroi_label, op_label, unify_label

In [None]:
def get_bscan_label(img_name, bscan_prefix="data/ioct/bscans/val/", label_prefix="data/ioct/labels/val/"):
    bscan = Image.open(os.path.join(bscan_prefix, img_name))
    label = np.asarray(Image.open(os.path.join(label_prefix, img_name)))
    return bscan, label

In [None]:
def create_crossover(src_bscan, src_label, dst_bscan, dst_label, mask):
    cross_bscan = dst_bscan.copy()
    cross_bscan[mask] = src_bscan[mask]
    cross_label = dst_label.copy()
    cross_label[mask] = src_label[mask]
    return cross_bscan, cross_label

Visualize the intra-operative b-scans (OP dataset) and their segmentations.

In [None]:
bscan_prefix = os.path.join(C.SPLIT_PATTERN.format(data='ioct'), 'bscans', 'val')
label_prefix = os.path.join(C.SPLIT_PATTERN.format(data='ioct'), 'labels', 'val')
img_name = "5d396575-e039-49da-a219-fe239e8bd9c88062-101.png"

src_bscan, src_label = get_bscan_label(img_name, bscan_prefix, label_prefix)
src_label=unify_label(src_label, src_labels=op_label, dst_labels=unified_label)
visualize_segmentation(src_bscan, src_label, show_original=True)

Visualize the OCT b-scans (AROI dataset) and their segmentations.

In [None]:
dst_bscan_prefix = os.path.join(C.SPLIT_PATTERN.format(data='aroi'), 'bscans', 'train') # "data/aroi/bscans/train/"
dst_label_prefix = os.path.join(C.SPLIT_PATTERN.format(data='aroi'), 'labels', 'train') # "data/aroi/labels/train/"
dst_img_name = "patient21_raw0060.png"  # "patient15_raw0032.png"
dst_bscan, dst_label = get_bscan_label(dst_img_name, dst_bscan_prefix, dst_label_prefix)
dst_label = unify_label(dst_label, src_labels=aroi_label, dst_labels=unified_label, remove_list=[80, 160, 240])
visualize_segmentation(dst_bscan, dst_label, show_original=True)

## Part 1: Manipulate on the bscan

The goal is the artifitialy create a cross-over image.

To identify which part should be copied from which image, we need to know the location of the instrument and the mirror in the image. We can use the segmentation labels to identify these two parts, by **generating masks** from them.

### Step 1: Copy and paste the tool with the manually expanded mask

Such mask is very inaccurate.

TODO:
- improve mask

#### 1. Expand instrument and mirroring label

In [None]:
expanded_label = expand_label(src_label,
                              instrument_label=unified_label.instrument,
                              mirror_label=unified_label.mirror,
                              expansion_instrument=30,
                              expansion_mirror=0,
                              expand_upward=True)
plt.imshow(expanded_label)

#### 2. Cover the layers with the target label

In [None]:
def get_dst_shadow(src_label, dst_label, instrument_label=2, mirror_label=4,
                   top_layer_label=1, margin_above=0, pad_left=0, pad_right=0):
    """Get the shadow of the source label in the destination label, taking
    the instrument and shadow label in the source as well as the layer label
    in the destination into account
    Args:
    overflow_above: the margin to include above the top layer. This is for the cases
        that the human labeled top layer is inaccurate and leaves some pixels out.
    pad_left: the number of pixels to pad to the left of the shadow. This is for the
        that direct under of the instrument actually includes some layers.
    pad_right: the number of pixels to pad to the right of the shadow. This is for the
        that direct under of the instrument actually includes some layers.
    """
    img_height, img_width = src_label.shape
    shadow_x = np.array([], dtype=np.int64)
    shadow_y = np.array([], dtype=np.int64)
    # Requirements for the shadow label:
    # 1. Horizontally after the starting of the instrument/mirroring & before the
    #    ending of the instrument/mirroring
    # 2. Vertically below the upper bound of layers
    x_src_tool, y_src_tool = np.where(np.logical_or(src_label == instrument_label,
                    src_label == mirror_label))  # (1024, 512)
    if len(x_src_tool) == 0:
        return shadow_x, shadow_y
    left_bound = np.min(y_src_tool)
    right_bound = np.max(y_src_tool)
    # Detect left break and right break of the top layer, this is to adjust the left and
    # right bound of the shadow.
    for y in range(left_bound, img_width):
        # If the layer continues to present to the right of left_bound below the tools,
        # increase left_bound
        if np.any(src_label[:, y] == top_layer_label):
            left_bound = y
        else:
            break
    for y in range(right_bound, -1, -1):
        # If the layer continues to present to the left of right_bound below the tools,
        # decrease right_bound
        if np.any(src_label[:, y] == top_layer_label):
            right_bound = y
        else:
            break
    if pad_left + left_bound < right_bound:
        left_bound += pad_left
    if right_bound - pad_left > left_bound:
        right_bound -= pad_right
    accumulated_min_upperbound = 0
    for i in range(left_bound, right_bound):
        top_layer = np.where(dst_label[:, i] == top_layer_label)[0]
        if len(top_layer) == 0:
            if accumulated_min_upperbound == 0:
                continue
            else:
                # set to current recorded highest layer
                top_layer_upperbound = accumulated_min_upperbound
        else:
            # print("instrument_above", instrument_above, len(instrument_above))
            top_layer_upperbound = np.min(top_layer)
            if top_layer_upperbound - margin_above > 0:
                top_layer_upperbound -= margin_above
            if accumulated_min_upperbound == 0:
                # initialize
                accumulated_min_upperbound = top_layer_upperbound
            else:
                accumulated_min_upperbound = min(
                    accumulated_min_upperbound, top_layer_upperbound)
        x_vertical = np.arange(top_layer_upperbound,
                               img_height)  # upperbound to bottom
        y_vertical = np.full_like(x_vertical, i)
        shadow_x = np.concatenate([shadow_x, x_vertical])
        shadow_y = np.concatenate([shadow_y, y_vertical])
    return shadow_x, shadow_y

In [None]:
shadow_x, shadow_y = get_dst_shadow(expanded_label, dst_label, instrument_label=unified_label.instrument,
                                    mirror_label=unified_label.mirror, top_layer_label=unified_label.ilm,
                                    margin_above=5, pad_left=0)

In [None]:
mask = np.full(src_label.shape, False)
classes_of_interest = [unified_label.instrument, unified_label.mirror]
for c in classes_of_interest:
    mask[expanded_label==c] = True
mask[shadow_x, shadow_y] = True
plt.imshow(mask)

#### 3. Copy and past

In [None]:
dst_bscan = np.asarray(dst_bscan)
src_bscan = np.asarray(src_bscan)

In [None]:
cross_bscan = dst_bscan.copy()
cross_bscan[mask] = src_bscan[mask]
# cross_bscan = np.ma.array(cross_bscan, mask=mask)
plt.imshow(cross_bscan)

## Part 2: Manipulate the label

The goal of this part is to create the corresponding label for the first part. The intuition is that it is very precise to manipulate on the label map. Then we can generate a image from the manipulated label map.

### 1. Create Cross-over labels

In [None]:
cross_bscan, cross_label = create_crossover(src_bscan, src_label, dst_bscan, dst_label, mask)
visualize_segmentation(cross_bscan, cross_label, show_original=True)

### 2. Expand Instrument Labels (Optional) 

Here I'll expand the instrument and mirroring labels. In previous experiment, we found out that the instruments are blurred.

The intuition is that maybe a better coverage of the label helps generate clearer instruments.

In [None]:
from idp_utils.data_handling.mask import expand_label
from idp_utils.data_handling.ulabel import unified_label

In [None]:
unified_label.instrument

In [None]:
expanded_label = expand_label(cross_label,
                              instrument_label=unified_label.instrument,
                              mirror_label=unified_label.mirror,
                              expansion_instrument=20,
                              expansion_mirror=40)
plt.imshow(expanded_label)

In [None]:
visualize_segmentation(cross_bscan, expanded_label, show_original=True)

## Part 3: Create a dataset

Per requirements to the [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md#pix2pix-datasets), we have to create the dataset and place it using a certain format.

The ideal way is to implement a custom sampler [link](https://discuss.pytorch.org/t/how-to-generate-random-pairs-at-each-epoch/112065/2).

Here I'll do it in a simpler way.
1. Randomly choose a OP and an AROI image
2. Create a mask based on the OP and the AROI image
3. Make the cross over label and bscan
4. Save to disk

### 1. Create a pair

In [None]:
from tqdm import tqdm
import random

op_bscan_folder = os.path.join(C.SPLIT_PATTERN.format(data='ioct'), 'bscans')
op_label_folder = os.path.join(C.SPLIT_PATTERN.format(data='ioct'), 'labels')
aroi_bscan_folder = os.path.join(C.SPLIT_PATTERN.format(data='aroi'), 'bscans')
aroi_label_folder = os.path.join(C.SPLIT_PATTERN.format(data='aroi'), 'labels')
def sample(src_bscan_prefix, src_label_prefix, dst_bscan_prefix, bscan_label_prefix, n_pairs=1000):
    """
    Sample pairs of source and target images. Source is where we will cut the instruments and
    mirrorings, destination is where we want to preserve the layers. The passed in prefixs are
    expected to directly contain paired images respectively.
    
    Returns:
        pairs (list(tuple)): a list of tuples of paths [(src_bscan, src_label, dst_bscan, dst_label)]
    """
    pairs = []
    src_bscans = glob(os.path.join(src_bscan_prefix, '*'))
    src_labels = glob(os.path.join(src_label_prefix, '*'))
    dst_bscans = glob(os.path.join(dst_bscan_prefix, '*'))
    dst_labels = glob(os.path.join(dst_label_prefix, '*'))
    assert len(src_bscans) == len(src_labels) and len(src_bscans) != 0, f"Length mismatch for bscans and labels ({len(src_bscans)}!={len(src_labels)})"
    assert len(dst_bscans) == len(dst_labels) and len(dst_bscans) != 0, f"Length mismatch for bscans and labels ({len(dst_bscans)}!={len(dst_labels)})"
    for _ in tqdm(range(n_pairs)):
        src_idx = random.randrange(len(src_bscans))
        dst_idx = random.randrange(len(dst_bscans))
        src_bscan = src_bscans[src_idx]
        src_label = src_labels[src_idx]
        dst_bscan = dst_bscans[dst_idx]
        dst_label = dst_labels[dst_idx]
        pairs.append((src_bscan, src_label, dst_bscan, dst_label))
    return pairs

In [None]:
sampled_pairs = {}
for split, n_pairs in zip(['train', 'val', 'test'], [10000, 1000, 1000]):
    src_bscan_prefix = join(op_bscan_folder, split)
    src_label_prefix = join(op_label_folder , split)
    dst_bscan_prefix = join(aroi_bscan_folder, split)
    dst_label_prefix = join(aroi_label_folder, split)
    sampled_pairs[split] = sample(src_bscan_prefix, src_label_prefix, dst_bscan_prefix, dst_label_prefix, n_pairs)

### 2. Create a mask for a pair

In [None]:
def load_as_array(img_path, label_type=None):
    """Load the image from a path to a numpy array
    If label type is set, it will convert and unify the label.
    label_type accepts None or op or aroi
    """
    img = Image.open(img_path)
    img_arr = np.asarray(img)
    if label_type == 'op':
        img_arr = unify_label(img_arr, src_labels=op_label, dst_labels=unified_label)
    elif label_type == 'aroi':
        img_arr = unify_label(img_arr, src_labels=aroi_label, dst_labels=unified_label, remove_list=[80, 160, 240])
    return img_arr

def create_mask(src_bscan, src_label, dst_bscan, dst_label):
    # 2. expand source label
    expanded_src_label = expand_label(src_label,
                                      instrument_label=unified_label.instrument,
                                      mirror_label=unified_label.mirror,
                                      expansion_instrument=30,
                                      expansion_mirror=0,
                                      expand_upward=True)
    # 3. get shadowed area based on the dst label
    shadow_x, shadow_y = get_dst_shadow(expanded_src_label,
                                        dst_label,
                                        instrument_label=unified_label.instrument,
                                        mirror_label=unified_label.mirror,
                                        top_layer_label=unified_label.ilm,
                                        margin_above=20,
                                        pad_left=0,
                                        pad_right=0)
    # 4. create the mask
    mask = np.full(src_bscan.shape, False)
    classes_of_interest = [unified_label.instrument, unified_label.mirror]
    for c in classes_of_interest:
        mask[expanded_src_label==c] = True
    mask[shadow_x, shadow_y] = True
    return mask

### 3. Create cross over bscan and label

In [None]:
sampled_pair = sampled_pairs['train'][233]
print(sampled_pair)
src_bscan_path, src_label_path, dst_bscan_path, dst_label_path = sampled_pair
loaded_pair = (
    load_as_array(src_bscan_path),
    load_as_array(src_label_path, label_type='op'),
    load_as_array(dst_bscan_path),
    load_as_array(dst_label_path, label_type='aroi')    
)
mask = create_mask(*loaded_pair)
cross_bscan, cross_label = create_crossover(*loaded_pair, mask)
# optional: expand cross label
# cross_label = expand_label(cross_label,
#                            instrument_label=unified_label.instrument,
#                            mirror_label=unified_label.mirror,
#                            expansion_instrument=20,
#                            expansion_mirror=40)
visualize_segmentation(cross_bscan, cross_label, show_original=True)

In [None]:
plot_horizontal(loaded_pair+(cross_bscan, cross_label, mask),
                ['s_bscan', 's_label', 'd_bscan', 'd_label', 'c_bscan', 'c_label', 'mask'],
                figsize=(15, 4))

### 4. Mass produce a dataset

In [None]:
cross_dataset_path = C.SPLIT_PATTERN.format(data='cross')
for typ in ['bscans', 'labels']:
    for split in ['train', 'val', 'test']:
        data_dir = join(cross_dataset_path, typ, split)
        Path(data_dir).mkdir(parents=True, exist_ok=True)

In [None]:
from os.path import basename
def create_cross_pair(src_bscan_path, src_label_path, dst_bscan_path, dst_label_path, expand=False):
    loaded_pair = (
        load_as_array(src_bscan_path),
        load_as_array(src_label_path, label_type='op'),
        load_as_array(dst_bscan_path),
        load_as_array(dst_label_path, label_type='aroi')    
    )
    mask = create_mask(*loaded_pair)
    cross_bscan, cross_label = create_crossover(*loaded_pair, mask)
    if expand:
        cross_label = expand_label(cross_label,
                                   instrument_label=unified_label.instrument,
                                   mirror_label=unified_label.mirror,
                                   expansion_instrument=20,
                                   expansion_mirror=40)
    return cross_bscan, cross_label

In [None]:
expanded_cross_dataset_path = C.SPLIT_PATTERN.format(data='cross_large')
for typ in ['bscans', 'labels']:
    for split in ['train', 'val', 'test']:
        data_dir = join(expanded_cross_dataset_path, typ, split)
        Path(data_dir).mkdir(parents=True, exist_ok=True)
bscan_dir = join(expanded_cross_dataset_path, 'bscans')
label_dir = join(expanded_cross_dataset_path, 'labels')

for split in ['train', 'val']:
    sampled_pair = sampled_pairs[split]
    for pair in tqdm(sampled_pair):
        cross_bscan, cross_label = create_cross_pair(*pair, expand=True)
        cross_bscan = Image.fromarray(cross_bscan, mode='L')
        cross_label = Image.fromarray(cross_label, mode='L')
        img_name = basename(pair[0]).split('.')[-2] + '-' + basename(pair[2]).split('.')[-2] + '.png'
        cross_bscan.save(os.path.join(bscan_dir, split, img_name))
        cross_label.save(os.path.join(label_dir, split, img_name))

## Part 4. Train a pix2pix on the artifitially generated images

Check [pix2pix train/test](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#pix2pix-traintest)

### 1. Create a dataset with the dataset creation script

`python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data`

In [None]:
cross_combined_dir = C.DATASET_PATTERN.format(data='cross_large_pix')

In [None]:
cross_combined_dir = C.DATASET_PATTERN.format(data='cross_large_pix')
!mkdir -p cross_combined_dir
cross_combined_dir

In [None]:
# !rm -rf $cross_combined_dir

In [None]:
!echo label_dir: $label_dir
!echo bscan_dir: $bscan_dir
!echo cross_combined_dir: $cross_combined_dir

In [None]:
!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A $label_dir --fold_B $bscan_dir --fold_AB $cross_combined_dir --grayscale

In [None]:
!ls -l data/datasets/cross_large_pix/train | wc -l

In [None]:
# expaned
expanded_cross_combined_dir = C.DATASET_PATTERN.format(data='expanded_cross_pix')
!echo label_dir: $label_dir
!echo bscan_dir: $bscan_dir
!echo expanded_cross_combined_dir: $expanded_cross_combined_dir
!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A $label_dir --fold_B $bscan_dir --fold_AB $expanded_cross_combined_dir --grayscale

In [None]:
expanded_cross_combined_dir

### 2. Train a pix2pix model

`python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA`

In [None]:
!python pytorch-CycleGAN-and-pix2pix/train.py  --help

#### Vanilla cross label

**Intuition**: The desired output is the bscan with part from the AROI and part from the OP. Therefore we can explicitly fabricate the desired output as well as the cross-label input, with simple cut-and-paste.

--name was `pix_fused`

```verbose log
----------------- Options ---------------
               batch_size: 64                            	[default: 1]
                    beta1: 0.5                           
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
                 dataroot: data/datasets/cross_large_pix 	[default: None]
             dataset_mode: aligned                       
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: 1                             
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256                           
                    epoch: latest                        
              epoch_count: 1                             
                 gan_mode: vanilla                       
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: True                          	[default: None]
                lambda_L1: 100.0                         
                load_iter: 0                             	[default: 0]
                load_size: 286                           
                       lr: 0.0002                        
           lr_decay_iters: 50                            
                lr_policy: linear                        
         max_dataset_size: inf                           
                    model: pix2pix                       	[default: cycle_gan]
                 n_epochs: 400                           	[default: 100]
           n_epochs_decay: 100                           
               n_layers_D: 3                             
                     name: cross_large_pix               	[default: experiment_name]
                      ndf: 64                            
                     netD: basic                         
                     netG: unet_256                      
                      ngf: 64                            
               no_dropout: False                         
                  no_flip: False                         
                  no_html: False                         
                     norm: batch                         
              num_threads: 4                             
                output_nc: 3                             
                    phase: train                         
                pool_size: 0                             
               preprocess: resize_and_crop               
               print_freq: 500                           	[default: 100]
             save_by_iter: False                         
          save_epoch_freq: 20                            	[default: 5]
         save_latest_freq: 5000                          
           serial_batches: False                         
                   suffix:                               
         update_html_freq: 1000                          
                use_wandb: False                         
                  verbose: False       
```

In [None]:
!CUDA_VISIBLE_DEVICES=1 python pytorch-CycleGAN-and-pix2pix/train.py \
    --dataroot $cross_combined_dir \
    --name cross_large_pix \
    --model pix2pix \
    --direction AtoB \
    --n_epochs 400 \
    --print_freq 500 \
    --batch_size 64 \
    --save_epoch_freq 20

#### Expanded labels

**Intuition**: The expanded labels of the instruments covers the areas of the instruments, so that it helps guide the model to generate clearer instruments.

In [None]:
# train with expanded
!CUDA_VISIBLE_DEVICES=1 python pytorch-CycleGAN-and-pix2pix/train.py \
    --dataroot $expanded_cross_combined_dir \
    --name "expanded_pix_fused" \
    --model pix2pix \
    --direction AtoB \
    --n_epochs 400 \
    --print_freq 500 \
    --batch_size 64 \
    --save_epoch_freq 20

In [None]:
dataroot = 'data/datasets/cross_combined'
name = 'onehot_cross_pix'

#### One-hot input

We train a model with `input_nc=n_labels`, `output_nc=1` and `dataset_mode=oct`.

The *oct* mode is implemented in the submodule pix2pix.

**Intuition**: instead of cluttering the input labels with meaningless sequential numbers in a same channel, we can use one-hot encoding to represent the labels, so that different labels are semantically independent.

In [None]:
!echo $dataroot
!CUDA_VISIBLE_DEVICES=1 python submodules/pix2pix/train.py \
    --dataroot $dataroot \
    --name $name \
    --model pix2pix \
    --direction AtoB \
    --n_epochs 100 \
    --print_freq 500 \
    --batch_size 64 \
    --save_epoch_freq 20 \
    --input_nc 7 \
    --output_nc 1 \
    --dataset_mode oct

In [None]:
name = 'onehot_cross_pix_o3'
!echo $dataroot
!CUDA_VISIBLE_DEVICES=1 python submodules/pix2pix/train.py \
    --dataroot $dataroot \
    --name $name \
    --model pix2pix \
    --direction AtoB \
    --n_epochs 200 \
    --print_freq 500 \
    --batch_size 64 \
    --save_epoch_freq 20 \
    --input_nc 7 \
    --output_nc 3 \
    --dataset_mode oct

## Part 5: Test the Trained Model

Results are saved at `./results/pix_fused/test_{epoch}`

In [None]:
for ep in [10, 20, 30, 40, 50, 60, 100, 150]:
    !CUDA_VISIBLE_DEVICES=1 python pytorch-CycleGAN-and-pix2pix/test.py  \
        --dataroot $cross_combined_dir \
        --name "pix_fused" \
        --model pix2pix \
        --direction AtoB \
        --epoch {ep} \
        --results_dir "fused_on_combined"

Evaluate on iOCT data:

In [None]:
uop_dataset_dir = C.DATASET_PATTERN.format(data='uop')
for ep in [80, 100, 150]:
    !CUDA_VISIBLE_DEVICES=1 python pytorch-CycleGAN-and-pix2pix/test.py  \
        --dataroot $uop_dataset_dir \
        --name "pix_fused" \
        --model pix2pix \
        --direction AtoB \
        --epoch {ep} \
        --batch_size 10 \
        --results_dir "./results/fused_on_ioct"

In [None]:
uop_dataset_dir = C.DATASET_PATTERN.format(data='uop')
uop_dataset_dir

Evaluate the model with expanded instrument label on iOCT:

In [None]:
uop_dataset_dir = C.DATASET_PATTERN.format(data='uop')
for ep in [20, 40, 60, 140, 300]:
    !CUDA_VISIBLE_DEVICES=1 python pytorch-CycleGAN-and-pix2pix/test.py  \
        --dataroot $uop_dataset_dir \
        --name "expanded_pix_fused" \
        --model pix2pix \
        --direction AtoB \
        --epoch {ep} \
        --batch_size 10 \
        --results_dir "./results/expanded_pix_fused"

Evaluate the model with **onehot** label:

The result doesn't looks good for onehot-labeled input. According to [this comment](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/575#issuecomment-476904244), the author of pix2pix pointed out that the input encoding are treated as RGB images, thus won't work well with too many categories. He further said that in pix2pixHD and SPADE, the input map are treated as one-hot label map.

In [None]:
uop_dataset_dir = C.DATASET_PATTERN.format(data='uop')
name = 'onehot_cross_pix'
results_dir = "./results/" + name
for ep in [200]:
    !CUDA_VISIBLE_DEVICES=1 python submodules/pix2pix/test.py  \
        --dataroot $uop_dataset_dir \
        --name $name \
        --model pix2pix \
        --direction AtoB \
        --epoch {ep} \
        --batch_size 10 \
        --results_dir $results_dir \
        --input_nc 7 \
        --output_nc 1 \
        --dataset_mode oct