In [None]:
%cd /content/drive/MyDrive/CV02/polyp_dataset_project09

/content/drive/MyDrive/CV02/polyp_dataset_project09


In [None]:
import os, json, random
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

random.seed(2024)

In [None]:
RANDOM_STATE  = 2024
BATCH_SIZE    = 128
NUM_WORKERS   = 2
PIN_MEMORY    = True
TEST_SIZE     = 0.2
LEARNING_RATE = 0.001
NUM_EPOCHS    = 200
IMAGE_SIZE    = 128

# PATHS
ROOT            = '/content/drive/MyDrive/CV02/polyp_dataset_project09'
BEST_MODEL_PATH = os.path.join(ROOT, 'models', 'unet_realgen.pth')
IMAGE_DIR       = os.path.join(ROOT, 'images')
GEN_DIR         = os.path.join(ROOT, 'generated')
GEN_IMAGE_DIR   = os.path.join(GEN_DIR, 'images')
GEN_MASK_DIR    = os.path.join(GEN_DIR, 'images-mask')
OUTPUT_DIR      = os.path.join(ROOT, 'seg_output')
TRAIN_JSON      = os.path.join(ROOT, 'polyp_train.json')
VAL_JSON        = os.path.join(ROOT, 'polyp_valid.json')
TEST_JSON       = os.path.join(ROOT, 'polyp_test.json')

In [None]:
def load_json(path):
    data = None
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def get_full_path(path):
    path = path.replace('./', '')
    return os.path.join(ROOT, path)

def convert_to_mask_path(image_path):
    #mask_path = image_path.replace('img', 'mask')
    mask_path = image_path.replace('/images/', '/images-mask/')
    if os.path.isfile(mask_path):
        return mask_path
    else:
        raise ValueError(f'Mask path not found. {mask_path}')

# Get all generated image paths
gen_image_paths = [
    os.path.join(GEN_IMAGE_DIR, fn) for fn in os.listdir(GEN_IMAGE_DIR)
]

# Concatenate generated image paths with real image paths for training
polyp_train = load_json(TRAIN_JSON)['images']
train_image_paths = [get_full_path(item['image_path']) for item in polyp_train]
#train_image_paths.extend(gen_image_paths)
print(f'{len(train_image_paths)} train images.')

# Get real image paths for validation
polyp_val = load_json(VAL_JSON)['images']
val_image_paths = [get_full_path(item['image_path']) for item in polyp_val]
print(f'{len(val_image_paths)} val images.')

# Get real image paths for testing
polyp_test = load_json(TEST_JSON)['images']
test_image_paths = [get_full_path(item['image_path']) for item in polyp_test]
print(f'{len(test_image_paths)} test images.')

test_image_paths[0]

1200 train images.
400 val images.
400 test images.


'/content/drive/MyDrive/CV02/polyp_dataset_project09/images/NeoPolyp-Small/00fd197cd955fa095f978455cef3593c.jpg'

In [None]:
sample_gen_image_paths = random.sample(gen_image_paths, 50)

In [None]:
# Function to normalize images to the [0, 1] range
def normalize_image(image):
    if image.dtype == np.uint8:
        return image / 255.0
    elif image.dtype in [np.float32, np.float64]:
        return np.clip(image, 0, 1)
    else:
        raise ValueError(f"Unsupported image dtype: {image.dtype}")

# Function to overlay mask on image
def overlay_mask(image, mask, alpha=0.4):
    # Ensure image and mask have the same shape
    if image.shape[:2] != mask.shape[:2]:
        raise ValueError("Image and mask must have the same dimensions")

    # Remove the last channel of the generate image
    if image.shape[-1] > 3:
        image = image[...,:3]

    # Convert mask to 3-channel if it is single-channel
    if mask.ndim == 2:
        mask = np.stack((mask, mask, mask), axis=-1)

    # Normalize mask to be in the range [0, 1]
    mask = normalize_image(mask)

    # Overlay mask on the image
    overlay = ((1 - alpha) * image) + (alpha * mask)

    return overlay

# Plotting the images and masks
fig, axs = plt.subplots(5, 10, figsize=(40, 20))

for i, img_path in enumerate(sample_gen_image_paths):
    mask_path = convert_to_mask_path(img_path)
    image = mpimg.imread(img_path)
    mask = mpimg.imread(mask_path)

    # Normalize image
    image = normalize_image(image)

    # Check if the mask is single channel, if not convert to grayscale
    if mask.ndim == 3:
        mask = np.mean(mask, axis=2)

    overlayed_image = overlay_mask(image, mask)

    ax = axs[i // 10, i % 10]
    ax.imshow(overlayed_image)
    ax.axis('off')
    ax.set_title(os.path.basename(img_path))

plt.tight_layout()
plt.show()
fig.savefig('./gen_image_with_mask.jpeg', format='jpeg')