In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import sys
# import os
# os.chdir('..')
# sys.path.append('src')

In [None]:
import torch
import timm

In [None]:
import os
import sys
import importlib

import fire
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np

from omegaconf import OmegaConf

import torch
import cv2
from PIL import Image
import rasterio
from torchvision import transforms

import ttach as tta
from collections import defaultdict

import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
def init_data_module_from_checkpoint(root, name, file_name):
    spec = importlib.util.spec_from_file_location(name, str(root/f'src/{file_name}'))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod


def init_modules(p, module_name='network'):
    p = Path(p)
    sys.path.insert(0, str(p / 'src'))

    if module_name in sys.modules:
        del sys.modules[module_name]
        
        
    for m in ['buildingblocks', 'basemodels', 'segformer', 'tools_tv']:
        try:
            del sys.modules[m]
        except:
            pass

    module = init_data_module_from_checkpoint(p, module_name, f'{module_name}.py')
    sys.path.pop(0)
    return module


def tiff_reader(fn):
    img = rasterio.open(fn).read().transpose(1, 2, 0)
    img = np.array(img)
    return img


def get_inferer(model_path):
    root = model_path.parent.parent.parent
    cfg_path = root / 'src/configs/u.yaml'
    cfg = OmegaConf.load(cfg_path)

    network = init_modules(root, 'network')
    tools_tv = init_modules(root, 'tools_tv')

    m = network.model_select(cfg)()
    weights = torch.load(model_path)['model_state']['cls']
    weights.keys()
    m.load_state_dict(weights)
    m = m.cuda().eval()

    def inf(x):
        with torch.cuda.amp.autocast(enabled=True):
            with torch.no_grad():
                x = x.contiguous()
                x = tools_tv.batch_quantile(x, q=.005)
                x = (x - cfg.AUGS.MEAN[0]) / cfg.AUGS.STD[0]
                x = x.clamp(-cfg.FEATURES.CLAMP, cfg.FEATURES.CLAMP)
                pred = m(dict(xb=x))
                pred = dict(yb=pred['yb'].sigmoid())
        return pred
    return inf, cfg


def preocess_images(cfg, images, infer, reader, dst, scale):
    EXT = '.png'
    # TODO: cfg stats
    
    for ii, fn in enumerate(tqdm(images)):
        img = reader(fn)
        H, W, C = img.shape
        x = torch.from_numpy(img).unsqueeze(0)
        x = x.float().cuda()
        
        pred = infer(x)
        pred = pred['yb']
        pred = pred.cpu()
        depred = pred[0].permute(1, 2, 0) * 255.
        depred = depred.sum(-1).numpy()
        depred = depred.astype(np.uint8)
        cv2.imwrite(str(dst / fn.with_suffix(EXT).name), cv2.cvtColor(depred, cv2.COLOR_RGB2BGR))
        #break
    return #depred

def inferencer(path, organ=None, images_folder=None, extension='*', atob=True, gpun=0, scale=1):
    #os.environ['CUDA_VISIBLE_DEVICES'] = str(gpun)
    path = Path(path)
    
    infer, cfg = get_inferer(path, scale)

    images_root = Path('../hmib/input/hmib/')
    df = pd.read_csv(str(images_root / 'train.csv'))
    idxs = df[df.organ == organ].index
    images = [images_root / 'train_images' / f'{df.iloc[idx].id}.tiff' for idx in idxs]
    
    dst = path.parent.parent / f'train_images_{path.with_suffix("").name}'
    dst.mkdir(exist_ok=True)

    preocess_images(cfg, images, infer, reader=tiff_reader, dst=dst, scale=scale)

def png_reader(p):
    return np.array(Image.open(str(p))).transpose(2,0,1)

In [None]:

class EnsembleInfer:
    def __init__(self, infers):
        self.infers = infers
        self.mode = 'avg'

    def __call__(self, xb, **kwargs):
        res = defaultdict(list)
        for inf in self.infers:
            pred = inf(xb, **kwargs)
            for k,v in pred.items():
                res[k].append(v)
                #print(v)
            #res.append(pred)
        reduced = {}
        for k,v in res.items():
            v = torch.stack(v)
            if self.mode == 'avg':v = v.mean(0)
            elif self.mode == 'max':v = v.max(0)
            reduced[k] = v

        return reduced
    

