In [None]:
import os
os.chdir('/home/extra/micheal/dd_synthesis')

# Preicise cut and paste

In [None]:
# pip install opencv-python
from collections import namedtuple
from glob import glob
import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from utils.visualization import visualize_segmentation, plot_horizontal
from utils.mask import expand_label, get_dst_shadow

### Randomly choose source and target

In [None]:
Label = namedtuple("Label", "instrument mirror ilm ipl pre bm", defaults=(None,)*6)
unified_label = Label(instrument=1, mirror=2, ilm=3, ipl=4, pre=5, bm=6)
aroi_label = Label(ilm=19, ipl=57, pre=171, bm=190)
op_label = Label(instrument=2, mirror=4, ilm=1, pre=3)

def unify_label(label, src_labels, dst_labels, remove_list=[]):
    """Transform values in label from src_labels to dst_labels.
    It will return a new label. The original label will be left untouched.
    Args:
        label (numpy array)
        src_labels (namedtuple(Label))
        dst_labels (namedtuple(Label))
        remove_list (list(int)): all labels in this list will be set to 0
    """
    label_copy = label.copy()
    for l in Label._fields:
        s_label = getattr(src_labels, l)
        d_label = getattr(dst_labels, l)
        if s_label is not None and d_label is not None:
            label_copy[label == s_label] = d_label
    for l in remove_list:
        label_copy[label == l] = 0
    return label_copy

In [None]:
def get_bscan_label(img_name, bscan_prefix="data/ioct/bscans/val/", label_prefix="data/ioct/labels/val/"):
    bscan = Image.open(bscan_prefix + img_name)
    label = np.asarray(Image.open(label_prefix + img_name))
    return bscan, label

In [None]:
def create_crossover(src_bscan, src_label, dst_bscan, dst_label, mask):
    cross_bscan = dst_bscan.copy()
    cross_bscan[mask] = src_bscan[mask]
    cross_label = dst_label.copy()
    cross_label[mask] = src_label[mask]
    return cross_bscan, cross_label

In [None]:
bscan_prefix = "data/ioct/bscans/val/"
label_prefix = "data/ioct/labels/val/"
img_name = "5d396575-e039-49da-a219-fe239e8bd9c88062-101.png"

bscan, label = get_bscan_label(img_name, bscan_prefix, label_prefix)
label=unify_label(label, src_labels=op_label, dst_labels=unified_label)
visualize_segmentation(bscan, label, show_original=True)

In [None]:
dst_bscan_prefix = "data/aroi/bscans/train/"
dst_label_prefix = "data/aroi/labels/train/"
dst_img_name = "patient21_raw0060.png"  # "patient15_raw0032.png"
dst_bscan, dst_label = get_bscan_label(dst_img_name, dst_bscan_prefix, dst_label_prefix)
dst_label = unify_label(dst_label, src_labels=aroi_label, dst_labels=unified_label, remove_list=[80, 160, 240])
visualize_segmentation(dst_bscan, dst_label, show_original=True)

## Part 1: Manipulate on the bscan

The goal is the artifitially create a cross-over image.

### Try edge detection

In [None]:
bscan = np.asarray(bscan)

In [None]:
blurred = cv.blur(bscan,(5,5))
plt.imshow(blurred)

In [None]:
# edges = cv.Canny(blurred, 100, 200, L2gradient=False, apertureSize=3)
edges = bscan > 100
plt.subplot(121),plt.imshow(bscan,cmap = 'gray')
plt.title('Original Image'), plt.xticks([]), plt.yticks([])
plt.subplot(122),plt.imshow(edges,cmap = 'gray')
plt.title('Edge Image'), plt.xticks([]), plt.yticks([])
plt.show()

### Step 1: Copy and paste the tool with the manually expanded mask

Such mask is very inaccurate.

TODO:
- improve mask

#### 1. Expand instrument and mirroring label

In [None]:
expanded_label = expand_label(label,
                              instrument_label=unified_label.instrument,
                              mirror_label=unified_label.mirror,
                              expansion_instrument=30,
                              expansion_mirror=0,
                              expand_upward=True)
plt.imshow(expanded_label)

#### 2. Cover the layers with the target label

