In [5]:
import sys
import os
import cv2
from PIL import Image
import io
import albumentations as A
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import gridspec

In [6]:
img_in_dir = "/nfs/stak/users/wigginno/hpc-share/data/branching_data/images"
mask_in_dir = "/nfs/stak/users/wigginno/hpc-share/data/branching_data/masks/"

In [7]:
img_paths = []
mask_paths = []
for img_name in os.listdir(img_in_dir):
    if ".png" in img_name:
        img_paths.append(os.path.join(img_in_dir, img_name))
for mask_name in os.listdir(mask_in_dir):
    if ".png" in mask_name:
        mask_paths.append(os.path.join(mask_in_dir, mask_name))

In [8]:
img_paths

['/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_9.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_6.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_4.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_2.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_11.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_5.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_1.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_8.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_3.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_10.png',
 '/nfs/stak/users/wigginno/hpc-share/data/branching_data/images/branching_7.png']

In [10]:
img_mask_pairs = []
for img_path in img_paths:
    img_name = img_path.split("/")[-1]
    mask_names = [x.split("/")[-1] for x in mask_paths]
    mask_idx = mask_names.index(img_name) if img_name in mask_names else None
    if mask_idx is not None:
        img = cv2.imread(img_path, 0)
        mask = cv2.imread(mask_paths[mask_idx], 0)
        img_mask_pairs.append((img,mask))

In [11]:
transformed_pairs = []

In [30]:
"""
transform = A.Compose(
    [A.ElasticTransform(interpolation = cv2.INTER_NEAREST,
                        border_mode = cv2.BORDER_CONSTANT,
                        same_dxdy = True,
                        approximate = True,
                        p=1)],
    additional_targets={'mask':'image'}
)
"""
transform = A.Compose(
    [A.GridDistortion(num_steps=5,
                      interpolation = cv2.INTER_NEAREST,
                      border_mode = cv2.BORDER_CONSTANT,
                      p=1)],
    additional_targets={'mask':'image'}
)

In [31]:
for img, mask in img_mask_pairs:
    print("...")
    transformed = transform(image=img, mask=mask)
    transformed_pair = (transformed['image'], transformed['mask'])
    transformed_pairs.append(transformed_pair)

...
...
...
...
...
...
...
...
...
...
...


In [32]:
def savepng(fig, filename):
    # optional, for saving a pyplot figure to a png file
    # uncomment the calls to this function in the visualize() function
    # also, plot titles can be commented out in the visualize() function
    png1 = io.BytesIO()
    fig.savefig(png1, format='png', bbox_inches='tight')
    png2 = Image.open(png1)
    png2.save(filename)
    png1.close()

In [33]:
def visualize(img_mask_pair, transformed_pair, index):
    # visualize 4 images: img, mask, augmented_img, augmented_mask

    plt.axis('off')

    img_original = plt.figure(figsize=(10, 10))
    img_augmented = plt.figure(figsize=(10, 10))
    mask_original = plt.figure(figsize=(10, 10))
    mask_augmented = plt.figure(figsize=(10, 10))

    # No x/y axis labels
    img_original_axis = plt.Axes(img_original, [0., 0., 1., 1.])
    img_original_axis.set_axis_off()
    img_original_axis.title.set_text(f"Original image {index}")
    img_original_axis.title.set_size(20)
    img_original.add_axes(img_original_axis)

    img_augmented_axis = plt.Axes(img_augmented, [0., 0., 1., 1.])
    img_augmented_axis.set_axis_off()
    img_augmented_axis.title.set_text(f"Original mask {index}")
    img_augmented_axis.title.set_size(20)
    img_augmented.add_axes(img_augmented_axis)

    mask_original_axis = plt.Axes(mask_original, [0., 0., 1., 1.])
    mask_original_axis.set_axis_off()
    mask_original_axis.title.set_text(f"Distorted image {index}")
    mask_original_axis.title.set_size(20)
    mask_original.add_axes(mask_original_axis)

    mask_augmented_axis = plt.Axes(mask_augmented, [0., 0., 1., 1.])
    mask_augmented_axis.set_axis_off()
    mask_augmented_axis.title.set_text(f"Distorted mask {index}")
    mask_augmented_axis.title.set_size(20)
    mask_augmented.add_axes(mask_augmented_axis)

    img_original_axis.imshow(img_mask_pair[0], cmap='gray', vmin=0, vmax=255)
    img_augmented_axis.imshow(img_mask_pair[1], cmap='gray', vmin=0, vmax=255)
    mask_original_axis.imshow(transformed_pair[0], cmap='gray', vmin=0, vmax=255)
    mask_augmented_axis.imshow(transformed_pair[1], cmap='gray', vmin=0, vmax=255)
 
    savepng(img_original, f'img{index}_original.png')
    savepng(img_augmented, f'img{index}_original_mask.png')
    savepng(mask_original, f'img{index}_transformed.png')
    savepng(mask_augmented, f'img{index}_transformed_mask.png')

    plt.close('all')


In [34]:
#plt.close('all')

In [35]:
# Show pair n (in both its original and augmented forms)
for n in range(len(img_mask_pairs)):
    visualize(img_mask_pairs[n], transformed_pairs[n], n)

In [None]:
# Show all pairs
for i in range(len(img_mask_pairs)):
    visualize(img_mask_pairs[i], transformed_pairs[i], i)