class CTTA(torch.nn.Module):
    def __init__(self, infer, transformers, ignore_keys=['ds', 'cls'], keys_demask=['yb'], mode='avg'):
        super().__init__()
        self.transformers = transformers
        self.infer = infer
        self.keys_demask = keys_demask
        self.mode = mode
        self.ignore_keys = ignore_keys
        
    def forward(self, xb, **kwargs):
        res = defaultdict(list)
        for transformer in self.transformers: # custom transforms or e.g. tta.aliases.d4_transform()
            axb = transformer.augment_image(xb)
            #print(transformer, axb.shape)
            pred = self.infer(axb.contiguous(), **kwargs)
            assert isinstance(pred, dict), type(pred)
            for k,v in pred.items():
                if k in self.keys_demask:
                    v = transformer.deaugment_mask(v)
                if k not in self.ignore_keys:
                    res[k].append(v)

        for k,v in res.items():
            #print(k, v)
            if self.mode == 'avg':
                res[k] = torch.stack(v).mean(0)
            elif self.mode == 'max':
                res[k] = torch.stack(v).max(0)[0]

        return res


information needed to split LB-public into LB-public-Hubmap and LB-public-HPA

"roughly 550 test images "
there are exactly 529 test images, of which Hubmap=448, HPA=81

public test : 55% of the test data()
291 --> Hubmap=210 (0.7216), HPA=81 (0.2783)

private test = 45%
238--> Hubmap= 238

In [None]:
dst = Path('../input/predict/pseudo_0/')#model_path.parent.parent / f'val_images_{model_path.with_suffix("").name}'
dst.mkdir(exist_ok=True)
    
    
for split in range(4):
    df_idx = pd.read_csv(f'../input/splits/{split}.csv', header=None)

    images_root = Path('../input/hmib/')
    df = pd.read_csv(str(images_root / 'train.csv'))
    images = [Path('../input/preprocessed/rle1024/images') / f'{df.iloc[idx].id}.png' for idx in df_idx.values.flatten()]
    
#     models = {
#         0:[
#             Path(f'../output/08-16/12-01-22_unet_convnext_small_in22ft1k/split_0/models/e7_t100_cmax_ema_0.7507.pth'),
#             # Path(f'../output/08-16/13-39-43_unet_dm_nfnet_f2/split_0/models/e5_t100_cmax_ema_0.7579.pth'),
#         ],
#         1:[
#             Path(f'../output/08-16/12-22-15_unet_convnext_small_in22ft1k/split_1/models/e6_t100_cmax_ema_0.7964.pth'),
#             # Path(f'../output/08-16/14-15-04_unet_dm_nfnet_f2/split_1/models/e5_t100_cmax_ema_0.8186.pth'),
#         ],
#         2:[
#             Path(f'../output/08-16/12-41-59_unet_convnext_small_in22ft1k/split_2/models/e9_t100_cmax_ema_0.7961.pth'),
#             # Path(f'../output/08-16/14-50-24_unet_dm_nfnet_f2/split_2/models/e6_t100_cmax_ema_0.8129.pth'),
#         ],
#         3:[
#             Path(f'../output/08-16/13-05-05_unet_convnext_small_in22ft1k/split_3/models/e7_t100_cmax_ema_0.8172.pth'),
#             # Path(f'../output/08-16/15-28-17_unet_dm_nfnet_f2/split_3/models/e2_t100_cmax_ema_0.8104.pth'),

#         ]}

    model_paths = [Path('../output/08-08/18-57-32_unet_resnet34/split_0/models/e64_t100_cmax_ema_0.7356.pth')]#models[split]

    infers = []
    for model_path in model_paths:
        infer, cfg = get_inferer(model_path, )
        # transforms = tta.Compose([tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90([90,270])])
        # tta_infer = CTTA(infer, transforms)
        infers.append(infer)

    ensinfer = EnsembleInfer(infers)
    infer = ensinfer#infers[1]

    r = preocess_images(cfg, images, infer, png_reader, dst, 1)
    
    break

