In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
os.chdir('..')
sys.path.append('src')

In [None]:
import rasterio
import torch
import numpy as np
from pathlib import Path
import pandas as pd
import torchvision
from dataclasses import dataclass
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib widget 

In [None]:
import json

import slideio
from matplotlib.widgets import PolygonSelector

In [None]:
# df = pd.read_csv('input/hmib/train.csv')
# cfg = OmegaConf.load('src/configs/u.yaml')

# GTEX

In [None]:
df = pd.read_csv('input/extra/gtex/GTExPortal.csv')

In [None]:
df.head()

In [None]:
#The large intestine has four parts: cecum, colon, rectum, and anal canal.

tissues = ['Spleen', 'Lung', 'Kidney - Cortex', 'Prostate', 'Colon - Transverse', 'Colon - Sigmoid ']

In [None]:
#df.Tissue.value_counts()

In [None]:
# for t in tissues:
#     sdf = df[df.Tissue == t]
#     break

In [None]:
sdf = df[df.Tissue == tissues[1]]

In [None]:
columns = df.columns

In [None]:
for c in columns:
    print(c, len(sdf[c].value_counts()))

In [None]:
sdf['Sex'].value_counts()

In [None]:
sdf['Age Bracket'].value_counts()

In [None]:
sdf['Hardy Scale'].value_counts()

In [None]:
sdf['Pathology Categories'].value_counts() # no_abnormalities

In [None]:
sdf['Pathology Notes'].value_counts()

In [None]:
for v in sdf['Pathology Notes']:
    if not isinstance(v, str):
        print('not string: ', v)
        continue
    words = v.split(' ')
    for i, word in enumerate(words):
        if word == 'alveolar':
            print(v)

In [None]:
#t = sdf[sdf['Pathology Notes'] == '2 pieces, no abnormalities'].sample(10) # spleen
#t = sdf[sdf['Pathology Notes'] == '1 piece'].sample(10) # lung
t = sdf[sdf['Pathology Notes'] == '2 pieces'].sample(10) # prostate

In [None]:
t

In [None]:
t.to_csv('input/extra/gtex/prostate/prostate.csv')

In [None]:
def gen_urls(df, url='https://brd.nci.nih.gov/brd/imagedownload/'):
    names = df['Tissue Sample ID']
    urls = []
    for name in names:
        urls.append(url + name)
    return urls

In [None]:
with open('input/extra/gtex/prostate/urls.txt', 'w') as f:
    for url in gen_urls(t): 
        f.write(url + '\n')

# Viewwer

In [None]:
def read_svs(p, targetscale=None):
    slide = slideio.open_slide(p, 'SVS')
    scene = slide.get_scene(0)
    shape, _, resolution = scene.rect, scene.num_channels, scene.resolution
    _,_,h,w = shape
    print('IMG resolution is ', resolution)
    size = 2000 if targetscale is None else int(h * (targetscale / resolution[0] * 1e6))
    image = scene.read_block(size=(size,0))
    scaley = h / 1000
    scalex = w / 1000
    return image, scaley, scalex, resolution

In [None]:
# slide = slideio.open_slide(image_path, 'SVS')
# num_scenes = slide.num_scenes
# scene = slide.get_scene(0)
# print(num_scenes, scene.name, scene.rect, scene.num_channels)

# raw_string = slide.raw_metadata
# raw_string.split("|")

In [None]:
# image = scene.read_block((0, 15000, 20000, 20000), size=(500,0))
#image = scene.read_block(size=(1000,0))

In [None]:
# plt.figure()
# plt.imshow(image)

In [None]:
ff = Path('input/extra/gtex/spleen/').glob('*.svs')
ff = list(ff)
ff = iter(ff)

In [None]:
image_path = next(ff)
image, scaley, scalex, resolution = read_svs(str(image_path))

In [None]:
image.shape, image.max()

In [None]:
mask = image.mean(2)

In [None]:
# plt.figure()
# plt.hist(mask.flatten(), bins=50);

