In [31]:
import pandas as pd
from torch.utils.data import Dataset
from typing import Optional, List
from pathlib import Path
import rasterio
from tqdm import tqdm
import numpy as np
import cv2
import warnings
from segmentation.scr.rle_coding import *
from segmentation.config import Configs as CFG
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)


In [66]:
class TiledDataset(Dataset):
    def __init__(
        self,
        name_data: str,
        path_img_dir: str,
        path_lb_dir: str,
        tile_size: Optional[list] = CFG.tile_size,
        overlap_pct: float = CFG.overlap_pct,
        strong_empty: bool = True,
        sample_limit: int = 30000 * 10,
        cache_dir: str = None,
    ):
        self.name_data = name_data
        self.path_img_dir = Path(path_img_dir)
        self.path_lb_dir = Path(path_lb_dir)
        self.tile_size = np.array(tile_size)
        self.overlap_pct = overlap_pct
        self.strong_empty = strong_empty
        self.sample_limit = sample_limit
        self.cache_dir = cache_dir
        
        self.path_img_dir = sorted(list(self.path_img_dir.rglob("*.tif")))
        self.samples = []
        for p_img in self.path_img_dir[:500]:
            p_lb = self.path_lb_dir / p_img.name
            with rasterio.open(p_img) as reader:
                width, height = reader.width, reader.height
                img_r = reader.read()
                px_max, px_min = img_r.max(), img_r.min()
            # print(p_lb)
            self.samples.append((p_img, p_lb, [px_min, px_max], [height, width]))

        min_overlap = float(overlap_pct) * 0.01
        max_stride = self.tile_size * (1.0 - min_overlap)

        list_tiles = []
        empty = 0
        nonempty = 0

        for file_path, label_path, px_stats, img_dims in tqdm(
            self.samples, total=len(self.samples), desc="Generating tiles"
        ):
            # [(C,H,W),...]
            height, width = img_dims
            num_patches = np.ceil(np.array([height, width]) / max_stride).astype(
                np.int64
            )
            starts = [
                np.int32(np.linspace(0, height - self.tile_size[0], num_patches[0])),
                np.int32(np.linspace(0, width - self.tile_size[1], num_patches[1])),
            ]
            stops = [starts[0] + self.tile_size[0], starts[1] + self.tile_size[1]]
            mask = cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE).astype(np.uint8)
            rle = rle_encode(mask)

            for y1, y2 in zip(starts[0], stops[0]):
                for x1, x2 in zip(starts[1], stops[1]):
                    mask_tile = mask[y1:y2, x1:x2]
                    is_empty = np.all(mask_tile == 0)

                    if self.strong_empty:
                        is_empty = is_empty or (mask_tile.sum() < (0.05 * self.tile_size[0]))
                    if is_empty:
                        empty += 1
                    else:
                        nonempty += 1
                    
                    list_tiles.append(
                        (
                            file_path,
                            rle,
                            is_empty,
                            (x1, y1, x2 - x1, y2 - y1),
                            px_stats,
                            (height, width),
                        )
                    )
        print(empty, nonempty )
        self.tiles = []
        if self.cache_dir:
            if isinstance(self.cache_dir, str):
                self.cache_dir = Path(self.cache_dir)
            cache_dir1 = self.cache_dir / self.name_data
            cache_dir1.mkdir(parents=True, exist_ok=True)
            
        self.df = pd.DataFrame(columns =['path_img', 'rle', 'is_empty', 'bbx', 'px_stats', 'size'])
        self.path_df = self.cache_dir / (self.name_data +'.csv')

        if self.sample_limit < len(list_tiles):
            pos_idxs_to_sample = np.random.choice(
                len(list_tiles), min(sample_limit, len(list_tiles)), replace=False
            )
            self.tiles = list(map(list_tiles.__getitem__, pos_idxs_to_sample))
        else:
            self.tiles = list_tiles

        # rasterio.windows.Window(col_off, row_off, width, h

In [67]:
train_set = TiledDataset(
    name_data='kidney_1_tilling',
    path_img_dir=CFG.path_img_kidney1,
                         path_lb_dir=CFG.path_lb_kidney1,
                         cache_dir=CFG.cache_dir 
                         )

Generating tiles: 100%|██████████| 500/500 [00:03<00:00, 128.55it/s]

754 2246





In [29]:
train_set.tiles[400]

(WindowsPath('data/train/kidney_1_dense/images/0066.tif'),
 '358934 1 359846 1 388995 2 420817 2 420842 2 421729 2 421755 1 438219 2 457343 1 458225 1 459137 1 471020 2 471932 2 472845 1',
 True,
 (0, 791, 512, 512),
 [19416, 31357],
 (1303, 912))

In [46]:
bbx = train_set.tiles[400][3]

In [48]:
x, y, w, h = bbx

In [42]:
img = cv2.imread(str(train_set.tiles[400][0]), cv2.IMREAD_GRAYSCALE).astype(np.uint8)

In [56]:
lb lb = cv2.imread('data/train/kidney_1_dense/labels/0066.tif', cv2.IMREAD_GRAYSCALE).astype(np.uint8)

In [14]:
pd.read_csv('data\kidney_1_tilling.csv')

Unnamed: 0,path_img,rle,is_empty,bbx,px_stats,size


In [None]:
train_set.path_img_dir[0].parent.parent

WindowsPath('data/train/kidney_1_dense')

In [None]:
x = pd.DataFrame(columns =['x', 'y']) 


In [60]:
x.loc[1,:] = [1,1]

In [61]:
x

Unnamed: 0,x,y
0,1,1
1,1,1


In [51]:
1528+12146 

13674

In [46]:
len(train_set.list_tiles)

3000

In [20]:
train_set.samples[0]

(WindowsPath('data/train/kidney_1_dense/images/0000.tif'),
 WindowsPath('data/train/kidney_1_dense/labels/0000.tif'),
 [18515, 36640],
 [1303, 912])

In [26]:
mask.shape

(1303, 912)

In [31]:
mask = cv2.imread('data/train/kidney_1_dense/labels/1000.tif', cv2.IMREAD_UNCHANGED).astype(np.uint8)
shape = mask.shape
rle = rle_encode(mask)
img_rle = rle_decode(rle, shape)

In [35]:
mask[mask == 255] = 1

In [36]:
(img_rle == mask).all()

True

In [39]:
rle_1000 = df_train[df_train['id'] == 'kidney_1_dense_1000']['rle'].values
img_rle = rle_decode(rle_1000[0], shape)


'0000.tif'