In [None]:
def get_dst_shadow(src_label, dst_label, instrument_label=2, mirror_label=4,
                   top_layer_label=1, margin_above=0, pad_left=0, pad_right=0):
    """Get the shadow of the source label in the destination label, taking
    the instrument and shadow label as well as the layer label in the destination
    into account
    Args:
    overflow_above: the margin to include above the top layer. This is for the cases
        that the human labeled top layer is inaccurate and leaves some pixels out.
    pad_left: the number of pixels to pad to the left of the shadow. This is for the
        that direct under of the instrument actually includes some layers.
    pad_right: the number of pixels to pad to the right of the shadow. This is for the
        that direct under of the instrument actually includes some layers.
    """
    img_height, img_width = src_label.shape
    shadow_x = np.array([], dtype=np.int64)
    shadow_y = np.array([], dtype=np.int64)
    # Requirements for the shadow label:
    # 1. Horizontally after the starting of the instrument/mirroring & before the
    #    ending of the instrument/mirroring
    # 2. Vertically below the upper bound of layers
    x_src_tool, y_src_tool = np.where(np.logical_or(src_label == instrument_label,
                    src_label == mirror_label))  # (1024, 512)
    if len(x_src_tool) == 0:
        return shadow_x, shadow_y
    left_bound = np.min(y_src_tool)
    right_bound = np.max(y_src_tool)
    # Detect left break and right break of the top layer, this is to adjust the left and
    # right bound of the shadow.
    for y in range(left_bound, img_width):
        # If the layer continues to present to the right of left_bound below the tools,
        # increase left_bound
        if np.any(src_label[:, y] == top_layer_label):
            left_bound = y
        else:
            break
    for y in range(right_bound, -1, -1):
        # If the layer continues to present to the left of right_bound below the tools,
        # decrease right_bound
        if np.any(src_label[:, y] == top_layer_label):
            right_bound = y
        else:
            break
    if pad_left + left_bound < right_bound:
        left_bound += pad_left
    if right_bound - pad_left > left_bound:
        right_bound -= pad_right
    accumulated_min_upperbound = 0
    for i in range(left_bound, right_bound):
        top_layer = np.where(dst_label[:, i] == top_layer_label)[0]
        if len(top_layer) == 0:
            if accumulated_min_upperbound == 0:
                continue
            else:
                # set to current recorded highest layer
                top_layer_upperbound = accumulated_min_upperbound
        else:
            # print("instrument_above", instrument_above, len(instrument_above))
            top_layer_upperbound = np.min(top_layer)
            if top_layer_upperbound - margin_above > 0:
                top_layer_upperbound -= margin_above
            if accumulated_min_upperbound == 0:
                # initialize
                accumulated_min_upperbound = top_layer_upperbound
            else:
                accumulated_min_upperbound = min(
                    accumulated_min_upperbound, top_layer_upperbound)
        x_vertical = np.arange(top_layer_upperbound,
                               img_height)  # upperbound to bottom
        y_vertical = np.full_like(x_vertical, i)
        shadow_x = np.concatenate([shadow_x, x_vertical])
        shadow_y = np.concatenate([shadow_y, y_vertical])
    return shadow_x, shadow_y

In [None]:
shadow_x, shadow_y = get_dst_shadow(expanded_label, dst_label, instrument_label=unified_label.instrument,
                                    mirror_label=unified_label.mirror, top_layer_label=unified_label.ilm,
                                    margin_above=5, pad_left=0)

In [None]:
mask = np.full(label.shape, False)
classes_of_interest = [unified_label.instrument, unified_label.mirror]
for c in classes_of_interest:
    mask[expanded_label==c] = True
mask[shadow_x, shadow_y] = True
plt.imshow(mask)

#### 3. Copy and past

In [None]:
dst_bscan = np.asarray(dst_bscan)
bscan = np.asarray(bscan)

In [None]:
cross_bscan = dst_bscan.copy()
cross_bscan[mask] = bscan[mask]
# cross_bscan = np.ma.array(cross_bscan, mask=mask)
plt.imshow(cross_bscan)

## Part 2: Manipulate the label

The goal of this part is to create the corresponding label for the first part. The intuition is that it is very precise to manipulate on the label map. Then we can generate a image from the manipulated label map.

In [None]:
cross_bscan, cross_label = create_crossover(bscan, label, dst_bscan, dst_label, mask)
visualize_segmentation(cross_bscan, cross_label, show_original=True)

## Part 3: Create a dataset