In [None]:
timage = image.copy()
timage[mask>240] = 0

In [None]:
mask.shape

In [None]:
rmask = cv2.resize((mask>240).astype(np.uint8), np.array(mask.shape[::-1])//10)

In [None]:
slide = slideio.open_slide(str(image_path), 'SVS')
scene = slide.get_scene(0)
size = 10000 
image = scene.read_block(size=(size,0))

In [None]:
scene.size

In [None]:
scene.read_block((0,0,256,256)).shape

In [None]:
fig2, ax2 = plt.subplots()
#fig2.show()
ax2.imshow(image)
# selector2 = PolygonSelector(ax2, lambda *args: None)

In [None]:
# poly = np.array(selector2.verts)
# poly[:,0] *= scalex
# poly[:, 1]*= scaley

# dst = Path('input/extra/gtex/polys/')/image_path.parent.name / image_path.with_suffix('.json').name
# dst.parent.mkdir(exist_ok=True, parents=True)
# with open(dst, 'w') as f:
#     json.dump(poly.astype(int).tolist(), f)

# CUtting

In [None]:
def start_points(size, split_size, overlap=0):
    points = [0]
    stride = int(split_size * (1-overlap))
    counter = 1
    while True:
        pt = stride * counter
        if pt + split_size >= size:
            points.append(size - split_size)
            break
        else:
            points.append(pt)
        counter += 1
    return points

def splitter(img, crop_w, crop_h, overlap=0):    
    img_h, img_w, *_ = img.shape
    X_points = start_points(img_w, crop_w, overlap)
    Y_points = start_points(img_h, crop_h, overlap)
    
    for i in Y_points:
        for j in X_points:
            split = img[i:i+crop_h, j:j+crop_w]
            #cv2.imwrite('{}_{}.{}'.format(name, count, frmt), split)
            yield split, i, j, crop_h, crop_w
            
imaging_measurements = {
    'hpa': {
        'pixel_size': {
            'kidney': 0.4,
            'prostate': 0.4,
            'largeintestine': 0.4,
            'spleen': 0.4,
            'lung': 0.4
        },
        'tissue_thickness': {
            'kidney': 4,
            'prostate': 4,
            'largeintestine': 4,
            'spleen': 4,
            'lung': 4
        }
    },
    'hubmap': {
        'pixel_size': {
            'kidney': 0.5,
            'prostate': 6.263,
            'largeintestine': 0.229,
            'spleen': 0.4945,
            'lung': 0.7562
        },
        'tissue_thickness': {
        'kidney': 10,
            'prostate': 5,
            'largeintestine': 8,
            'spleen': 4,
            'lung': 5
        }
    }
}

def read_svs(p, scale, targetscale=None):
    slide = slideio.open_slide(p, 'SVS')
    scene = slide.get_scene(0)
    shape, _, resolution = scene.rect, scene.num_channels, scene.resolution
    _,_,h,w = shape
    print('IMG resolution is ', resolution)
    size = 1000 if targetscale is None else int(h * (targetscale / scale / (resolution[0] * 1e6)))
    image = scene.read_block(size=(size,0))
    scaley = h / 1000
    scalex = w / 1000
    return image, scaley, scalex, resolution

def read_tiff(p, scale, **kwargs):
    f = rasterio.open(p)
    a = f.read().transpose(1,2,0)
    if scale != 1:
        a = cv2.resize(a, (0,0), fx=1/scale, fy=1/scale)
    return a

def read_jpeg(p, scale):
    i = cv2.imread(str(p))
    i = cv2.cvtColor(i, cv2.COLOR_RGB2BGR)
    if scale != 1:
        i = cv2.resize(i, (0,0), fx=1/scale, fy=1/scale)
    return i

In [None]:
# source = imaging_measurements['hpa']['pixel_size']
# target = imaging_measurements['hubmap']['pixel_size']
# extra = resolution[0] * 1e6
# source, extra

In [None]:
#set(df.organ)

In [None]:
organ = 'spleen'
SCALE = 3
# dst = Path(f'input/extra/hpa/cuts_{SCALE}/')

dst = Path(f'input/preprocessed/cuts_{SCALE}/')

# dst = Path(f'input/extra/gtex/cuts_{SCALE}/')
dst = dst / organ
        
# ff = Path(f'input/extra/gtex/images/{organ}/').glob('*.svs')

# ff = Path(f'input/extra/hpa/images/').glob(f'{organ}*.jpg')

df = pd.read_csv('input/hmib/train.csv')
idxs = df[df.organ==organ].index
ff = [Path('input/hmib/train_images/') / f'{df.iloc[idx].id}.tiff' for idx in idxs]

ff = sorted(list(ff))#[:100]
len(ff)

In [None]:
for image_path in tqdm(ff):
    image = read_tiff(image_path, SCALE)
    # image = read_jpeg(image_path, SCALE)
    # image, _, _, _ = read_svs(str(image_path), targetscale=.4, scale=SCALE)
    # print(image.shape, image_path)
    CH, CW = 512, 512
    g = splitter(image, CH, CW, overlap=0)
    H,W,C = image.shape
    
    dst.mkdir(exist_ok=True)
    
    for img, y,x,h,w in g:
        if img.mean()<230:
            s = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            name = f"{image_path.stem}_{y:4}_{x:4}.png"
            cv2.imwrite(str(dst / name), s)

In [None]:
# image_path = next(ff)
# image, scaley, scalex, resolution = read_svs(str(image_path), targetscale=.4)
# image, scaley, scalex, resolution = read_svs(str(image_path), targetscale=.4)
image = read_tiff(image_path)

In [None]:
CH, CW = 512, 512
g = splitter(image, CH, CW, overlap=0)
H,W,C = image.shape
H,W,C

In [None]:
dst = Path('input/extra/gtex/cuts/')
dst = dst / image_path.parent.name
# dst.mkdir(exist_ok=True)
dst.mkdir(exist_ok=True)

In [None]:
for img, y,x,h,w in tqdm(g, total=H*W//CH//CW):
    # cy, cx = y//100, x//100
    # cy = min(t.shape[0], cy-1)
    # cx = min(t.shape[1], cx-1)
    
    if img.mean()<230:
        s = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        name = f"{image_path.stem}_{y:4}_{x:4}.png"
        cv2.imwrite(str(dst / name), s)
        #break
    # break

In [None]:
Image.open('input/preprocessed/rle1024/images/11064.png')

In [None]:
Image.open('input/extra/gtex/cuts/prostate/GTEX-16NFA-1726_10240_10240.png')

In [None]:
t.shape

In [None]:
plt.figure()
plt.imshow(image[::100, ::100])

In [None]:
from torchvision import transforms

In [None]:
n = transforms.Normalize(mean=.5, std=.5)

In [None]:
i = torch.rand(4,3,64,64) * 255

In [None]:
ni = n(i)

In [None]:
ni.std()

In [None]:
t = transforms.ToTensor()

In [None]:
t(Image.open('input/preprocessed/cuts/prostate/10912_   0_1024.png'))

In [None]:
splitter(img, 256, 256)

In [None]:
Image.open('input/hmib/test_images/10078.tiff')

In [None]:
from torchvision import  transforms

In [None]:
transforms.Resize()

# BCI

In [None]:
a_dst = Path('input/extra/bci/train/HE/')
a_files = a_dst.rglob('*.png')
b_files = Path('input/extra/bci/train/IHC/').rglob('*.png')
a_files = list(a_files)
b_files = list(b_files)

len(a_files)

In [None]:
a_dst = Path('../../tmp/test/PAMBuH/input/CUTS/glomi_x33_1024/imgs/')
a_files = a_dst.rglob('*.png')
b_files = Path('../../tmp/test/PAMBuH/input/CUTS/glomi_x33_1024/masks/').rglob('*.png')
a_files = list(a_files)
b_files = list(b_files)

len(a_files)

In [None]:
import random

In [None]:
ff = random.choices(b_files, k=16)
ff

In [None]:
def load_img(p):
    i = Image.open(p)
    i = i.resize((256,256))
    i = np.array(i)
    i = torch.from_numpy(i)
    return i

In [None]:
ii = []
for f in ff:
    f = a_dst / f.parent.name / f.name
    i = load_img(f)
    ii.append(i)

In [None]:
# ias = []
# ibs = []


# for i, fb in tqdm(enumerate(b_files)):
#     fa = a_dst / fb.name
#     i1, i2 = Image.open(fa), Image.open(fb)
#     i1 = i1.resize((256,256))
#     i2 = i2.resize((256,256))
    
#     i1 = np.array(i1)
#     i2 = np.array(i2)
#     ias.append(torch.from_numpy(i1))
#     ibs.append(torch.from_numpy(i2))
#     break

In [None]:
i1 = torch.stack(ii)
i1.shape

In [None]:
gr = torchvision.utils.make_grid(i1.float().permute(0,3,1,2),nrow=4, normalize=True).permute(2,1,0)
gr.shape

In [None]:
plt.figure()
plt.imshow(gr)

In [None]:
it = iter(b_files)

In [None]:
fb = next(it)
fa = a_dst / fb.name
i1, i2 = Image.open(fa), Image.open(fb)


In [None]:
fb, fa

In [None]:
i2.size

In [None]:
i2.resize((256,256))

In [None]:
i1.resize((256,256))

# Hub old

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
from shapely import geometry
from typing import Tuple
import numpy as np
import rasterio
from rasterio.windows import Window
#from utils import jread, get_basics_rasterio, json_record_to_poly, flatten_2dlist, get_cortex_polygons, gen_pt_in_poly

import os
import math
import json
import random
import argparse
import datetime
import itertools
from pathlib import Path
from functools import partial
import multiprocessing as mp
import multiprocessing.pool
from contextlib import contextmanager
from typing import Tuple, List, Dict, Callable

import cv2
import torch
import rasterio
import numpy as np
from shapely import geometry



def jread(path: str) -> Dict:
    with open(str(path), 'r') as f:
        data = json.load(f)
    return data


def jdump(data: Dict, path: str) -> None:
    with open(str(path), 'w') as f:
        json.dump(data, f, indent=4)


def filter_ban_str_in_name(s: str, bans: List[str]): return any([(b in str(s)) for b in bans])


def get_filenames(path: str, pattern: str, filter_out_func: Callable) -> str:
    filenames = list(Path(path).glob(pattern))
    assert (filenames), f'There is no matching filenames for {path}, {pattern}'
    filenames = [fn for fn in filenames if not filter_out_func(fn)]
    assert (filenames), f'There is no matching filenames for {filter_out_func}'
    return filenames


def polyg_to_mask(polyg: np.ndarray, wh: Tuple[int, int], fill_value: int) -> np.ndarray:
    polyg = np.int32([polyg])
    mask = np.zeros([wh[0], wh[1]], dtype=np.uint8)
    cv2.fillPoly(mask, polyg, fill_value)
    return mask


def json_record_to_poly(record: Dict) -> List[geometry.Polygon]:
    num_polygons = len(record['geometry']['coordinates'])
    if num_polygons == 1:     # Polygon
        list_coords = [record['geometry']['coordinates'][0]]
    elif num_polygons > 1:    # MultiPolygon
        list_coords = [record['geometry']['coordinates'][i][0] for i in range(num_polygons)]
    else:
        raise Exception("No polygons are found")

    try:
        polygons = [geometry.Polygon(coords) for coords in list_coords]
    except Exception as e:
        print(e, list_coords)
    return polygons

def get_basics_rasterio(name):
    file = rasterio.open(str(name))
    return file, file.shape, file.count

def get_tiff_block(ds, x, y, w, h=None, bands=3):
    if h is None: h = w
    return ds.read(list(range(1, bands+1)), window=rasterio.windows.Window(x, y, w, h))

def save_tiff_uint8_single_band(img, path, bits=1):
    assert img.dtype == np.uint8
    if img.max() <= 1. : print(f"Warning: saving tiff with max value is <= 1, {path}")
    h, w = img.shape
    dst = rasterio.open(path, 'w', driver='GTiff', height=h, width=w, count=1, nbits=bits, dtype=np.uint8)
    dst.write(img, 1)
    dst.close()
    del dst


def get_cortex_polygons(anot_structs_json: Dict) -> List[geometry.Polygon]:
    return get_polygons_by_type(anot_structs_json, 'Cortex')

def get_polygons_by_type(anot_structs_json: Dict, name: str) -> List[geometry.Polygon]:
    polygons = []
    for record in anot_structs_json:
        if record['properties']['classification']['name'] == name:
            polygons += json_record_to_poly(record)
    return polygons

def flatten_2dlist(list2d: List) -> List:
    list1d = list(itertools.chain(*list2d))
    return list1d

def tiff_merge_mask(path_tiff, path_mask, path_dst, path_mask2=None):
    # will use shitload of mem
    img = rasterio.open(path_tiff).read()
    mask = rasterio.open(path_mask).read()
    #assert mask.max() <= 1 + 1e-6

    if img.shape[0] == 1:
        img = np.repeat(img, 3, 0)


    red = mask * 200 if mask.max() <= 1 + 1e-6 else mask
    img[1,...] = img.mean(0)
    img[0,...] = red

    if path_mask2 is not None:
        mask2 = rasterio.open(path_mask2).read()
        blue = mask2 * 200 if mask2.max() <= 1 + 1e-6 else mask2
        #assert mask2.max() <= 1 + 1e-6
        img[2,...] = blue

    _, h, w = img.shape
    dst = rasterio.open(path_dst, 'w', driver='GTiff', height=h, width=w, count=3, dtype=np.uint8)
    dst.write(img, [1,2,3]) # 3 bands
    dst.close()
    del dst


def gen_pt_in_poly(polygon: geometry.Polygon,
                   max_num_attempts=50) -> geometry.Point:
    """Generates randomly point within given polygon. If after max_num_attempts point has been not
    found, then returns centroid of polygon.
    """

    min_x, min_y, max_x, max_y = polygon.bounds

    num_attempts = 0
    while num_attempts < max_num_attempts:
        random_point = geometry.Point([random.uniform(min_x, max_x), random.uniform(min_y, max_y)])
        if random_point.within(polygon): return random_point
        num_attempts += 1
    return polygon.centroid


def rgb2gray(rgb: np.ndarray) -> np.ndarray:
    """Gets np.ndarray (3, ...) or (..., 3) and returns gray scale np.ndarray (...)."""

    first_channel = rgb.shape[0]
    if first_channel == 3:
        rgb = np.swapaxes(np.swapaxes(rgb, 0, 2), 0, 1) # (3, ...) -> (..., 3)
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.144])


