## Generating examples of augmented images 

In [1]:
from bones_dataloader import *

dataset_dir = "../../datasets/BONE_CHANNELS"

batch_size = 6
patch_size = (640, 640)
color_model = "RGB"

augmentation_strategy = "random" # "no_augmentation", "color_augmentation", "inpainting_augmentation", "standard", "random"
augmentation = [None,
                "horizontal_flip", 
                "vertical_flip", 
                "rotation", 
                "transpose", 
                "elastic_transformation", 
                "grid_distortion", 
                "optical_distortion",
                #"color_transfer", 
                #"inpainting"]
]

dataloaders = create_dataloader(tile_size="{}x{}".format(patch_size[0], patch_size[1]),
                                batch_size=batch_size, 
                                shuffle=True,
                                img_input_size=patch_size,
                                img_output_size=patch_size,
                                dataset_dir=dataset_dir,
                                color_model=color_model,
                                augmentation=augmentation,
                                augmentation_strategy=augmentation_strategy,
                                start_epoch=1,
                                validation_split=0.2)

dataset_train_size = len(dataloaders['train'].dataset)
dataset_test_size = len(dataloaders['test'].dataset)

logger.info(dataset_test_size)
logger.info(dataset_train_size)

2023-11-06 14:03:12,773 :: INFO load_dataset :: [training] ../../datasets/BONE_CHANNELS/training
2023-11-06 14:03:12,844 :: INFO load_dataset :: [training] ../../datasets/BONE_CHANNELS/training
2023-11-06 14:03:12,913 :: INFO load_dataset :: [testing] ../../datasets/BONE_CHANNELS/testing
2023-11-06 14:03:12,935 :: INFO create_dataloader :: Train images (640x640): 1258 augmentation: random
2023-11-06 14:03:12,935 :: INFO create_dataloader :: Valid images (640x640): 315 augmentation: no_augmentation
2023-11-06 14:03:12,936 :: INFO create_dataloader :: Test images (640x640): 464 augmentation: no_augmentation
2023-11-06 14:03:12,936 :: INFO <module> :: 464
2023-11-06 14:03:12,937 :: INFO <module> :: 1258


In [None]:
from PIL import Image 
  
# open method used to open different extension image file 
image = Image.open(dataloaders['train'].dataset.samples[4][0])  
mask = Image.open(dataloaders['train'].dataset.samples[4][1])  

data_augmentation(image, mask, mask)


In [22]:
import os
import sys
import torch
import torch.nn as nn
import torchvision.utils as vutils
import torchvision.transforms.functional as TF

import random
import matplotlib.pyplot as plt

current_path = os.path.abspath('.')
root_path = os.path.dirname(os.path.dirname(current_path))
sys.path.append(root_path)

from sourcecode.wsi_image_utils import *
from sourcecode.logger_utils import *
from sourcecode.GAN.model.networks import Generator
from sourcecode.GAN.utils.tools import get_config, random_bbox, mask_image, is_image_file, default_loader, normalize, get_model_list

from torchvision import transforms

from albumentations import (
    Transpose,
    RandomRotate90,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion
)

