## Mayo Clinic - STRIP AI -  Understanding image processing 

use `PIL` (pillow package) and `torchvision` to load and process images.

- Get image metadata
    - get file size and create/update timestamps via `pathlib`
    - get image metadata via `PIL` package
        - image lenght, width, mode, and so on
- resize images `PIL` package
    - use `PIL` thumbnail to resize images while keeping the original image height/width ratio
    - note that when converting `PIL` object to numpy, the data is in `[0, 255]` not `[0, 1]`
- crop and pad images by `torchvision` tranforms
    - use `torchvision` to crop and pad images
    - **crop** image: 
        - when the original size is 512*480, and by cropping the image to 512, the new image will be 512*512, and the additional area is filled with 0 (shown as black)
        - when the original size is 512*480 and by cropping the image to 480, the new image will be 480*480
    - **pad** image: 
        - when the original size is 512*480, and by padding the image by 10, the new image will be 522*490, the addtional area is filled with 0 (shown as black)
- add guassion blur to images by `torchvision` tranforms
- normalize images via `torchvisaion` transforms normalize function: 
    - first convert the `PIL` image object into numpy array (the data range is `0, 255]`)
    - then reshape the numpy array from height*width*channels (for rgb images, the number of channels is 3) to channels*height*width
    - make the data range from `[0, 255]` to `[0. 1]`
    - normalize the data using `torchvision` *transforms.functional.normalize*
    - reshape the numpy back to height*width*channels
- remove white space
    

In [None]:
import pandas as pd
import numpy as np
import os
from pathlib import Path

from datetime import datetime, timedelta
import time

import gc
import copy

import pyarrow.parquet as pq
import pyarrow as pa

 
from dateutil.relativedelta import relativedelta
from sklearn.preprocessing import StandardScaler, MinMaxScaler

from sklearn.metrics import mean_squared_error, roc_auc_score
from sklearn.model_selection import StratifiedKFold, KFold

pd.options.display.max_rows = 100
pd.options.display.max_columns = 100

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
random_seed=1234
pl.seed_everything(random_seed)



import torch
from torch import nn
import numpy as np


import torch
from torch.utils.data import (Dataset, DataLoader)


#basic libs

import pandas as pd
import numpy as np
import os
from pathlib import Path

from datetime import datetime, timedelta
import time
from dateutil.relativedelta import relativedelta

import gc
import copy

#additional data processing

import pyarrow.parquet as pq
import pyarrow as pa

from sklearn.preprocessing import StandardScaler, MinMaxScaler


#visualization
import seaborn as sns
import matplotlib.pyplot as plt

#load images
import matplotlib.image as mpimg
import PIL
from PIL import Image




#settings
pd.options.display.max_rows = 100
pd.options.display.max_columns = 100

Image.MAX_IMAGE_PIXELS = None

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
random_seed=1234
pl.seed_everything(random_seed)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

In [None]:
next(os.walk("../input"))[1]

In [None]:
img_folder = '../input/mayo-clinic-strip-ai/train'
img_path = f'{img_folder}/777311_0.tif' 
# img_path = f'{img_folder}/006388_0.png'

### Get image metadata

In [None]:
#check the file info
Path(img_path).stat()

In [None]:
#get image meta data using pillow
#https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=attributes#image-attributes

img = Image.open(img_path)

meta_dict = {    
            'filename': img.filename,
            'format': img.format, 
            'mode': img.mode,  
            'size': img.size,  #2-tuple (width, height).
            'width': img.width, 
            'height': img.height, 
            'palette': img.palette, 
            'info': img.info, 
            'is_animated': img.is_animated, 
            'n_frames': img.n_frames, 
}

# img.close()
# del img
# gc.collect()

meta_dict

### Load and resize images

In [None]:
#get the image size (width, height)
# img = Image.open(img_path)
print(img.size, img.height, img.width)
# img = np.asarray(img)

In [None]:
%%time
#display the original image
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.show()

In [None]:
#https://stackoverflow.com/questions/71738218/module-pil-has-not-attribute-resampling
#dealing with pillow version differences
print(PIL.__version__)

In [None]:
#create the thumbnail of the image

if hasattr(Image, 'Resampling'):  # Pillow<8.4.0
    img.thumbnail((1024, 1024), resample=Image.Resampling.LANCZOS, reducing_gap=10)
    if (img.height> img.width):
        img = img.transpose(PIL.Image.Transpose.ROTATE_90)
else:
    img.thumbnail((1024, 1024), resample=Image.LANCZOS, reducing_gap=10)
    if (img.height> img.width):
        img = img.transpose(PIL.Image.ROTATE_90)
    
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.show()

In [None]:
print(img.size)
np.asarray(img, np.uint8).min(), np.asarray(img, np.uint8).max()

### Crop and pad images by torchvisaion tranforms

In [None]:
#https://stackoverflow.com/questions/10965417/how-to-convert-a-numpy-array-to-pil-image-applying-matplotlib-colormap

#use torchvision to center crop the image
img2 = transforms.functional.center_crop(img, 1024)
print(img2.size)
plt.figure(figsize=(8, 8))
plt.imshow(img2)
plt.show()

In [None]:
np.asarray(img2, np.uint8).min(), np.asarray(img2, np.uint8).max()

In [None]:
img3 = transforms.functional.pad(img, 10)
print(img3.size)
plt.figure(figsize=(8, 8))
plt.imshow(img3)
plt.show()

In [None]:
np.asarray(img3, np.uint8).min(), np.asarray(img3, np.uint8).max()

