In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
os.chdir('..')
sys.path.append('src')

In [None]:
import rasterio
import torch
import numpy as np
from pathlib import Path
import pandas as pd
import torchvision
from dataclasses import dataclass
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib widget 

from data import DfDataset, MaskGenerator, MainDataset

from build_data import DatasetsGen, init_datasets
from augs import create_augmented

In [None]:
df = pd.read_csv('input/hmib/train.csv')
cfg = OmegaConf.load('src/configs/u.yaml')
cfg.PARALLEL.LOCAL_RANK=0

In [None]:
cfg.DATA.root='hmib/train_images/'
cfg.DATA.ann_root='hmib/train_annotations/'

In [None]:
dg = DatasetsGen(cfg)
ds = init_datasets(cfg, dg, ['TRAIN'])

In [None]:
tds = ds['TRAIN']
x,y = tds[0]['x'], tds[0]['y']
#x.shape, x.dtype, y.shape,  x.dtype, len(tds)

In [None]:
from dali import build_daliloaders

In [None]:
tds[0]

In [None]:
dls = build_daliloaders(cfg, ds)
tdl = dls['TRAIN']

In [None]:
for b in tdl:
    break

In [None]:
b[0]['y'].contiguous

In [None]:
l = tds.ds.ds.labels[0]

In [None]:
import dataclasses

In [None]:
lc = dataclasses.replace(l)

In [None]:
mg = MaskGenerator(192, mask_ratio=.0)

In [None]:
m = mg()

In [None]:
plt.figure()
plt.imshow(m)

In [None]:
1/0

In [None]:
ads = create_augmented(cfg, ds)
atds = ads['TRAIN']

In [None]:
i = atds[0]
x,y = i['x'][0], i['y'][0]
x.shape, y.shape,  len(atds)

In [None]:
plt.figure()
plt.imshow(x.permute(1,2,0))

In [None]:
plt.figure()
plt.imshow(y.permute(1,2,0))

In [None]:
def batch_quantile(b, q=.01):
    br = b.view(b.shape[0],-1)
    rq = torch.quantile(br, dim=1, q=1-q).view(-1,1,1,1)
    lq = torch.quantile(br, dim=1, q=  q).view(-1,1,1,1)
    return torch.max(torch.min(b, rq), lq)

def nppclip(a, p=1):
    l, r = np.percentile(a, p), np.percentile(a, 100-p) 
    ac = np.clip(a, l, r)
    return ac

In [None]:
t = ar#[500:2500,500:2500]
t = nppclip(t)
t = np.stack([t[...,2],t[...,1],t[...,0]], -1)
t = (t - t.min()) / (t.max() - t.min())
plt.figure(figsize=(10,10))
plt.imshow(t)

In [None]:
imgs_path = Path('input/hmib/train_images/')
ann_path = Path('input/hmib/train_annotations/')
dst = Path('input/preprocessed/png1024overlap')
dst.mkdir()

In [None]:
#tdf = df[df.organ == 'spleen']
tdf = df

In [None]:
from data import read_ann, convert_ann
from rasterio.features import rasterize

In [None]:
H, W = 1024, 1024
ch, cw = 512,512
imgs = []
masks = []
for i,row in tqdm(tdf.iterrows(), total=len(tdf)):
    #if i < 10:
    #    continue
    #break
    f = imgs_path / str(row['id'])
    f = f.with_suffix('.tiff')
    fd = rasterio.open(f)
    a = fd.read().transpose(1,2,0)
    fd.close()
    h,w,c = a.shape
    ar = cv2.resize(a, (W, H))
    ar = ar.astype(np.uint8)
    name = f.with_suffix('.png').name
    
    ss = splitter(ar, ch, cw)
    name, ext = name.split('.')
    for ii, (s,*_) in enumerate(ss):
        fixname = f'{name}_{ii}.{ext}'
        cv2.imwrite(str(dst / 'images' / fixname), s)
    
    #cv2.imwrite(str(dst / 'images' / name), ar)
    
    annf = (ann_path / f.name).with_suffix('.json')
    data = read_ann(annf)
    poly = convert_ann(data)
    mask = rasterize([poly], out_shape=(h,w))
    mask = cv2.resize(mask, (W,H))
    
    ss = splitter(mask, ch, cw)
    #name, ext = name.split('.')
    for ii,  (s,*_) in enumerate(ss):
        fixname = f'{name}_{ii}.{ext}'
        cv2.imwrite(str(dst / 'masks' / fixname), s)
    
    #cv2.imwrite(str(dst / 'masks' / name), mask)
    
    
    # masks.append(mask)
    # imgs.append(ar)
    #break

