<div style="padding:20px;color:#2b2d42;margin:0;font-size:180%;text-align:center;display:fill;border-radius:5px;background-color:white;overflow:hidden;font-weight:600">[WSI-HuBMAP] Dataset creation 256x256</div>

<img src="https://drive.google.com/uc?id=1pbIvjTlhGywfhiMTqcsdOB5LSHlklM90" style="border-radius:5px">

<h5 style="text-align: center; font-family: Verdana; font-size: 12px; font-style: normal; font-weight: bold; text-decoration: None; text-transform: none; letter-spacing: 1px; color: black; background-color: #ffffff;">Created by: Nghi Huynh</h5>

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">1. Imports</div>

In [None]:
import gc
import os
import cv2
import zipfile
import rasterio
import numpy as np
import pandas as pd
import tifffile as tiff
from PIL import Image
import matplotlib.pyplot as plt
from rasterio.windows import Window
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
import math
import re

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">2. Configs</div>

In [None]:
config = {
    'resize': (768,768),
    'resolution': (1024,1024),
    'DATA': '../input/hubmap-organ-segmentation/train_images',
    'MASKS' : '../input/hubmap-organ-segmentation/train.csv',
    'Window' : (250,2670),
    'bs': 64,
    'nfolds': 4,
    'fold': 0,
    'NUM_WORKERS': 4,
    'OUT_TRAIN' : 'train.zip',
    'OUT_MASKS' : 'masks.zip'
}

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">3. Helper functions</div>

In [None]:
# functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    '''
    Args:
    encs: list of rle masks
    shape: mask shape
    '''
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float32) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

In [None]:
def visualize_batch(df, organ, slide):
    count = 0
    fig = plt.figure(figsize=(30,30))
    for index, row in df.iterrows():
        if count == 25: break
        
        if row['organ'] == organ:
            img = rasterio.open(os.path.join(config['DATA'],str(index)+'.tiff'), num_threads='all_cpus')
            mask = enc2mask([row['rle']],(img.shape[1],img.shape[0]))
        
            #read window slide
            img = img.read([1,2,3],window=Window.from_slices(slide, slide)) 
            mask = mask[slide[0]:slide[1], slide[0]:slide[1]] 
        
            plt.subplot(5,5, count+1)
            plt.imshow(np.transpose(img, (1,2,0)))
            plt.imshow(mask, cmap='seismic', alpha=0.4)
            plt.title(f'{index}', size=40)
            plt.axis('off')
            
            count += 1
            del img, mask
    
    plt.suptitle(f'{organ}', size=50, weight="bold", y = 1.0)      
    plt.show()   

In [None]:
df_masks = pd.read_csv(config['MASKS'])[['id', 'rle','organ']].set_index('id')
df_masks.head()

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">4. Create HPA dataset class</div>

Based on the information provided by the host about tissue area within HPA image around 2500 x 2500 pixels, my approach is to create whole image dataset contains only tissue area. 

* [256x256](https://www.kaggle.com/datasets/nghihuynh/hubmap-2022-wsi-256x256)
* [512x512](https://www.kaggle.com/datasets/nghihuynh/hubmap-2022-wsi-512x512)
* [768x768](https://www.kaggle.com/datasets/nghihuynh/hubmap-2022-wsi-768x768)

`Window.from_slices((row_start, row_stop), (col_start, col_stop))`

In [None]:
class HPADataset(Dataset):
    def __init__(self, idx, resize=config['resize'], slide=config['Window'], encs=None):
        self.data = rasterio.open(os.path.join(config['DATA'],str(idx)+'.tiff'), num_threads='all_cpus')
        # some images have issues with their format 
        # and must be saved correctly before reading with rasterio
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
                    
        self.shape = self.data.shape
        self.slide = slide
        self.resize = resize
        self.idx = idx
        
        self.mask = enc2mask(encs,(self.shape[1],self.shape[0])) if encs is not None else None
        
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        # read img (RGB), mask (grayscale) from window slide
        # img, mask: uint8
        img = self.data.read([1,2,3],window=Window.from_slices(self.slide,self.slide)) 
        mask = self.mask[self.slide[0]:self.slide[1], self.slide[0]:self.slide[1]] 
        
        # resize
        img = cv2.resize(np.transpose(img,(1,2,0)),(self.resize[0],self.resize[1]),
                         interpolation = cv2.INTER_AREA)
        mask = cv2.resize(mask,(self.resize[0], self.resize[1]),
                          interpolation = cv2.INTER_NEAREST)
        
        return img, mask, self.idx

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">5. Visualize batches</div>

Visualize batch of images based on organ with cropped tissue area.

In [None]:
visualize_batch(df_masks, 'lung', config['Window'])

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">6. Dataset creation</div>

In [None]:
x_tot,x2_tot = [],[]
with zipfile.ZipFile(config['OUT_TRAIN'], 'w') as img_out,\
 zipfile.ZipFile(config['OUT_MASKS'], 'w') as mask_out:
    for index, encs in tqdm(df_masks.iterrows(),total=len(df_masks)):
        #image+mask dataset
        ds = HPADataset(index,encs=encs)
        if index == 31800: continue
        for i in range(len(ds)):
            img, m, idx = ds[i]

            x_tot.append((img/255.0).reshape(-1,3).mean(0))
            x2_tot.append(((img/255.0)**2).reshape(-1,3).mean(0))

            #write data   
            img = cv2.imencode('.png',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))[1]
            img_out.writestr(f'{index}.png', img)
            m = cv2.imencode('.png',m)[1]
            mask_out.writestr(f'{index}.png', m)
        
#image stats
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
print('mean:',img_avr, ', std:', img_std)

# <div style="padding:20px;color:white;margin:0;font-size:100%;text-align:left;display:fill;border-radius:5px;background-color:#735d78;overflow:hidden">7. Sanity check</div>

In [None]:
cols, rows = 4,4
idx0 = 20
fig=plt.figure(figsize=(cols*4, rows*4))
with zipfile.ZipFile(config['OUT_TRAIN'], 'r') as img_arch, \
     zipfile.ZipFile(config['OUT_MASKS'], 'r') as msk_arch:
    fnames = sorted(img_arch.namelist())[8:]
    #print(fnames)
    for i in range(rows):
        for j in range(cols):
            idx = i+j*cols
            img = cv2.imdecode(np.frombuffer(img_arch.read(fnames[idx0+idx]), 
                                             np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            mask = cv2.imdecode(np.frombuffer(msk_arch.read(fnames[idx0+idx]), 
                                              np.uint8), cv2.IMREAD_GRAYSCALE)
            fig.add_subplot(rows, cols, idx+1)
            plt.axis('off')
            plt.imshow(Image.fromarray(img))
            plt.imshow(Image.fromarray(mask), alpha=0.2)
plt.show()