def save_arr_as_tiff(arr: np.ndarray, path: str, nbits: int = 8) -> None:
    """Gets np.ndarray (num_bands, h, w) and returns gray scale np.ndarray (h, w) in uint8."""
    
    num_bands, h, w = arr.shape

    dst = rasterio.open(path, 'w', driver='GTiff',
                        height=h, width=w, count=num_bands,
                        nbits=nbits, dtype=np.uint8)
    dst.write(arr)
    dst.close()
    del dst


class NoDaemonProcess(mp.Process):
    # make 'daemon' attribute always return False
    def _get_daemon(self):
        return False
    def _set_daemon(self, value):
        pass
    daemon = property(_get_daemon, _set_daemon)

class NoDaemonPool(mp.pool.Pool):
    Process = NoDaemonProcess

def sigmoid(x): return 1 / (1 + np.exp(-x))


class TFReader:
    """Reads tiff files.

    If subdatasets are available, then use them, otherwise just handle as usual.
    """

    def __init__(self, path_to_tiff_file: str):
        self.ds = rasterio.open(path_to_tiff_file)
        self.subdatasets = self.ds.subdatasets
        self.is_subsets_avail = len(self.subdatasets) > 0
        if self.is_subsets_avail:
            path_to_subdatasets = self.ds.subdatasets
            self.list_ds = [rasterio.open(path_to_subdataset)
                            for path_to_subdataset in path_to_subdatasets]

    def read(self, window: Tuple[None, Window] = None, boundless: bool=True):
        if self.is_subsets_avail:
            output = np.vstack([ds.read() for ds in self.list_ds]) if window is None else \
                np.vstack([ds.read(window=window, boundless=boundless) for ds in self.list_ds])
        else:
            output = self.ds.read() if window is None else \
                self.ds.read(window=window, boundless=boundless)
        return output

    @property
    def shape(self):
        return self.ds.shape

    def __del__(self):
        del self.ds
        if self.is_subsets_avail:
            del self.list_ds

            