In [None]:
# images = list(Path('../input/preprocessed/rle1024/images/').glob('*.png'))
# len(images)
# parti = images[:4]
# len(parti)

In [None]:
#Image.open('../output/08-16/14-15-04_unet_dm_nfnet_f2/split_1/val_images_e5_t100_cmax_ema_0.8186/18445.png')

In [None]:
pseudos = Path('../input/predict/pseudo/')
images = Path('../input/preprocessed/rle1024/images/')
masks = Path('../input/preprocessed/rle1024/masks/')
dst = Path('../input/predict/combined/')

df = pd.read_csv(str(images_root / 'train.csv'))
for i,row in tqdm(df.iterrows()):
    idx = row.id
    img = np.array(Image.open(images / f"{idx}.png"))
    img = img.mean(2)
    mask = np.array(Image.open(masks / f"{idx}.png")) * 255.
    pseudo = np.array(Image.open(pseudos / f"{idx}.png"))[...,0] # rgb for some reason
    
    r = np.stack([mask, img, pseudo], -1).astype(np.uint8)
    fn = dst / row.organ / f"{idx}.png"
    fn.parent.mkdir(exist_ok=True)
    cv2.imwrite(str(fn), cv2.cvtColor(r, cv2.COLOR_RGB2BGR))
        
    #break

In [None]:
mask.max(), img.max(), pseudo.max()

In [None]:
r = Image.open('../input/predict/combined/largeintestine/28791.png')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(r)

In [None]:
def get_dice(x, y, eps = 1e-6):                                        
    intersection = (x * y).sum()
    dice = ((2. * intersection + eps) / (x.sum() + y.sum() + eps))
    return dice

In [None]:
a = Image.open('../input/preprocessed/rle1024/masks/10044.png')
b = Image.open('../input/preprocessed/rle1024/masks/10274.png')
aa = np.array(a)
bb = np.array(b)

In [None]:
get_dice(bb.flatten(), bb.flatten())

In [None]:
# plt.figure()
# plt.imshow(bb)

In [None]:
# plt.figure()
# plt.imshow(aa)

In [None]:
pseudos = Path('../input/predict/pseudo_0/')
# images = Path('../input/preprocessed/rle1024/images/')
masks = Path('../input/preprocessed/rle1024/masks/')
# dst = Path('../input/predict/combined/')

df = pd.read_csv(str(images_root / 'train.csv'))
dices = defaultdict(list)

for i,row in tqdm(df.iterrows()):
    idx = row.id
    try:
        mask = np.array(Image.open(masks / f"{idx}.png"))
        pseudo = np.array(Image.open(pseudos / f"{idx}.png"))[...,0] # rgb for some reason
    except FileNotFoundError:
        continue
    mask = mask[128:-128, 128:-128]
    pseudo = pseudo[128:-128, 128:-128]
    x = pseudo.flatten() / 255. > .5
    y = mask.flatten() > .5
    
    dice = get_dice(x, y)
    dices[row.organ].append((dice, idx))
    #break

In [None]:
total = 0
cnt = 0
for k, v in dices.items():
    if not v:
        continue
    cnt += len(v)
    dd = [i[0] for i in v]
    idxs = [i[1] for i in v]
    print(f"{k:15} {np.mean(dd):.3f}, {np.std(dd):.3f}")
    total += np.mean(dd)
total /= 5

In [None]:
cnt

In [None]:
plt.figure()
plt.imshow(np.array(Image.open('../input/preprocessed/tiff1024/images/5102.tiff'))*255)

In [None]:
Image.open('../input/preprocessed/rle1024//images/5102.png')

In [None]:
Image.open('../input/predict/pseudo_single_notta/5102.png')

prostate        0.769, 0.205
spleen          0.701, 0.257
lung            0.269, 0.258
kidney          0.916, 0.108
largeintestine  0.880, 0.107

0.7069703251310566

In [None]:
total

In [None]:
k = 'kidney'
v = dices[k]
dd = [i[0] for i in v]
idxs = [i[1] for i in v]
si = np.argsort(dd)    

In [None]:
j=2
dd[si[j]], idxs[si[j]]

In [None]:
plt.figure()
plt.plot(dd)