### Add Gaussian Blur to images

In [None]:
img4 = transforms.functional.gaussian_blur(img, kernel_size=(5, 9), sigma=(0.1, 5))
print(img4.size)
plt.figure(figsize=(8, 8))
plt.imshow(img4)
plt.show()

In [None]:
#not that before adding the gaussian blur, the data range is [0, 255]
#after the guassian blur, the range is [1, 254]
np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max()

In [None]:
#https://dsp.stackexchange.com/questions/10057/gaussian-blur-standard-deviation-radius-and-kernel-size
#try different kernal sizes

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16, 8))

for i, k_sizes  in enumerate([(5, 5), (5, 9),(15, 15), (25, 45)]):
    img4 = transforms.functional.gaussian_blur(img, kernel_size=k_sizes, sigma=(0.1, 5))
    print(i, k_sizes, img4.size, np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max())
    axes[i%2, i//2].imshow(img4)
    axes[i%2, i//2].set_title(f'{k_sizes}')

plt.show()

In [None]:
#try different sigma

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16, 8))

for i, sigma  in enumerate([(0.1, 5), (0.05, 1),(0.01, 1), (0.8, 10)]):
    img4 = transforms.functional.gaussian_blur(img, kernel_size=(5, 9), sigma=sigma)
    print(i, k_sizes, img4.size, np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max())
    axes[i%2, i//2].imshow(img4)
    axes[i%2, i//2].set_title(f'{sigma}')

plt.show()

In [None]:
#try different sigma and kernel_size

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16, 8))

for i, (sigma, k_sizes) in enumerate([[(0.1, 5), (5, 9)], [(0.1, 5), (35, 65)],
                                     [(0.8, 10), (5, 9)], [(0.8, 10), (35, 65)]]):
    img4 = transforms.functional.gaussian_blur(img, kernel_size=k_sizes, sigma=sigma)
    print(i, k_sizes, img4.size, np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max())
    axes[i%2, i//2].imshow(img4)
    axes[i%2, i//2].set_title(f'kernel_size={k_sizes}, sigma={sigma}')

plt.show()

In [None]:
img4 = transforms.functional.gaussian_blur(img2, kernel_size=(5, 9), sigma=(0.1, 5))
print(img4.size)
plt.figure(figsize=(8, 8))
plt.imshow(img4)
plt.show()

In [None]:
np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max()

### Normalize image

- to apply the `torchvisaion` transforms normalize function:
    - first convert the `PIL` image object into numpy array (the data range is `0, 255]`)
    - then reshape the numpy array from height*width*channels (for rgb images, the number of channels is 3) to channels*height*width
    - make the data range from `[0, 255]` to `[0. 1]`
    - normalize the data using `torchvision` *transforms.functional.normalize*
    - reshape the numpy back to height*width*channels
    

In [None]:
img5 = np.asarray(img)
print(img5.shape)
print(img5.min(), img5.max())
img5 = img5.transpose((2,0,1))
print(img5.shape)
img5 = img5/255
print(img5.min(), img5.max())
#make sure the array is normalized to 0-1 before applying normalize
img5 = transforms.functional.normalize(torch.Tensor(img5), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img5 = img5.numpy().transpose((1,2,0))
print(img5.min(), img5.max())

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16, 8))

axes[0, 0].imshow(img)
axes[0, 1].imshow(img5)
axes[1, 0].imshow(np.clip(img5, 0, 1))


img5_1 = img/np.amax(img5) # if float
img5_1 = np.array(img5_1/np.amax(img5_1)*255, np.int32) # if int
axes[1, 1].imshow(img5_1)

plt.show()



In [None]:
(np.uint8(img5)*255).min(), (np.uint8(img5)*255).max()

In [None]:
img5_1.min(), img5_1.max()

In [None]:
img5 = np.asarray(img2)
print(img5.shape)

img5 = img5.transpose((2,0,1))
img5 = img5/255
img5 = transforms.functional.normalize(torch.Tensor(img5), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img5 = img5.numpy().transpose((1,2,0))
plt.figure(figsize=(8, 8))
plt.imshow(img5)
plt.show()

In [None]:
print(img5.min(), img5.max())

### remove empty space

the following is base on this [notebook](https://www.kaggle.com/code/jirkaborovec/bloodclots-eda-load-wsi-prune-background?scriptVersionId=101797769)

In [None]:
#https://www.kaggle.com/code/jirkaborovec/bloodclots-eda-load-wsi-prune-background?scriptVersionId=101797769

def prune_image_rows_cols(im, mask, thr=0.990):
    # delete empty columns
    for l in reversed(range(im.shape[1])):
        if (np.sum(mask[:, l]) / float(mask.shape[0])) > thr:
            im = np.delete(im, l, 1)
    # delete empty rows
    for l in reversed(range(im.shape[0])):
        if (np.sum(mask[l, :]) / float(mask.shape[1])) > thr:
            im = np.delete(im, l, 0)
    return im


def mask_median(im, val=255):
    masks = [None] * 3
    for c in range(3):
        masks[c] = im[..., c] >= np.median(im[:, :, c]) - 5
    mask = np.logical_and(*masks)
    im[mask, :] = val
    return im, mask


In [None]:
img6, mask6 = mask_median(np.array(img))
img6 = prune_image_rows_cols(img6, mask6)


fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))

print(img.size, img6.shape)
print(np.asarray(img).min(),np.asarray(img).max())
print(img6.min(),img6.max())
axes[0].imshow(img)
axes[1].imshow(img6)

plt.show()