class GdalSampler:
    """Iterates over img with annotation, returns tuples of img, mask
    """

    def __init__(self, img_path: str,
                 mask_path: str,
                 img_polygons_path: str,
                 img_wh: Tuple[int, int],
                 border_path=None,
                 rand_shift_range: Tuple[int, int] = (0, 0)) -> Tuple[np.ndarray, np.ndarray]:
        """If rand_shift_range ~ (0,0), then centroid of glomerulus corresponds centroid of output sample
        """
        self._records_json = jread(img_polygons_path)
        self._mask = TFReader(mask_path)
        self._img = TFReader(img_path)
        self._border = TFReader(border_path) if border_path is not None else None
        self._wh = img_wh
        self._count = -1
        self._rand_shift_range = rand_shift_range
        # Get 1d list of polygons
        polygons = flatten_2dlist([json_record_to_poly(record) for record in self._records_json])
        self._polygons_centroid = [np.round(polygon.centroid) for polygon in polygons]

    def __iter__(self):
        return self

    def __len__(self):
        return len(self._records_json)

    def __next__(self):
        self._count += 1
        if self._count < len(self._records_json):
            return self.__getitem__(self._count)
        else:
            self._count = -1
            raise StopIteration("Failed to proceed to the next step")

    def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
        y,x = self._polygons_centroid[idx]
        w,h = self._wh
        y,x = y-h//2, x-w//2 # align center of crop with poly
        window = ((x, x+w),(y, y+h))
        img = self._img.read(window=window, boundless=True)
        mask = self._mask.read(window=window, boundless=True)
        if self._border is not None:
            return img, mask, self._border.read(window=window, boundless=True)

        return img, mask

    def __del__(self):
        del self._mask
        del self._img


