### Biomy社むけのipynb

#### STEP (0)
*   必要なモジュールをインストール
*   condaを利用して、オリジナルのenvironment.ymlを実行する場合は、segmentation-modelsのみでOK
*   モデルは[ここ](https://github.com/hasibzunair/MoNuSAC-ISBI-2020/releases/tag/v0.0.1)からDLしておく必要あり、保存先のディレクトリは変数MODEL_PATHにおく

In [None]:
!apt-get install openslide-tools
!apt-get install python-openslide

In [None]:
!pip install openslide-python

In [None]:
# README.mdでは文中”NOTE”に記載があるので注意
!pip install segmentation-models



#### STEP (1)
svs/ndpi/tif等のWSIを以下のフォルダ構成(階層)のように格納すること

---
```
DATA
├── CONTAIN
│   ├── SUBJECT_1
│   ├── SUBJECT_2
│   ├── ...
│   └── SUBJECT_N
│       └── WSI.ndpi
│
└── UNCONTAIN
    ├── SUBJECT_1
    ├── SUBJECT_2
    ├── ...
    └── SUBJECT_M
        └── WSI.ndpi

```

In [None]:
import openslide
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
import os 
import time
import cv2
from tqdm import tqdm
import skimage.draw
import random
import keras
from skimage.transform import resize
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
import skimage.io
import efficientnet.tfkeras
from tensorflow.keras.models import load_model
from scipy import ndimage as ndi
from PIL import Image, ImagePalette
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns
from IPython.display import display

NUCLEI_PALETTE = ImagePalette.random()

In [None]:
# 事前にDLした学習済みモデルの読み込み
MODEL_PATH = 'PRETRAINED_MODEL_PATH'
MODEL = load_model('{}/unet_efficientnetb3_multiclass.h5'.format(MODEL_PATH), compile=False)

#### STEP (2)
WSIを細切れにクロップする。<br>
適当な閾値によって細胞が多いところにあたりをつけているので、すべての領域を実施する場合は、変数を以下に設定すること。
```
THRESH=255　　# 255のときは真っ白
```

In [None]:
# 変数の設定
BASE = Path('DATA')
DATA_SETNAME = 'test_images'  # 細切れにした画像の保存先フォルダ名
SAVE_BASE = BASE.joinpath(DATA_SETNAME)
THRESH = 70
NUM_CREATE = 10  # クロップする画像数
CROP_W, CROP_H = 96, 96
LEVEL = 0  # WSIのレベル指定
TARGET_DIR = 'CONTAIN'

In [None]:
np.random.seed(0)  # 検証のためシード固定

def split_img_thumbnail(nih_base, get_type, w=512, h=512):
    for subject_id_dir in [dir_path for dir_path in BASE.joinpath(get_type).iterdir() if dir_path.is_dir()]:
        svs_path = [svs for svs in subject_id_dir.iterdir()][0]  # SUBJECT_Xのフォルダ以下には1枚しかない場合を仮定してインデックス0を指定
        # print(svs_path.stem)
        try:
            print('*'*30)
            csvs_slide = openslide.OpenSlide(str(svs_path))

            # サムネイル画像つかってクロップ領域を確認
            k = csvs_slide.get_thumbnail((w,h))
            gray_k = k.convert('L')
            # サムネイルで画素値のヒストグラムをみて、細胞画像に当てをつける
            v = 255 - np.array(gray_k)
            thresh_idx = v > THRESH
            idx_xy = np.where(thresh_idx)
            num_cell = np.shape(idx_xy)[1]
            dice = list(range(0, num_cell-1))  # 0からあたりをつけた細胞画像数までのサイコロ作成
            # テスト時のようにNUM_CREATEが小さいときは、重複は起こりにくいが、np.random.sample(dice, NUM_CREATE)なら被らない
            select_idxs = np.random.choice(dice, NUM_CREATE)
            for select_idx in select_idxs:
                # print(select_idx)
                x,y = np.array(idx_xy)[:, select_idx]
                rate_x, rate_y = csvs_slide.dimensions[0]/h, csvs_slide.dimensions[1]/w
                print('Top　Left=({},{})'.format(x, y))
                print('Showing Resize Rate={}'.format(rate_x**(-1)))

                # 左上の座標によって枠外が発生する可能性あるので注意、以下は何も手当てしていない
                ox, oy = int(x*rate_x), int(y*rate_y)
                crop_img = csvs_slide.read_region((oy, ox), LEVEL, (CROP_W, CROP_H))

                # サムネイズ画像のなかで、ざっくりどの辺りをクロップしているか確認
                tmp_k = k.copy()
                draw = ImageDraw.Draw(tmp_k)
                draw.rectangle((y, x,
                                y+int(rate_y**(-1)*1000), x+int(rate_x**(-1)*1000)),
                                outline=(255, 0, 0), width=10)
                fig = plt.figure()
                ax = fig.add_subplot(1,2,1)
                ax.imshow(tmp_k)
                ax.axis(False)

                # クロップ画像を指定レベルで確認
                ax = fig.add_subplot(1,2,2)
                ax.imshow(crop_img)
                plt.title(svs_path.stem.split('.')[-1])
                ax.axis(False)
                plt.show()

                # save as png
                if not SAVE_BASE.joinpath(svs_path.stem).exists():
                    SAVE_BASE.joinpath(svs_path.stem, 'images').mkdir(parents=True)
                crop_rgb_img = crop_img.convert('RGB')
                save_at = SAVE_BASE.joinpath(svs_path.stem, 'images', '{}_{}.jpg'.format(ox,oy))
                crop_rgb_img.save(str(save_at))

            '''
            # 1d ヒストグラムは閾値をざっくりみたかっただけ
            plt.figure()
            sns.distplot(v)  # 背景白が255になるため
            plt.show()
            '''
        except:
            print('missing to load the file:{}'.format(svs_path))

In [None]:
split_img_thumbnail(BASE, TARGET_DIR)

#### STEP (3)


In [None]:
# Define paths
TEST_DATASET_PATH = os.path.join(BASE, DATA_SETNAME)

In [None]:
def create_directory(directory):
    '''
    Creates a new folder in the specified directory if the folder doesn't exist.
    INPUT
        directory: Folder to be created, called as "folder/".
    OUTPUT
        New folder in the current directory.
    '''
    if not os.path.exists(directory):
        os.makedirs(directory)

In [None]:
def pad(img, pad_size=96):
    """
    Load image from a given path and pad it on the sides, so that eash side is divisible by 96 (network requirement)
    if pad = True:
        returns image as numpy.array, tuple with padding in pixels as(x_min_pad, y_min_pad, x_max_pad, y_max_pad)
    else:
        returns image as numpy.array
    """

    if pad_size == 0:
        return img

    height, width = img.shape[:2]

    if height % pad_size == 0:
        y_min_pad = 0
        y_max_pad = 0
    else:
        y_pad = pad_size - height % pad_size
        y_min_pad = int(y_pad / 2)
        y_max_pad = y_pad - y_min_pad

    if width % pad_size == 0:
        x_min_pad = 0
        x_max_pad = 0
    else:
        x_pad = pad_size - width % pad_size
        x_min_pad = int(x_pad / 2)
        x_max_pad = x_pad - x_min_pad

    img = cv2.copyMakeBorder(img, y_min_pad, y_max_pad, x_min_pad, x_max_pad, cv2.BORDER_REFLECT_101)

    return img, (x_min_pad, y_min_pad, x_max_pad, y_max_pad)


def unpad(img, pads):
    """
    img: numpy array of the shape (height, width)
    pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad)
    @return padded image
    """
    (x_min_pad, y_min_pad, x_max_pad, y_max_pad) = pads
    height, width = img.shape[:2]

    return img[y_min_pad:height - y_max_pad, x_min_pad:width - x_max_pad]


def read_nuclei(path):
    "read raw data"

    # Load 4-channel image
    img = skimage.io.imread(path)
    
    # input image
    if len(img.shape) > 2:
        img = img[:,:,:3]
    # mask
    else:
        # do nothing
        pass
        
    return img


def save_nuclei(path, img):
    "save image"
    skimage.io.imsave(path, img)

    
def sliding_window(image, step, window):
    x_loc = []
    y_loc = []
    cells = []
    
    for y in range(0, image.shape[0], step):
        for x in range(0, image.shape[1], step):
            cells.append(image[y:y + window[1], x:x + window[0]])
            x_loc.append(x)
            y_loc.append(y)
    return x_loc, y_loc, cells


def extract_patches(image, step, patch_size):    
    patches = []
    
    # Get locations
    x_pos, y_pos, cells = sliding_window(image, step, (patch_size[0], patch_size[1]))

    for (x, y, cell) in zip(x_pos, y_pos, cells):

        # Get patch
        patch = image[y:y + patch_size[0], x:x + patch_size[0]]

        # Get size
        raw_dim = (patch.shape[1], patch.shape[0]) # W, H
        #print(raw_dim)
        #print(patch.shape)


        if raw_dim != (patch_size[0], patch_size[1]):

            # Resize to 64x64
            #patch = cv2.resize(patch, (64, 64), interpolation = cv2.INTER_AREA)
            patch, pad_locs = pad(patch, pad_size=patch_size[0])
            
            
            # Do stuffffff
            patches.append(patch)
        
        else:

            # Do stuffffff
            patches.append(patch)
    
    patches = np.array(patches)
    
    return patches
    
# Compute Panoptic quality metric for each image
def Panoptic_quality(ground_truth_image,predicted_image):
    TP = 0
    FP = 0
    FN = 0
    sum_IOU = 0
    matched_instances = {}# Create a dictionary to save ground truth indices in keys and predicted matched instances as velues
                        # It will also save IOU of the matched instance in [indx][1]

    # Find matched instances and save it in a dictionary
    for i in np.unique(ground_truth_image):
        if i == 0:
            pass
        else:
            temp_image = np.array(ground_truth_image)
            temp_image = temp_image == i
            matched_image = temp_image * predicted_image
        
            for j in np.unique(matched_image):
                if j == 0:
                    pass
                else:
                    pred_temp = predicted_image == j
                    intersection = sum(sum(temp_image*pred_temp))
                    union = sum(sum(temp_image + pred_temp))
                    IOU = intersection/union
                    if IOU> 0.5:
                        matched_instances [i] = j, IOU 
                        
    # Compute TP, FP, FN and sum of IOU of the matched instances to compute Panoptic Quality               
                        
    pred_indx_list = np.unique(predicted_image)
    pred_indx_list = np.array(pred_indx_list[1:])

    # Loop on ground truth instances
    for indx in np.unique(ground_truth_image):
        if indx == 0:
            pass
        else:
            if indx in matched_instances.keys():
                pred_indx_list = np.delete(pred_indx_list, np.argwhere(pred_indx_list == [indx][0]))
                TP = TP+1
                sum_IOU = sum_IOU+matched_instances[indx][1]
            else:
                FN = FN+1
    FP = len(np.unique(pred_indx_list))
    PQ = sum_IOU/(TP+0.5*FP+0.5*FN)
    
    return PQ

In [None]:
# SAME CODE BLOCK AS IN 6_inference.ipynb
# Helper function for data visualization
def visualize(**images):
    """Plot images in one row."""
    
    norm=plt.Normalize(0,4) # 5 classes including BG
    map_name = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black", "red","yellow","blue", "green"])

    n = len(images)
    plt.figure(figsize=(18, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, cmap=map_name, norm=norm)
    plt.show()
    
        
def prep(img):
    img = img.astype('float32')
    img = (img > 0.5).astype(np.uint8)  # threshold
    img = resize(img, (image_cols, image_rows), preserve_range=True)
    return img


def visualize_results(image, mask):
    
    f, axarr = plt.subplots(1,2, figsize=(16, 16))
    
    norm=plt.Normalize(0,4) # 5 classes including BG
    map_name = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black", "red","yellow","blue", "green"])

    axarr[0].imshow(image)
    axarr[1].imshow(mask, cmap=map_name, norm=norm)


def vis_gray(image, mask):
    
    f, axarr = plt.subplots(1,2, figsize=(16, 16))
    
    axarr[0].imshow(image)
    axarr[1].imshow(mask, cmap='gray')


def predict(im):
    """Predict on patch"""
    im = np.expand_dims(im, axis=0)
    # WARNING、元のコードはグローバル変数のモデルを利用している
    im = MODEL.predict(im)
    # im = model.predict(im)
    im = np.argmax(im.squeeze(), axis=-1)
    #assert im.shape == (96, 96), "Wrong shape, {}!".format(im.shape)    
    return im


def instance_seg(image):
    distance = ndi.distance_transform_edt(image)
    local_maxi = peak_local_max(distance, indices=False, footprint=np.ones((3, 3)), labels=image)
    markers = ndi.label(local_maxi)[0]
    labels = watershed(-distance, markers, mask=image)
    return labels    



def whole_slide_predict(whole_image):
    #import pdb; pdb.set_trace()    
    # If input image less than patch, infer on whole image
    if whole_image.shape[0] < 96 or whole_image.shape[1] < 96:
        # Get size
        raw_dim = (whole_image.shape[1], whole_image.shape[0]) # W, H
        
        # Resize to 64x64 for prediction
        #whole_image_rs = cv2.resize(whole_image, (64, 64), interpolation = cv2.INTER_AREA)
        whole_image_rs, pad_locs = pad(whole_image, pad_size=96)
        # Infer
        pred = predict(whole_image_rs)        
        # Resize back to original shape
        #pred = cv2.resize(pred, raw_dim, interpolation = cv2.INTER_AREA)
        pred = unpad(pred, pad_locs)
        # Change dtype for resizing back to original shape
        pred = pred.astype(np.uint8)
    else:
        # Get patch locations
        x_pos, y_pos, cells = sliding_window(whole_image, 96, (96, 96)) 
        # Array for storing predictions
        pred = np.zeros((whole_image.shape[0], whole_image.shape[1])).astype(np.uint8)

        # Slide over each patch
        for (x, y, cell) in zip(x_pos, y_pos, cells):
            # Get patch
            patch = whole_image[y:y + 96, x:x + 96]
            # Get size
            raw_dim = (patch.shape[1], patch.shape[0]) # W, H
            # If less than patch size, resize and then run prediction
            # print('kokomadeha?')
            if raw_dim != (96, 96):
                # Resize to 64x64
                #patch_rs = cv2.resize(patch, (64, 64), interpolation = cv2.INTER_AREA)
                patch_rs, pad_locs = pad(patch, pad_size=96)                
                #print(patch.dtype, processed.dtype)
                assert patch.dtype == patch_rs.dtype, "Wrong data type after resizing!"
                # Infer
                processed = predict(patch_rs)
                
                # Resize back to original shape
                #processed = cv2.resize(processed, raw_dim, interpolation = cv2.INTER_AREA)
                processed = unpad(processed, pad_locs)
                
                # Change dtype 
                processed = processed.astype(np.uint8)
                
                assert patch.shape[:2] == processed.shape, "Wrong shape!"
                assert patch.dtype == processed.dtype, "Wrong data type in prediction!"

            else:
                # print('kocchiyane')
                # Infer
                processed = predict(patch)
                # Change dtype
                processed = processed.astype(np.uint8)
                #print(patch.dtype, processed.dtype)
                assert patch.shape[:2] == processed.shape, "Wrong shape!"
                assert patch.dtype == processed.dtype, "Wrong data type in prediction!"

            # Add in image variable
            pred[y:y + 96, x:x + 96] = processed 
            processed = None

    return pred


#### STEP (4)

In [None]:
for subject_dir in Path(TEST_DATASET_PATH).iterdir():
    for img_dir in subject_dir.iterdir():
        image_fns = sorted([img_path for img_path in img_dir.glob('**/*')])
        for idx in range(len(image_fns)):
            print("Index: ",idx)
            # print(Path(os.path.join(test_dataset_path, image_fns[idx])).exists())
            image = skimage.io.imread(os.path.join(TEST_DATASET_PATH, image_fns[idx]))
            print("Image shape:", image.shape)

            pred = whole_slide_predict(image)
            print(pred.dtype)
            # Post processing to refine predictions
            pred_filt = cv2.medianBlur(pred.astype(np.uint8), 5)

            print(image.shape, pred.shape)
            print("Uniques predicted", np.unique(pred))
            # 順番は別のipynbを信じてみた
            label_map = {
                        '0':'background',
                        '1':'Epithelial',
                        '2':'Lymphocyte',
                        '4':'Macrophage',
                        '3':'Neutrophil',
                    }
            print([label_map[str(i)]for i in np.unique(pred)])
            assert image.shape[:2] == pred.shape, "Image missmatch"

            #visualize_results(image, pred)

            visualize(
                    image=image,
                    Predicted_mask = pred,
                    Filtered_mask = pred_filt
                )