In [None]:
import os
import re
from glob import glob
import pathlib
from pathlib import Path
import random
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
import idp_utils.data_handling.constants as C

%cd $C.ROOT_PATH

seed = 6
random.seed(seed)
np.random.seed(seed)

# 1 Prepare Fake Labels & Fake Scans

In [None]:
label_mode = 'hetero'

OP_split_path = C.SPLIT_PATTERN.format(data='OP', name='original')
AROI_split_path = C.SPLIT_PATTERN.format(data='AROI', name='hetero')

OP_test_bscan_path = os.path.join(OP_split_path, 'bscans', 'test')
OP_test_label_path = os.path.join(OP_split_path, 'labels', 'test')
AROI_test_bscan_path = os.path.join(AROI_split_path, 'bscans', 'test')
AROI_test_label_path = os.path.join(AROI_split_path, 'labels', 'test')

num_test_OP = len(os.listdir(OP_test_bscan_path))
num_test_AROI = len(os.listdir(AROI_test_bscan_path))
print('num of test in op:', num_test_OP)
print('num of test in aroi:', num_test_AROI)

In [None]:
OP_label_sample_name = os.listdir(OP_test_label_path)[0]
OP_bscan_sample_name = os.listdir(OP_test_bscan_path)[0]
assert OP_label_sample_name == OP_bscan_sample_name
AROI_label_sample_name = os.listdir(AROI_test_label_path)[0]
AROI_bscan_sample_name = os.listdir(AROI_test_bscan_path)[0]
assert AROI_bscan_sample_name == AROI_label_sample_name

OP_bscan_shape = np.asarray(Image.open(os.path.join(OP_test_bscan_path, OP_bscan_sample_name))).shape
OP_label_shape  = np.asarray(Image.open(os.path.join(OP_test_label_path, OP_label_sample_name))).shape
assert OP_bscan_shape == OP_label_shape
print(f'OP bscan shape: {OP_bscan_shape}')
AROI_bscan_shape = np.asarray(Image.open(os.path.join(AROI_test_bscan_path, AROI_bscan_sample_name))).shape
AROI_label_shape = np.asarray(Image.open(os.path.join(AROI_test_label_path, AROI_label_sample_name))).shape
assert AROI_bscan_shape == AROI_label_shape
print(f'AROI bscan shape: {AROI_bscan_shape}')


In [None]:
OP_labels_with_instrument = []
for label_path in glob(os.path.join(OP_test_label_path, '*.png')):
    label = np.asarray(Image.open(label_path))
    for value in C.INSTRUMENT_LABELS:
        if np.any(label == value):
            OP_labels_with_instrument.append(label_path)
print(f'num of OP labels with instrument: {len(OP_labels_with_instrument)}')

In [None]:
rng = np.random.default_rng(seed=seed)
sample_labels =rng.choice(len(OP_labels_with_instrument), num_test_AROI, replace=False)
OP_labels = np.array(OP_labels_with_instrument)[sample_labels]

AROI_labels = glob(os.path.join(AROI_test_label_path, '*.png'))

In [None]:
fake_labels = []
AROI_UNIQUE_LABELS = C.HETERO_AROI_LABELS + C.FLUID_LABELS
for i in range(num_test_AROI):
    OP_label = np.asarray(Image.open(OP_labels[i]))
    AROI_label = np.asarray(Image.open(AROI_labels[i]))
    unique_OP_labels = np.unique(AROI_label)

    # remove irrelevant labels from OP labels
    irre_OP_labels = [ label for label in np.unique(OP_label) if label not in C.INSTRUMENT_LABELS]
    for label in irre_OP_labels:
        OP_label = np.where(OP_label == label, 0, OP_label)

    fake_label = OP_label + AROI_label

    # remove labels below OP labels
    start, end = fake_label.shape[1], 0
    for col in range(fake_label.shape[1]):
        for label in C.INSTRUMENT_LABELS:
            if label in fake_label[:, col]:
                if col < start:
                    start = col
                if col > end:
                    end = col
    for col in range(start, end+1):
        for label in AROI_UNIQUE_LABELS:
            fake_label[:, col] = np.where(fake_label[:, col] == label, 0, fake_label[:, col])

    fake_labels.append(fake_label)

In [None]:
# fake_bscan = np.zeros(AROI_bscan_shape, dtype=np.uint8)
# fake_bscans = [fake_bscan] * num_test_AROI