In [None]:
def to_gray(i):return np.mean(i,-1,keepdims=True).repeat(3,-1)

def mp_sampler(dst, i_fn, m_fn, a_fn, wh, idxs):
    _wh, _wh_mask = wh, wh        
    #s = sampler.GdalSampler(i_fn, m_fn, a_fn, _wh)
    s = sampler.GdalSampler(i_fn, m_fn, a_fn, _wh)
    SCALE = 3
    
    for idx in idxs:
        #i,m,b = s[idx]  
        i, m = s[idx]  
        
        #print(idx, i.shape, m.shape)
        orig_name = (str(idx).zfill(6) + '.png')
        
        img_dir = dst / 'images' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
         
        i = i.transpose(1,2,0)
        m = m.transpose(1,2,0)
        
        #i = i.mean(-1, keepdims=True).astype(np.uint8)
        #i = i.repeat(3,-1)
        #print(i.shape, i.dtype, m.shape, m.dtype)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        
        #m = m.repeat(3,-1).astype(np.uint8)
        m = 255 * m.repeat(3,-1).astype(np.uint8)
        
        
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
    
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)
    return


In [None]:
_wh, _wh_mask = wh, wh        
s = sampler.GdalSampler(i_fn, m_fn, a_fn, _wh)
SCALE = 3

for idx in idxs:
    #i,m,b = s[idx]  
    i, m = s[idx]  
