# <div align = 'center'> HuBMAP: Hacking the Kidney </div>
# <div align = 'center'> Identify glomeruli in human kidney tissue images </div>

# Table of contents <a id='0.1'></a>

1. [Introduction](#1)
2. [Import Packages](#2)
3. [Utility Functions](#4)
4. [Loading Data and overview](#3)
   * [3.1 Train Data](#5)
   * [3.2 HuBMAP-Metadata](#6)
   * [3.3 Test Data](#7)
   * [3.4 Train Imaegs](#8)
   * [3.5 Test Images](#9)
5. [Image + Segmentation Mask](#10)
   * [4.1 Image Tiff File](#11)
   * [4.2 Annotation json file](#12)
6. [EDA](#13)
   * [Individual Features](#14)
   * [Pandas Metadata profiling](#15)

7. [Creating Dataset for Training](#16)
   * [Idea](#17)
   * [Tiling](#18)
   * [Visualisation](#19)
8. [Data Preparation](#20)
   * [Filtering low band density](#21)
   * [Augmentation](#22)

# 1. <a id='1'>Introduction</a>
[Table of contents](#0.1)

We aim to develop a segmentation algorithm to identify the "Glomerulus" in the kidney.

We are given histological images of the kidney and annotation information representing the glomerular segmentation. Also we can use anatomical structure segmentation information and additional information (including anonymized patient data) about each image.

# 2. <a id='2'>Import Packages</a>
[Table of contents](#0.1)

In [None]:
%env SM_FRAMEWORK=tf.keras

In [None]:

# basic
import os
import cv2
import collections
import sys, gc
import warnings
import time, math
import numpy as np
import pandas as pd
import os.path as osp
from glob import glob
from pathlib import Path
import pandas_profiling as pp
from tqdm.notebook import tqdm
from path import Path

# visualize
import seaborn as sn
import matplotlib.pyplot as plt
 
# image preprocessing 
import json
import rasterio
import skimage.io
import tifffile as tiff
import zipfile
from rasterio.windows import Window
from PIL import Image, ImageDraw
from IPython.display import clear_output, Image as displayImage, display

# kaggle datasets
# from kaggle_datasets import KaggleDatasets

# # deep learning
import tensorflow as tf
# import segmentation_models as sm
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, Callback, LearningRateScheduler

# # cross validation
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split

# # logging
import wandb
from wandb.keras import WandbCallback
from kaggle_secrets import UserSecretsClient

%matplotlib inline
warnings.filterwarnings('ignore')
print(f'Wandb Version: {wandb.__version__}')
print(f'Seaborn Version: {sn.__version__}')
print(f'Tensorflow Version: {tf.__version__}')

# 3. <a id='3'>Loading Data and Overview</a>
[Table of contents](#0.1)

There are 3 .csv files containing

* train
* test
* anonymous patient data

There are addition two folders/directories containing

* images in .tiff format
* encoded annotations in .json format

## 3.1 <a id='5'>Train Data</a>
[Table of contents](#0.1)

There are 8 training set. This csv includes ids corresponding to data in train directory. Also it has mask data in "encoding" column. This data is encoded with RLE encoding.

In [None]:
train = pd.read_csv("../input/hubmap-kidney-segmentation/train.csv")
train.info()

In [None]:
train.head()

## 3.2 <a id='6'>HuBMAP metadata</a>
[Table of contents](#0.1)

This file includes additional information (including anonymized patient data) about each image

In [None]:
ds_info = pd.read_csv("../input/hubmap-kidney-segmentation/HuBMAP-20-dataset_information.csv")
ds_info.info()

In [None]:
ds_info.head()

## 3.3 <a id='7'>Test Data</a>
[Table of contents](#0.1)

There are 5 test set

In [None]:
test = pd.read_csv("../input/hubmap-kidney-segmentation/sample_submission.csv")
test.info()

In [None]:
test

## 3.4 <a id='8'>Train Images</a>
[Table of contents](#0.1)

* tiff files are kidney image data.
* json files include unencoded annotations.

In [None]:
os.listdir("../input/hubmap-kidney-segmentation/train")

In [None]:
image_1 = tiff.imread('../input/hubmap-kidney-segmentation/train/' + train.iloc[2,0] + ".tiff")
img_id_1 = train.iloc[2,0]
print("This image's id:", img_id_1)
image_1.shape

plt.figure(figsize=(10, 10))
plt.imshow(image_1)

## 3.5 <a id='9'>Test Images</a>
[Table of contents](#0.1)

In [None]:
os.listdir("../input/hubmap-kidney-segmentation/test")

In [None]:
image_1 = tiff.imread('../input/hubmap-kidney-segmentation/test/' + test.iloc[1,0] + ".tiff")
img_id_1 = test.iloc[1,0]
print("This image's id:", img_id_1)
image_1.shape

plt.figure(figsize=(10, 10))
plt.imshow(image_1)

### <a id='4'>Utility File </a>

In [None]:
# https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
 
def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

In [None]:
def read_single(img_path, msk_path):
    """ Read the image and mask from the given path. """
    image = cv2.imread(img_path, cv2.IMREAD_COLOR)
    mask = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
    return image, mask

def read_data(image_paths, mask_paths, gloms_only=False):
    images = []
    masks = []

    for img_path, msk_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):

        image, mask = read_single(img_path, msk_path)
        mask_density = np.count_nonzero(mask)   
        if gloms_only:
            if(mask_density>0):
                images.append(image)
                masks.append(mask)
        else:
            images.append(image)
            masks.append(mask)

    images = np.array(images)
    masks = np.array(masks)
    print('images shape:', images.shape)
    print('masks shape:', masks.shape)
    return images, masks

# 4 <a id='10'>Image + Segmentation Mask</a>
[Table of contents](#0.1)

## 4.1 <a id='11'>Image Tiff file </a>
[Table of contents](#0.1)

We are given histological images of the kidney. These images are tiff format. We can load this data with tifffile module.

In [None]:
image_1 = tiff.imread('../input/hubmap-kidney-segmentation/train/' + train.iloc[2,0] + ".tiff")
img_id_1 = train.iloc[2,0]
print("This image's id:", img_id_1)
image_1.shape

plt.figure(figsize=(5,5))
plt.imshow(image_1)

### mask
We can decode mask from encoding column of train.csv.

In [None]:
mask_1 = rle2mask(train.iloc[2, 1], (image_1.shape[1], image_1.shape[0]))
mask_1.shape

plt.figure(figsize=(10,10))
plt.imshow(mask_1, cmap='coolwarm', alpha=0.5)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(image_1)
plt.imshow(mask_1, cmap='coolwarm', alpha=0.5)

## 4.2 <a id='12'>Annotation json file</a>
[Table of contents](#0.1)

We have also two kinds of annotation files.

#### *Glomerulus segmentation file*
According to the description of dataset, the same information as the rle-encoded mask is stored in the .csv file

#### *Anatomical structure file*
This file contains anatomical structure segmentations. They are intended to help us identify the various parts of the tissues

In [None]:
with open(f"../input/hubmap-kidney-segmentation/train/aaa6a05cc-anatomical-structure.json") as f:
    anatomical_structure_json = json.load(f)
    
anatomical_structure_json

In [None]:
def flatten(l):
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

def draw_structure(structures, im):
    """
    anatomical_structure: list of points of anatomical_structure poligon.
    im: numpy array of image read from tiff file.
    """
    
    im = Image.fromarray(im)
    draw = ImageDraw.Draw(im)
    for structure in structures:
        structure_flatten = list(flatten(structure["geometry"]["coordinates"][0]))
        structure = []
        for i in range(0, len(structure_flatten), 2):
            structure.append(tuple(structure_flatten[i:i+2]))
        
        draw.line(structure, width=100, fill='Red')
    return im



In [None]:
plt.figure(figsize=(8,8))
image_1_with_line = draw_structure(anatomical_structure_json, image_1)
plt.imshow(image_1_with_line)

# 5. <a id='13'>EDA</a>
[Table of contents](#0.1)


## 5.1.  <a id='14'>Individual Features</a>

In [None]:
ds_info.head()

In [None]:
ds_info.shape

There are 13 data. Each data has 16 colmuns.

8 data are for training, and rest are test. It includes anonymized patient data.

In [None]:
df_info = ds_info
df_info["split"] = "test"
df_info.loc[df_info["image_file"].isin(os.listdir(os.path.join("../input/hubmap-kidney-segmentation", "train"))), 
            "split"] = "train"
df_info["area"] = df_info["width_pixels"] * df_info["height_pixels"]

In [None]:
plt.figure(figsize=(16, 35))
plt.subplot(6, 2, 1)
sn.countplot(x="race", hue="split", data=df_info)
plt.subplot(6, 2, 2)
sn.countplot(x="ethnicity", hue="split", data=df_info)
plt.subplot(6, 2, 3)
sn.countplot(x="sex", hue="split", data=df_info)
plt.subplot(6, 2, 4)
sn.countplot(x="laterality", hue="split", data=df_info)
plt.subplot(6, 2, 5)
sn.histplot(x="age", hue="split", data=df_info)
plt.subplot(6, 2, 6)
sn.histplot(x="weight_kilograms", hue="split", data=df_info)
plt.subplot(6, 2, 7)
sn.histplot(x="height_centimeters", hue="split", data=df_info)
plt.subplot(6, 2, 8)
sn.histplot(x="bmi_kg/m^2", hue="split", data=df_info)
plt.subplot(6, 2, 9)
sn.histplot(x="percent_cortex", hue="split", data=df_info)
plt.subplot(6, 2, 10)
sn.histplot(x="percent_medulla", hue="split", data=df_info)
plt.subplot(6, 2, 11)
sn.histplot(x="area", hue="split", data=df_info);

## 5.2. <a id='15'>Pandas Metadata Profiling</a>
[Table of contents](#0.1)

In [None]:
#https://towardsdatascience.com/exploratory-data-analysis-with-pandas-profiling-de3aae2ddff3

metadata_profile = pp.ProfileReport(ds_info)

In [None]:
metadata_profile

# 6. <a id='16'>Creating the dataset for training </a>
[Table of contents](#0.1)

## 6.1. <a id='17'>Idea</a>
[Table of contents](#0.1)

Resolution of images are huge and making it hard to analyse and use them to train any model. To make things easy, technique of tiling can be used. I'll start by using the image with the smallest resolution i.e, 7. aaa6a05cc.tiff

### <div align = 'center'> Image Tiling <div/>
For the beginning I will split 'aaa6a05cc.tiff' and store all files into the folder split:

Images will be stored in the folder split/images/ Mask-files will be stored in the folder split/masks/ Also I’m going to implement filtering. Images with 0-mask and located in the firs/last 2 rows/columns are totally useless for a further model training. Even in this case I will still have some 0-mask images, it also will be useful for the model

#### Idea:
* taking a random tile size (preferred 256 X 256 or 512 X 512)
* aligning the tile with the image and cropping out
* save the cropped file to the designated location
* move the tile forward and repeat the process
* repeat the same process with the corresponding annotation file
    
#### Input:
* image file
* train.csv for annotation

#### Output:
* A file containing info about created dataset
* A zip file containg tiled out images and mask

## 6.2. <a id='18'>Tiling</a>
[Table of contents](#0.1)

In [None]:
os.makedirs('../output')
input_dir = '../input/hubmap-kidney-segmentation/train'
output_dir = '../output'

In [None]:
#loading the CSVs

df = pd.read_csv(f'../input/hubmap-kidney-segmentation/train.csv')
sub_df = pd.read_csv(f'../input/hubmap-kidney-segmentation/sample_submission.csv')

In [None]:
# Those folders will store our images
os.makedirs(f'train_tiles/images', exist_ok=True)
os.makedirs(f'train_tiles/masks', exist_ok=True)

# This list will contain information about all our images
meta_ls = []

#defining tile size
tile_size = 256
#we can decreses the tile size to 256 X 256 to get even more number of images after tiling

# The break down starts here
for ix in range(1):
    img_id = df.id[ix]
    path = f"../input/hubmap-kidney-segmentation/train/aaa6a05cc.tiff"
    img = skimage.io.imread(path).squeeze()
    mask = rle2mask(train.iloc[2, 1], (image_1.shape[1], image_1.shape[0]))

    x_max, y_max = img.shape[:2]
    
    for x0 in tqdm(range(0, x_max, tile_size)):
        x1 = min(x_max, x0 + tile_size)
        for y0 in range(0, y_max, tile_size):
            y1 = min(y_max, y0 + tile_size)

            img_tile = img[x0:x1, y0:y1]
            mask_tile = mask[x0:x1, y0:y1]

            img_tile_path = f"train_tiles/images/{img_id}_{x0}-{x1}x_{y0}-{y1}y.png"
            mask_tile_path = f"train_tiles/masks/{img_id}_{x0}-{x1}x_{y0}-{y1}y.png"

            cv2.imwrite(img_tile_path, cv2.cvtColor(img_tile, cv2.COLOR_RGB2BGR))
            cv2.imwrite(mask_tile_path, mask_tile)

            meta_ls.append([
                img_id, x0, x1, y0, y1, img_tile_path, mask_tile_path
            ])

In [None]:
%%time
# c: create, q: quiet, f: file
!tar -cf train_tiles.tar train_tiles

In [None]:
#Creating the meta file

meta_df = pd.DataFrame(meta_ls, columns=['image_id', 'x0', 'x1', 'y0', 'y1', 'image_tile_path', 'mask_tile_path'])
meta_df.to_csv(f'train_metadata.csv', index=False)
meta_df.head()

In [None]:
#Count of Split images

len(os.listdir('train_tiles/images'))

In [None]:
from glob import glob
import random

multipleImages = glob('train_tiles/images/**')
def plotImages2():
    r = random.sample(multipleImages, 9)
    plt.figure(figsize=(20,20))
    plt.subplot(331)
    plt.imshow(cv2.imread(r[0])); plt.axis('off')
    plt.subplot(332)
    plt.imshow(cv2.imread(r[1])); plt.axis('off')
    plt.subplot(333)
    plt.imshow(cv2.imread(r[2])); plt.axis('off')
    plt.subplot(334)
    plt.imshow(cv2.imread(r[3])); plt.axis('off')
    plt.subplot(335)
    plt.imshow(cv2.imread(r[4])); plt.axis('off')
    plt.subplot(336)
    plt.imshow(cv2.imread(r[5])); plt.axis('off')
    plt.subplot(337)
    plt.imshow(cv2.imread(r[6])); plt.axis('off')
    plt.subplot(338)
    plt.imshow(cv2.imread(r[7])); plt.axis('off')
    plt.subplot(339)
    plt.imshow(cv2.imread(r[8])); plt.axis('off')

In [None]:
plotImages2()

# 8. <a id='20'>Data Preparation</a>
[Table of contents](#0.1)

In [None]:
import glob
image_paths = glob.glob("./train_tiles/images/*.png")
mask_paths = glob.glob("./train_tiles/masks/*.png")
len(image_paths)

In [None]:
len(mask_paths)

## 8.1 <a id='21'>Filtering low band density</a>
[Table of contents](#0.1)

In [None]:
lowband_density_values = []
mask_density_values = []

for img_path, msk_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
    image, mask = read_single(img_path, msk_path)
    img_hist = np.histogram(image)
    #print("img_hist", img_hist)
    lowband_density = np.sum(img_hist[0][0:4])
    mask_density = np.count_nonzero(mask)
    #print("lowband_density", lowband_density)
    #print("highband_density", highband_density)
    #print("mask_density", mask_density)
    lowband_density_values.append(lowband_density)
    mask_density_values.append(mask_density)
train_helper_df = pd.DataFrame(data=list(zip(image_paths, mask_paths, lowband_density_values,
                                             mask_density_values)),
                               columns=['image_path','mask_path', 'lowband_density', 'mask_density'])
train_helper_df.astype(dtype={'image_path':'object','mask_path':'object',
                                      'lowband_density':'int64', 'mask_density':'int64'})

### 7.1.1 <a>selecting images with tissues</a>
[Table of contents](#0.1)


In [None]:
images_tissue = train_helper_df[train_helper_df.lowband_density>100].image_path
masks_tissue = train_helper_df[train_helper_df.lowband_density>100].mask_path
images_tissue.shape

In [None]:
images, masks = read_data(images_tissue[1200:1218], masks_tissue[1200:1218])

### 7.1.2 <a>Visualisation</a>
[Table of contents](#0.1)

In [None]:
max_rows = 6
max_cols = 6
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,18))
fig.suptitle('Sample Images', y=0.93)
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(images[:plot_count], masks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    #sns.distplot(img_array.flatten(), ax=ax[1]);
    ax[row_masks, col].imshow(mas)

## 8.1. <a id='22'>Augmentation</a>
[Table of contents](#0.1)

Augmentation is done only on images with gloms

Validation samples are split and kept aside and it is not used for augmentation to avoid leakage of train data to val data

In [None]:
image_tissues_split, image_val_files, mask_tissues_split, mask_val_files = train_test_split(images_tissue, masks_tissue, test_size=0.30, random_state=17)
print("Split Counts\n\tImage_files:\t{0}\n\tMask_files:\t{2}\n\tVal Images:\t\t{1}\n\tVal Masks:\t\t{3}\n"
      .format(len(image_tissues_split), len(image_val_files), len(mask_tissues_split), len(mask_val_files)))

In [None]:
#https://albumentations.ai/ 
#https://www.kaggle.com/alexanderliao/image-augmentation-demo-with-albumentation

from albumentations import (
CLAHE,
ElasticTransform,
GridDistortion,
OpticalDistortion,
HorizontalFlip,
RandomBrightnessContrast,
RandomGamma,
HueSaturationValue,
RGBShift,
MedianBlur,
GaussianBlur,
GaussNoise,
ChannelShuffle,
CoarseDropout
)

def augment_data(image_paths, mask_paths):  

    if not os.path.exists('./hubmap_256x256_augmented/images_aug2'):
        os.makedirs('./hubmap_256x256_augmented/images_aug2')
    if not os.path.exists('./hubmap_256x256_augmented/masks_aug2'):
        os.makedirs('./hubmap_256x256_augmented/masks_aug2')

    for image, mask in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
        images_aug = []
        masks_aug = []
        image_name = Path(image).stem
        mask_name = Path(mask).stem

        x, y = read_single(image, mask)
        mask_density = np.count_nonzero(y)

        ## Augmenting only images with Gloms
        if(mask_density>0):

            try:
                h, w, c = x.shape
            except Exception as e:
                image = image[:-1]
                x, y = read_single(image, mask)
                h, w, c = x.shape

            aug = CLAHE(clip_limit=1.0, tile_grid_size=(8, 8), always_apply=False, p=1)
            augmented = aug(image=x, mask=y)
            x0 = augmented['image']
            y0 = augmented['mask']

            ## ElasticTransform
            aug = ElasticTransform(p=1, alpha=120, sigma=512*0.05, alpha_affine=512*0.03)
            augmented = aug(image=x, mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']

            ## Grid Distortion
            aug = GridDistortion(p=1)
            augmented = aug(image=x, mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']

            ## Optical Distortion
            aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']

            ## Horizontal Flip
            aug = HorizontalFlip(p=1)
            augmented = aug(image=x, mask=y)
            x4 = augmented['image']
            y4 = augmented['mask']

            ## Random Brightness and Contrast
            aug = RandomBrightnessContrast(p=1)
            augmented = aug(image=x, mask=y)
            x5 = augmented['image']
            y5 = augmented['mask']

            aug = RandomGamma(p=1)
            augmented = aug(image=x, mask=y)
            x6 = augmented['image']
            y6 = augmented['mask']

            aug = HueSaturationValue(p=1)
            augmented = aug(image=x, mask=y)
            x7 = augmented['image']
            y7 = augmented['mask']

            aug = RGBShift(p=1)
            augmented = aug(image=x, mask=y)
            x8 = augmented['image']
            y8 = augmented['mask']

            aug = MedianBlur(p=1, blur_limit=5)
            augmented = aug(image=x, mask=y)
            x9 = augmented['image']
            y9 = augmented['mask']

            aug = GaussianBlur(p=1, blur_limit=3)
            augmented = aug(image=x, mask=y)
            x10 = augmented['image']
            y10 = augmented['mask']

            aug = GaussNoise(p=1)
            augmented = aug(image=x, mask=y)
            x11 = augmented['image']
            y11 = augmented['mask']

            aug = ChannelShuffle(p=1)
            augmented = aug(image=x, mask=y)
            x12 = augmented['image']
            y12 = augmented['mask']

            aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
            augmented = aug(image=x, mask=y)
            x13 = augmented['image']
            y13 = augmented['mask']

            images_aug.extend([
                    x0, x1, x2, x3, x4, x5, x6,
                    x7, x8, x9, x10, x11, x12,
                    x13])

            masks_aug.extend([
                    y0, y1, y2, y3, y4, y5, y6,
                    y7, y8, y9, y10, y11, y12,
                    y13])

            idx = 0
            for i, m in zip(images_aug, masks_aug):
                tmp_image_name = f"{image_name}_{idx}.png"
                tmp_mask_name  = f"{mask_name}_{idx}.png"

                image_path = os.path.join("./hubmap_256x256_augmented/images_aug2/", tmp_image_name)
                mask_path  = os.path.join("./hubmap_256x256_augmented/masks_aug2/", tmp_mask_name)

                cv2.imwrite(image_path, i)
                cv2.imwrite(mask_path, m)

                idx += 1

    return images_aug, masks_aug

images_aug, masks_aug = augment_data(image_tissues_split, mask_tissues_split)

In [None]:
import glob
aug_img_paths2 = glob.glob("./hubmap_256x256_augmented/images_aug2/*.png")
aug_msk_paths2 = glob.glob("./hubmap_256x256_augmented/masks_aug2/*.png")

print("Number of Augmented Images", len(aug_img_paths2))
print("Number of Augmented Masks", len(aug_msk_paths2))

In [None]:
aug_img_paths = aug_img_paths2[-100:]
aug_msk_paths = aug_msk_paths2[-100:]
aug_imgs, aug_msks = read_data(aug_img_paths, aug_msk_paths)

In [None]:
max_rows = 10
max_cols = 4
fig, ax = plt.subplots(max_rows, max_cols, figsize=(20,32))
plot_count = (max_rows*max_cols)//2
for idx, (img, mas) in enumerate(zip(aug_imgs[:plot_count], aug_msks[:plot_count])):
    row = (idx//max_cols)*2
    row_masks = row+1
    col = idx % max_cols
    ax[row, col].imshow(img)
    ax[row_masks, col].imshow(mas)