OP_bscans = [ re.sub('labels', 'bscans', label) for label in OP_labels]
OP_bscans = [ np.asarray(Image.open(bscan)) for bscan in OP_bscans]
fake_bscans = np.asarray(OP_bscans)

In [None]:
fake_split_folder = C.SPLIT_PATTERN.format(data='OPAROI', name='fake')
pathlib.Path(os.path.join(fake_split_folder, 'labels', 'train')).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(fake_split_folder, 'labels', 'val')).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(fake_split_folder, 'labels', 'test')).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(fake_split_folder, 'bscans', 'train')).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(fake_split_folder, 'bscans', 'val')).mkdir(parents=True, exist_ok=True)
pathlib.Path(os.path.join(fake_split_folder, 'bscans', 'test')).mkdir(parents=True, exist_ok=True)

In [None]:
for i in range(num_test_AROI):
    label_img = Image.fromarray(fake_labels[i])
    bscan_img = Image.fromarray(fake_bscans[i])

    label_img.save(os.path.join(fake_split_folder, 'labels', 'test', '{}__{}.png'.format(OP_labels[i].split('/')[-1], AROI_labels[i].split('/')[-1])))
    bscan_img.save(os.path.join(fake_split_folder, 'bscans', 'test', '{}__{}.png'.format(OP_labels[i].split('/')[-1], AROI_labels[i].split('/')[-1])))

# 2 Generate Fake Labels Datasets

In [None]:
data, name = "OPAROI", "fake"

bscan_folder = os.path.join(C.SPLIT_PATTERN.format(data=data, name=name), "bscans")
label_folder = os.path.join(C.SPLIT_PATTERN.format(data=data, name=name), "labels")
dataset_folder = C.DATASET_PATTERN.format(data=data, name=name)

In [None]:
!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py \
    --fold_A "$label_folder" \
    --fold_B "$bscan_folder" \
    --fold_AB "$dataset_folder" 

# 3 Test

## 3.1 Test with Hetero-label

where the same class in different dataset has different label

In [None]:
data, name = "OPAROI", "fake"
dataset_folder = C.DATASET_PATTERN.format(data=data, name=name)
checkpoint_name = 'oparoi_heterolabel_pix2pix'

!python pytorch-CycleGAN-and-pix2pix/test.py \
        --dataroot $dataset_folder \
        --direction AtoB \
        --name $checkpoint_name \
        --model pix2pix \
        --num_test $num_test_AROI 

## 3.2 Test with Homo-label
where the same class in different dataset has the same label

In [None]:
data, name = "OPAROI", "original"
dataset_folder = C.DATASET_PATTERN.format(data=data, name=name)
checkpoint_name = 'oparoi_homolabel_pix2pix'

!python pytorch-CycleGAN-and-pix2pix/test.py \
        --dataroot $dataset_folder \
        --direction AtoB \
        --name $checkpoint_name \
        --model pix2pix \
        --num_test 5355

# 4 Visualization

In [None]:
def find_fake_img(path, file_name):
    for file in glob(os.path.join(path, '*.png')):
        if 'fake_B' in file.split('/')[-1] and file_name in file.split('/')[-1]:
            return file

In [None]:
AROI_split_path = C.SPLIT_PATTERN.format(data='AROI', name='hetero')
AROI_real_test_bscan_path = os.path.join(AROI_split_path, 'bscans', 'test')
AROI_fake_test_bscan_path = os.path.join('results', 'pix2pix_aroi_original', 'test_latest', 'images')
OP_fake_test_bscan_path = os.path.join('results', 'pix2pix_op_original', 'test_latest', 'images')

In [None]:
metric = fake_OP_file.split()[2:]
metric[-1] = metric[-1].split('_')[0]
metric

In [None]:
checkpoint_name = 'oparoi_heterolabel_pix2pix'
img_sample_folder = os.path.join('results', checkpoint_name, 'test_latest', 'images')
imgs = glob(os.path.join(img_sample_folder, '*.png'))
imgs_fakeB = []
for img in imgs:
    img_name = img.split('/')[-1]
    if 'fake_B' in img_name:
        imgs_fakeB.append(img)

In [None]:
rand_idx = random.randint(0, len(imgs_fakeB))
fake_B_sample_path = imgs_fakeB[rand_idx]

OP_file = fake_B_sample_path.split('/')[-1].split(' ')[0].split('__')[0]
AROI_file = fake_B_sample_path.split('/')[-1].split(' ')[0].split('__')[1]