def data_augmentation(input_image, target_img, output_mask, img_input_size=(640, 640), img_output_size=(640, 640), aug=None, GAN_model=None):

    image = TF.resize(input_image, size=img_output_size)
    target_image = TF.resize(target_img, size=img_output_size) if target_img is not None else None
    mask = TF.resize(output_mask, size=img_output_size) if output_mask is not None and np.any(
        np.unique(pil_to_np(output_mask) > 0)) else None

    used_augmentations = []
    if True:

        # Random horizontal flipping
        if True:
            image1 = TF.hflip(image)
            mask1 = TF.hflip(mask) if mask is not None else None
            used_augmentations.append("horizontal_flip")
            image1.save("temp/205_r3c7_hflip.png")  
            mask1.save("temp/205_r3c7_mask_hflip.png")  

        # Random vertical flipping
        if True:
            image1 = TF.vflip(image)
            mask1 = TF.vflip(mask) if mask is not None else None
            used_augmentations.append("vertical_flip")
            image1.save("temp/205_r3c7_vflip.png")  
            mask1.save("temp/205_r3c7_mask_vflip.png")  

        # Random rotation
        if True:
            augmented = RandomRotate90(p=1)(image=np.array(image),
                                            mask=np.array(mask) if mask is not None else np.zeros(img_output_size))
            image1 = Image.fromarray(augmented['image'])
            mask1 = Image.fromarray(augmented['mask'])
            used_augmentations.append("rotation")
            image1.save("temp/205_r3c7_rotation.png")  
            mask1.save("temp/205_r3c7_mask_rotation.png") 

        # Random transpose
        if True:
            augmented = Transpose(p=1)(image=np.array(image),
                                       mask=np.array(mask) if mask is not None else np.zeros(img_output_size))
            image1 = Image.fromarray(augmented['image'])
            mask1 = Image.fromarray(augmented['mask'])
            used_augmentations.append("transpose")
            image1.save("temp/205_r3c7_transpose.png")  
            mask1.save("temp/205_r3c7_mask_transpose.png") 

        # Random elastic transformation
        if True:
            alpha = random.randint(100, 200)
            augmented = ElasticTransform(p=1, alpha=alpha, sigma=alpha * 0.05, alpha_affine=alpha * 0.03)(
                image=np.array(image), mask=np.array(mask) if mask is not None else np.zeros(img_output_size))
            image1 = Image.fromarray(augmented['image'])
            mask1 = Image.fromarray(augmented['mask'])
            used_augmentations.append("elastic_transformation")
            image1.save("temp/205_r3c7_elastic_transformation.png")  
            mask1.save("temp/205_r3c7_mask_elastic_transformation.png") 

        # Random GridDistortion
        if True:
            augmented = GridDistortion(p=1)(image=np.array(image),
                                            mask=np.array(mask) if mask is not None else np.zeros(img_output_size))
            image1 = Image.fromarray(augmented['image'])
            mask1 = Image.fromarray(augmented['mask'])
            used_augmentations.append("grid_distortion")
            image1.save("temp/205_r3c7_grid_distortion.png")  
            mask1.save("temp/205_r3c7_mask_grid_distortion.png") 

        # Random OpticalDistortion
        if True:
            augmented = OpticalDistortion(p=1, distort_limit=1, shift_limit=0.5)(image=np.array(image),
                                                                                 mask=np.array(
                                                                                     mask) if mask is not None else np.zeros(
                                                                                     img_output_size))
            image1 = Image.fromarray(augmented['image'])
            mask1 = Image.fromarray(augmented['mask'])
            used_augmentations.append("optical_distortion")
            image1.save("temp/205_r3c7_optical_distortion.png")  
            mask1.save("temp/205_r3c7_mask_optical_distortion.png") 

        # Color transfer augmentation
        if False:
            
            original_img_lab = TF.to_tensor(image).permute(1, 2, 0).numpy()
            target_img_lab = TF.to_tensor(target_image).permute(1, 2, 0).numpy()

            _, _, augmented_img = transfer_color(original_img_lab, target_img_lab)
            image1 = transforms.ToPILImage()(torch.from_numpy(augmented_img).permute(2, 0, 1))
            used_augmentations.append("color_transfer")
            image1.save("temp/205_r3c7_color_transfer.png")  
            mask1.save("temp/205_r3c7_mask_color_transfer.png") 

        # Inpainting augmentation
        if False:
            
            width, height = image.size
            sourcecode_dir = os.path.dirname(os.path.abspath('.'))
            config_file = os.path.join(sourcecode_dir, 'GAN/configs/config_imagenet_ocdc.yaml')
            config = get_config(config_file)

            # Setting the points for cropped image
            crop_size = config['image_shape']
            left = np.random.randint(0, width-crop_size[0])
            top = np.random.randint(0, height-crop_size[1])

            cropped_region = image.crop((left, top, left+crop_size[0], top+crop_size[1]))
            cropped_region = pil_to_np(cropped_region)
            cropped_region = lab_to_rgb(cropped_region)
            cropped_region = transforms.ToTensor()(cropped_region)
            inpainting_img = cropped_region.detach().clone().mul_(2).add_(-1)        # normalize between -1 and 1
            inpainting_img = inpainting_img.unsqueeze(dim=0).to(dtype=torch.float32) # adds the batch channel
            
            bboxes = random_bbox(config, batch_size=inpainting_img.size(0))
            inpainting_img, inpainting_mask = mask_image(inpainting_img, bboxes, config)

            if torch.cuda.is_available():
                GAN_model = nn.parallel.DataParallel(GAN_model)
                inpainting_img = inpainting_img.cuda()
                inpainting_mask = inpainting_mask.cuda()

            # Inpainting inference
            x1, x2, offset_flow = GAN_model(inpainting_img, inpainting_mask)
            inpainted_result = x2 * inpainting_mask + inpainting_img * (1. - inpainting_mask)
            inpainted_result = inpainted_result.squeeze(0).add_(1).div_(2) # renormalize between 0 and 1
            inpainted_result = transforms.ToTensor()(rgb_to_lab(inpainted_result.permute(1, 2, 0).cpu().detach().numpy()))
            
            #viz_images = torch.stack([inpainting_img, inpainted_result.unsqueeze(dim=0).cuda()], dim=1)
            #viz_images = viz_images.view(-1, *list(inpainting_img.size())[1:])
            #vutils.save_image(viz_images,
            #                    '/home/dalifreire/Pictures/augmentation/teste_%03d.png' % (random.randint(0, 999)),
            #                    nrow=2 * 4,
            #                    normalize=True)
            
            augmented_img = TF.to_tensor(image)
            augmented_img[:, top:top+crop_size[1], left:left+crop_size[0]] = inpainted_result.squeeze(0)
            image1 = transforms.ToPILImage()(augmented_img)
            used_augmentations.append("inpainting")
            image1.save("temp/205_r3c7_inpainting.png")  
            mask1.save("temp/205_r3c7_mask_inpainting.png") 