Per requirements to the [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md#pix2pix-datasets), we have to create the dataset and place it using a certain format.

The ideal way is to implement a custom sampler [link](https://discuss.pytorch.org/t/how-to-generate-random-pairs-at-each-epoch/112065/2).

Here I'll do it in a simpler way.
1. Randomly choose a OP and an AROI image
2. Create a mask based on the OP and the AROI image
3. Make the cross over label and bscan
4. Save to disk

### 1. Create a pair

In [None]:
from tqdm import tqdm
import random

op_bscan_folder = "data/ioct/bscans/"
op_label_folder = "data/ioct/labels/"
aroi_bscan_folder = "data/aroi/bscans/"
aroi_label_folder = "data/aroi/labels/"
def sample(src_bscan_prefix, src_label_prefix, dst_bscan_prefix, bscan_label_prefix, n_pairs=1000):
    """
    Sample pairs of source and target images. Source is where we will cut the instruments and
    mirrorings, destination is where we want to preserve the layers. The passed in prefixs are
    expected to directly contain paired images respectively.
    
    Returns:
        pairs (list(tuple)): a list of tuples of paths [(src_bscan, src_label, dst_bscan, dst_label)]
    """
    pairs = []
    src_bscans = glob(os.path.join(src_bscan_prefix, '*'))
    src_labels = glob(os.path.join(src_label_prefix, '*'))
    dst_bscans = glob(os.path.join(dst_bscan_prefix, '*'))
    dst_labels = glob(os.path.join(dst_label_prefix, '*'))
    assert len(src_bscans) == len(src_labels) and len(src_bscans) != 0, f"Length mismatch for bscans and labels ({len(src_bscans)}!={len(src_labels)})"
    assert len(dst_bscans) == len(dst_labels) and len(dst_bscans) != 0, f"Length mismatch for bscans and labels ({len(dst_bscans)}!={len(dst_labels)})"
    for _ in tqdm(range(n_pairs)):
        src_idx = random.randrange(len(src_bscans))
        dst_idx = random.randrange(len(dst_bscans))
        src_bscan = src_bscans[src_idx]
        src_label = src_labels[src_idx]
        dst_bscan = dst_bscans[dst_idx]
        dst_label = dst_labels[dst_idx]
        pairs.append((src_bscan, src_label, dst_bscan, dst_label))
    return pairs

In [None]:
sampled_pairs = {}
for split, n_pairs in zip(['train', 'val'], [3000, 300]):
    src_bscan_prefix = op_bscan_folder + split
    src_label_prefix = op_label_folder + split
    dst_bscan_prefix = aroi_bscan_folder + split
    dst_label_prefix = aroi_label_folder + split
    sampled_pairs[split] = sample(src_bscan_prefix, src_label_prefix, dst_bscan_prefix, dst_label_prefix, n_pairs)

### 2. Create a mask for a pair

In [None]:
def load_as_array(img_path, label_type=None):
    """Load the image from a path to a numpy array
    If label type is set, it will convert and unify the label.
    label_type accepts None or op or aroi
    """
    img = Image.open(img_path)
    img_arr = np.asarray(img)
    if label_type == 'op':
        img_arr = unify_label(img_arr, src_labels=op_label, dst_labels=unified_label)
    elif label_type == 'aroi':
        img_arr = unify_label(img_arr, src_labels=aroi_label, dst_labels=unified_label, remove_list=[80, 160, 240])
    return img_arr

def create_mask(src_bscan, src_label, dst_bscan, dst_label):
    # 2. expand source label
    expanded_src_label = expand_label(src_label,
                                      instrument_label=unified_label.instrument,
                                      mirror_label=unified_label.mirror,
                                      expansion_instrument=30,
                                      expansion_mirror=0,
                                      expand_upward=True)
    # 3. get shadowed area based on the dst label
    shadow_x, shadow_y = get_dst_shadow(expanded_src_label,
                                        dst_label,
                                        instrument_label=unified_label.instrument,
                                        mirror_label=unified_label.mirror,
                                        top_layer_label=unified_label.ilm,
                                        margin_above=20,
                                        pad_left=0,
                                        pad_right=0)
    # 4. create the mask
    mask = np.full(src_bscan.shape, False)
    classes_of_interest = [unified_label.instrument, unified_label.mirror]
    for c in classes_of_interest:
        mask[expanded_src_label==c] = True
    mask[shadow_x, shadow_y] = True
    return mask

### 3. Create cross over bscan and label

In [None]:
sampled_pair = sampled_pairs['train'][43]
print(sampled_pair)
src_bscan_path, src_label_path, dst_bscan_path, dst_label_path = sampled_pair
loaded_pair = (
    load_as_array(src_bscan_path),
    load_as_array(src_label_path, label_type='op'),
    load_as_array(dst_bscan_path),
    load_as_array(dst_label_path, label_type='aroi')    
)
mask = create_mask(*loaded_pair)
cross_bscan, cross_label = create_crossover(*loaded_pair, mask)
visualize_segmentation(cross_bscan, cross_label, show_original=True)

In [None]:
plot_horizontal(loaded_pair+(cross_bscan, cross_label, mask),
                ['s_bscan', 's_label', 'd_bscan', 'd_label', 'c_bscan', 'c_label', 'mask'],
                figsize=(15, 4))

### 4. Mass produce a dataset

In [None]:
!mkdir -p 'data/cross/labels/train'

In [None]:
from os.path import basename

bscan_dir = 'data/cross/bscans/'
label_dir = 'data/cross/labels/'
def create_cross_pair(src_bscan_path, src_label_path, dst_bscan_path, dst_label_path):
    loaded_pair = (
        load_as_array(src_bscan_path),
        load_as_array(src_label_path, label_type='op'),
        load_as_array(dst_bscan_path),
        load_as_array(dst_label_path, label_type='aroi')    
    )
    mask = create_mask(*loaded_pair)
    cross_bscan, cross_label = create_crossover(*loaded_pair, mask)
    return cross_bscan, cross_label

for split in ['train', 'val']:
    sampled_pair = sampled_pairs[split]
    for pair in tqdm(sampled_pair):
        cross_bscan, cross_label = create_cross_pair(*pair)
        cross_bscan = Image.fromarray(cross_bscan, mode='L')
        cross_label = Image.fromarray(cross_label, mode='L')
        img_name = basename(pair[0]).split('.')[-2] + '-' + basename(pair[2]).split('.')[-2] + '.png'
        cross_bscan.save(os.path.join(bscan_dir, split, img_name))
        cross_label.save(os.path.join(label_dir, split, img_name))