fake_B_sample = np.asarray(Image.open(fake_B_sample_path))
real_A_sample = np.asarray(Image.open(re.sub('fake_B', 'real_A', fake_B_sample_path)))
real_OP_sample = np.asarray(Image.open(re.sub('fake_B', 'real_B', fake_B_sample_path)))
AROI_resize_shape = (fake_B_sample.shape[0], fake_B_sample.shape[1])
real_AROI_sample = np.asarray(Image.open(os.path.join(AROI_real_test_bscan_path, AROI_file)).resize(AROI_resize_shape))

fake_OP_file = find_fake_img(OP_fake_test_bscan_path, OP_file.split('.')[0])
fake_AROI_file = find_fake_img(AROI_fake_test_bscan_path, AROI_file.split('.')[0])
fake_OP_sample = np.asarray(Image.open(fake_OP_file))
fake_AROI_sample = np.asarray(Image.open(fake_AROI_file))

fake_OP_metric = fake_OP_file.split()[2:]
fake_OP_metric[-1] = fake_OP_metric[-1].split('_')[0]
fake_OP_metric = " ".join(fake_OP_metric)
fake_AROI_metric = fake_AROI_file.split()[2:]
fake_AROI_metric[-1] = fake_AROI_metric[-1].split('_')[0]
fake_AROI_metric = " ".join(fake_AROI_metric)


fig, axs = plt.subplots(3, 2, figsize=(30, 12))
axs[0,0].imshow(real_A_sample)
axs[0,0].set_title('combined label', fontsize=15)
axs[0,1].imshow(fake_B_sample)
axs[0,1].set_title('generated image', fontsize=15)
axs[2,0].imshow(real_OP_sample)
axs[2,0].set_title('real OP', fontsize=15)
axs[2,1].imshow(real_AROI_sample, cmap='gray')
axs[2,1].set_title('real AROI', fontsize=15)
axs[1,0].imshow(fake_OP_sample, cmap='gray')
axs[1,0].set_title('fake OP  ' + fake_OP_metric, fontsize=15)
axs[1,1].imshow(fake_AROI_sample, cmap='gray')
axs[1,1].set_title('fake AROI  ' + fake_AROI_metric, fontsize=15)

plt.tight_layout()
plt.subplots_adjust(wspace=0)

In [None]:
# def visualize_combined(dataset='AROI', label_mode='hetero', checkpoint_name='oparoi_heterolabel_pix2pix'):
#     AROI_split_path = C.SPLIT_PATTERN.format(data=dataset, name=label_mode)
#     AROI_test_bscan_path = os.path.join(AROI_split_path, 'bscans', 'test')
    
#     img_sample_folder = os.path.join('results', checkpoint_name, 'test_latest', 'images')
#     imgs = glob(os.path.join(img_sample_folder, '*.png'))
#     imgs_fakeB = []
#     for img in imgs:
#         img_name = img.split('/')[-1]
#         if 'fake_B' in img_name:
#             imgs_fakeB.append(img)
            
#     rand_idx = random.randint(0, len(imgs_fakeB))
#     fake_B_sample_path = imgs_fakeB[rand_idx]

#     AROI_file = fake_B_sample_path.split('/')[-1].split('__')[-1].split()[0]

#     fake_B_sample = np.asarray(Image.open(fake_B_sample_path))
#     real_A_sample = np.asarray(Image.open(re.sub('fake_B', 'real_A', fake_B_sample_path)))
#     real_OP_sample = np.asarray(Image.open(re.sub('fake_B', 'real_B', fake_B_sample_path)))
#     AROI_resize_shape = (fake_B_sample.shape[0], fake_B_sample.shape[1])
#     real_AROI_sample = np.asarray(Image.open(os.path.join(AROI_test_bscan_path, AROI_file)).resize(AROI_resize_shape))

#     fig, axs = plt.subplots(1, 4, figsize=(20, 8))
#     axs[0].imshow(real_A_sample)
#     axs[0].set_title('combined label', fontsize=30)
#     axs[1].imshow(fake_B_sample)
#     axs[1].set_title('generated image', fontsize=30)
#     axs[2].imshow(real_OP_sample)
#     axs[2].set_title('real OP', fontsize=30)
#     axs[3].imshow(real_AROI_sample, cmap='gray')
#     axs[3].set_title('real AROI', fontsize=30)

#     plt.tight_layout()

In [None]:
# visualize_combined(dataset='OPAROI', label_mode='original', checkpoint_name='oparoi_homolabel_pix2pix')