In [None]:
cv2.imread('input/preprocessed/png1024/masks/10044.png', cv2.IMREAD_UNCHANGED).shape

In [None]:
imgs = [torch.from_numpy(a) for a in imgs]
masks = [torch.from_numpy(a) for a in masks]
imgs = torch.stack(imgs)
masks = torch.stack(masks)
masks = masks.unsqueeze(-1)
imgs.shape, masks.shape

In [None]:
organs = list(set(df.organ))
organs

In [None]:
organ = organs[0]
idxs = df[df.organ==organ].index

In [None]:
idxs = range(64)

In [None]:
i,m = imgs[idxs], masks[idxs]

In [None]:
m.shape

In [None]:
s=0
gr = torchvision.utils.make_grid(m.float().permute(0,3,1,2)[s:s+100], normalize=True).permute(2,1,0)
gr.shape

In [None]:
plt.hist(t.flatten().numpy(), bins=30);

In [None]:
plt.figure(figsize=(20,10))
plt.imshow(gr)

In [None]:
import cv2


def start_points(size, split_size, overlap=0):
    points = [0]
    stride = int(split_size * (1-overlap))
    counter = 1
    while True:
        pt = stride * counter
        if pt + split_size >= size:
            points.append(size - split_size)
            break
        else:
            points.append(pt)
        counter += 1
    return points



def splitter(img, crop_w, crop_h):    
    img_h, img_w, *_ = img.shape
    X_points = start_points(img_w, crop_w, 0.5)
    Y_points = start_points(img_h, crop_h, 0.5)
    
    for i in Y_points:
        for j in X_points:
            split = img[i:i+crop_h, j:j+crop_w]
            #cv2.imwrite('{}_{}.{}'.format(name, count, frmt), split)
            yield split, j, i, crop_h, crop_w

In [None]:
ar.shape

In [None]:
ss = splitter(mask, 512, 512)

In [None]:
t = []
for s, y,x,h,w in ss:
    print(s.shape, y,x,h,w)
    t.append(torch.from_numpy(s).unsqueeze(-1))
t = torch.stack(t)

In [None]:
t.shape

In [None]:
gr = torchvision.utils.make_grid(t.permute(0,3,1,2), nrow=3).permute(1,2,0)
gr.shape

In [None]:
plt.figure()
plt.imshow(gr*255)

In [None]:

def generate_block_coords(H, W, block_size):
    h,w = block_size
    nYBlocks = (int)((H + h - 1) / h)
    nXBlocks = (int)((W + w - 1) / w)
    
    for X in range(nXBlocks):
        cx = X * h
        for Y in range(nYBlocks):
            cy = Y * w
            yield cy, cx, h, w

In [None]:
g = generate_block_coords(1024, 1024, (256,256))

In [None]:
for i, r in df.iterrows():
    break

In [None]:
r

In [None]:
imgs = TiffImages('input/hmib/train_images/')
anns = JsonAnnotations('input/hmib/train_annotations/')
data = DataPair(imgs, anns)

In [None]:
labels = read_meta(df)
len(labels)

In [None]:
r = data(labels[0])

In [None]:
ds = MainDataset(cfg,
                 'input/hmib/train_images/',
                 'input/hmib/train_annotations/',
                 'input/hmib/train.csv',
                 [Path('input/splits/0.csv'),],
                 train=True,
                 BaseLoader=DfDataset,
                 rate=1
                )

In [None]:
ds[2]

In [None]:
inddf = pd.read_csv('input/splits/0.csv', header=None, index_col=None)#.values()

In [None]:
ds = DfDataset(data, df, inddf)
len(ds)

In [None]:
dl = torch.utils.data.DataLoader(ds, batch_size=4)

In [None]:
for b in dl:
    break

In [None]:
b['y'].shape

In [None]:
df = pd.read_csv('input/hmib/train.csv')

In [None]:
from sklearn.model_selection import StratifiedKFold, KFold


In [None]:
skf = KFold(n_splits=4, random_state=42, shuffle=True)
# skf = KFold(n_splits=4, )

In [None]:
folds = []
for i, (_, fold) in enumerate(skf.split(df, df['organ'])):
    folds.append(fold)
    pd.Series(fold).to_csv(f'input/splits/{i}.csv', index=None, header=None)

In [None]:
[len(f) for f in folds]

In [None]:
df.organ.value_counts()

In [None]:
df.iloc[fold].organ.value_counts()

In [None]:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')


In [None]:
inp = torch.zeros(4,3,192,224)

In [None]:
r = vits16(inp)
r.shape

In [None]:
vits16.blocks[0].norm1

In [None]:
for k,v in vits16.named_parameters():
    print(k, v.shape)