In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
import torch
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import rasterio
from rasterio.windows import Window
from PIL import Image
plt.figure(figsize=(10, 10))

from postprocessing import SmoothTiles
from sampler import get_basics_rasterio
from utils import create_dir

In [None]:
from ipywidgets import IntProgress
from pathlib import Path

In [None]:
size_w=128

# General

In [None]:
def test_inf(batch):
    return batch[:, :1]

def read_frame(fd, x, y, h, w=None, c=3, batch_size=8):
    if w is None: w = h
    img = fd.read(list(range(1, c+1)), window=Window(x, y, h, w))
    return torch.ByteTensor(img)

def block_reader(path, inf, size=512, infer_list_flg=False):
    fd, shape, channel = get_basics_rasterio(path)
    print(shape)
    rows = []
    for ny in tqdm(range(-size//4, shape[0], size//2), desc='rows'):
        pad=(size//4, size - shape[1]%size)
        if ny < 0:
            pad = (size//4, size - shape[1]%size, size//4, 0)
        elif shape[0]-ny < size:
            pad = (size//4, size - shape[1]%size, 0, size + ny - shape[0])
        row = read_frame(fd, 0, ny, shape[1], size)
        row_img = F.pad(row, pad=pad, mode='constant', value=0)
        row_img = row_img.unsqueeze(0)
        left_i = torch.split(row_img, size, dim=3)[:-1]
        right_i = torch.split(row_img[:, :, :, size//4:], size, dim=3)

        imgs_batch = []
        for i in range(len(left_i)):
            imgs_batch.extend([left_i[i], right_i[i]])
        if infer_list_flg:
            infer_batch = inf(imgs_batch)
        else:
            infer_batch = inf(torch.cat(imgs_batch, 0))
        infer_batch = infer_batch[:, :, size//4:-size//4, size//4:-size//4]
        rows.append(torch.cat([i for i in infer_batch], 2).squeeze(0))
    return torch.cat(rows, 0)[:shape[0], :shape[1]]

def prt(img):
    print(img.shape)
    if len(img.shape) == 2:
        plt.imshow(np.array(img))
    else:
        plt.imshow(np.array(img).transpose(1, 2, 0))
    plt.show()

def save_tiff_uint8_single_band(img, path):
    if isinstance(img, torch.Tensor):
        img = np.array(img)
    elif not isinstance(img, numpy.ndarray):
        raise TypeError(f'Want torch.Tensor or numpy.ndarray, but got {type(img)}')
    assert img.dtype == np.uint8
    h, w = img.shape
    dst = rasterio.open(path, 'w', driver='GTiff', height=h, width=w, count=1, nbits=1, dtype=np.uint8)
    dst.write(img, 1) # 1 band
    dst.close()
    print(f'Save to {path}')
    del dst
    
def postprocess(infer_func, src_folder, dst_folder, save_predicts=True):
    """
    infer_func([BxCxHxV]) -> [BxCxHxV]
    src_folder - folder with test images
    dst_folder - folder to save output (RLE and predictions)
    """
    imgs_name = Path(src_folder).glob('*.tiff')
    create_dir(dst_folder)
    for img_name in tqdm(imgs_name, desc='Images', leave=False):
        img = block_reader(img_name, infer_func)
#         save_tiff_uint8_single_band(img, Path(dst_folder) / img_name.name)
        return img

ll = postprocess(test_inf, '/mnt/storage/HuBMAP/train/', './output')

In [None]:
ll.shape

# Read img

In [None]:
name ='0486052bb'
path = f'/mnt/storage/HuBMAP/train/{name}.tiff'
#0486052bb-anatomical-structure.json
#!ls /mnt/storage/HuBMAP/train

In [None]:
def read_frame(fd, x, y, h, w=None, c=3, batch_size=8):
    if w is None: w = h
    img = fd.read(list(range(1, c+1)), window=Window(x, y, h, w))
    return torch.ByteTensor(img)

def save_frame(img, name="file.tiff"):
    Image.fromarray(img).save(name)
    
def get_batch(fd, x=512, batch=8):
    k = 10000
    h = 100
    return k, h
#     return read_frame(df, k, k, h, h)

def get_sub_batch(fd, x, y, h, batch_size=8, sub_size=2):
    return read_frame(fd, x, y, h)

def simple_infer(path, inf, save_path, size=512, batch_size=4):
    fd, shape, channel = get_basics_rasterio(path)
    nx, ny = np.ceil(np.array(shape)/size).astype(int)
    print(nx, ny)
    batch = []
    heatmaps = []
#     for nyi in tqdm(range(ny), desc='rows'):
    ny = 5
    for nyi in tqdm(range(10, 15), desc='rows'):
        y = nyi*size
        for nxi in tqdm(range(nx), desc='columns', leave=False):
            x = nxi*size
            # fix pad
            if x + size > shape[0]: x = shape[0] - size
            if y + size > shape[1]: y = shape[1] - size
            img = read_frame(fd, x, y, size, c=channel)
            batch.append(img.unsqueeze(0))
            if len(batch) == batch_size:
                batch = torch.cat(batch, 0)
                heatmap = inf(batch)
                batch = []
                heatmaps.append(heatmap)
    else:
        if batch:
            heatmaps.append(inf(torch.cat(batch, 0)))
            
    return torch.cat(heatmaps, 0)
#     return torch.cat([torch.cat([column for column in torch.cat(heatmaps, 0)[raw*nx:(raw+1)*nx]], 1) for raw in range(ny)], 0)

# prt(torch.cat(rr, 0))
#     return torch.cat(heatmaps, 0)

#     return torch.reshape(torch.cat(heatmaps, 0), (nx, ny, size, size))

# def infer(path, inf, size=512, batch_size=4):
#     fd, shape, channel = get_basics_rasterio(path)
# #     nx, ny = [i//size + bool(i % size) for i in shape]
#     nx, ny = np.ceil(np.array(shape)/size).astype(int)
#     x, y = 10000, 10000
#     print(nx, ny)
#     nx, ny = 3, 2
#     for i in range(np.ceil(nx*ny/batch_size).astype(int)):
#         batch = get_sub_batch(fd, x, y, h=size, batch_size=batch_size)
#         return batch
#         inf(batch)
#         print(i)
#     for nyi in range(ny*nx//):
#         y += nyi*size
#         for nxi in range(nx):
#             x += nxi*size
#             if x + size > shape[0]: pass
#             if y + size > shape[1]: pass
#             img = read_frame(fd, x, y, size, c=channel)
#             print(img.shape)
#     return (nx, ny) #shape, channel
#     return get_banch(fd)
#     return (shape, channel)

def test_inf(batch):
    return batch[:, 0]

def prt(img):
    print(img.shape)
    plt.imshow(np.array(img))
#     plt.imshow(np.array(img).transpose(1, 2, 0))
    plt.show()

ll = simple_infer(path, test_inf, None)
ll.shape
# prt(ll)

In [None]:
prt(ll)

In [None]:
rr = [torch.cat([i for i in ll[ii*51:(ii+1)*51]], 1) for ii in range(5)]
# torch.cat([i for i in ll], 1)
prt(torch.cat(rr, 0))

In [None]:
ww = torch.reshape(ll, (51, 5, 512, 512))
ww.shape
www = []
for i in ww:
    www.append(torch.cat([j for j in i], 0))
# print(len(www))
ww1 = torch.cat(www, 1)
# print(ww1.shape)

# rr = [torch.cat([i for i in ll[ii*51:(ii+1)*51]], 1) for ii in ww]
# prt(torch.cat(rr, 0))

In [None]:
prt(ww1)

In [None]:
rr = torch.cat([i for i in ll], 1)
prt(rr)

In [None]:
ll.permute(1, 2, 0).shape

In [None]:
ww = torch.reshape(ll.permute(1, 2, 0), (512, 512, 51, 5))
print(ww.shape)
# rr = torch.cat([i for i in ww[:, 0]], 1)
# prt(rr)

In [None]:
torch.cat([torch.cat([column for column in img[raw*nx:(raw+1)*nx]], 1) for raw in range(ny)], 0)

def concat_xy(img, nx, ny):
    print(img.shape)
    batch_frame = []
    for bf in range(ny):
#     for batch_frame in img:
        ll.append(torch.cat([i for i in bf], 1))
#     print(len(ll))
    return torch.cat(batch_frame, 0)


#         ll.append(torch.cat(heatmaps, 0))
#     return img.shape, torch.cat(heatmaps, 0)
#     n = int(np.sqrt(img.shape[-1]))
#     big_img = torch.reshape(img, (img.shape[0], img.shape[1], n, n))
#     sub_img = []
#     for i in range(n):
#         tpl_img = tuple([j for j in big_img[:, :, :, i].permute(2, 0, 1)])
#         sub_img.append(np.concatenate(tpl_img, axis=1))
#     return np.concatenate(tuple(sub_img), axis=0)
r = concat_xy(ww)

In [None]:
prt(r)

In [None]:
smooth = st.merge(img_main, img_sub)
# print(smooth.shape)
# plt.imshow(concat(smooth.to('cpu')))

# Generate Test

In [None]:
def get_random_prediction(size=128, n=2, cuda=True):
    device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
    img_main = torch.normal(0, 1, size=(size, size, n**2)).to(device)
    img_sub = torch.normal(0, 1, size=(size, size, (n+1)**2)).to(device)
    return img_main, img_sub

def get_black_prediction(size=128, n=2, cuda=True):
    device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
    img_main = torch.zeros((size, size, n**2)).to(device)
    img_sub = torch.ones((size, size, (n+1)**2)).to(device)
    return img_main, img_sub

# img_main, img_sub = get_random_prediction(size_w, 3)
img_main, img_sub = get_black_prediction(size_w, 3)
print(img_main.shape, img_sub.shape)

# Show window

In [None]:
st = SmoothTiles()

In [None]:
plt.plot(st.triangle(size_w))

In [None]:
plt.imshow(st.window_2D(size_w, st.triangle))

In [None]:
plt.plot(st.gauss(size_w, 20))

In [None]:
plt.imshow(st.window_2D(size_w, st.gauss))

In [None]:
plt.plot(st.spline(size_w))

In [None]:
plt.imshow(st.window_2D(size_w, st.spline))

# Smooth Tiles

In [None]:
def concat(img):
    n = int(np.sqrt(img.shape[-1]))
    big_img = torch.reshape(img, (img.shape[0], img.shape[1], n, n))
    sub_img = []
    for i in range(n):
        tpl_img = tuple([j for j in big_img[:, :, :, i].permute(2, 0, 1)])
        sub_img.append(np.concatenate(tpl_img, axis=1))
    return np.concatenate(tuple(sub_img), axis=0)

In [None]:
smooth = st.merge(img_main, img_sub)
print(smooth.shape)
plt.imshow(concat(smooth.to('cpu')))