In [None]:
import torch
from typing import TypeVar
from tqdm.auto import tqdm

from segmentation.scr.utils.rle_coding import *
from segmentation.scr.tilling_dataset import Tilling_Dataset
import matplotlib.pyplot as plt
from skimage import color

PandasDataFrame = TypeVar('pandas.core.frame.DataFrame')

In [None]:
tilling_1 = Tilling_Dataset(
    name_data='kidney_3_tilling',
    path_to_df='data\\kidney_3_tilling.csv',
    use_random_sub=False,
    empty_tile_pct=0,
    sample_limit=None,
    transform=None
)

In [None]:
tilling_1.df['is_empty'].value_counts().plot.pie(y='type', autopct='%.1f%%', legend=False)
plt.title('Percent empty images');

In [None]:
tilling_1.df['is_empty'][tilling_1.df['is_empty'] == False]

In [None]:
img, mask, bbx, size = tilling_1[5]



fig, axarr = plt.subplots(ncols=3, figsize=(12, 6))
axarr[0].imshow(torch.permute(img, (1,2,0)).numpy(), cmap="gray")
axarr[1].imshow(color.label2rgb(mask.numpy(), torch.permute(img, (1,2,0)).numpy(), bg_label=0, bg_color=(1.,1.,1.), alpha=0.25))
axarr[2].imshow(mask, vmin=0, interpolation='antialiased', interpolation_stage='rgba')

# Dataset statistic

In [None]:
loader = torch.utils.data.DataLoader(
    tilling_1 ,
    batch_size=5,
    
    shuffle=False,
    pin_memory=True
)

In [None]:
psum = torch.tensor([0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0])

num_pixels = 0
pbar = tqdm(enumerate(loader), total=len(loader), desc='Calculate')
for step, batch in pbar:
    images, _, _, _ = batch
    batch_size, num_channels, height, width = images.shape
    num_pixels += batch_size * height * width
    
    psum += images.sum(axis=[0, 2, 3])
    psum_sq += (images**2).sum(axis=[0, 2, 3])

    
# mean and STD
total_mean = psum / num_pixels
total_var = (psum_sq / num_pixels)  - (total_mean ** 2)
total_std = torch.sqrt(total_var)

# output
print("mean: " + str(total_mean))
print("std:  " + str(total_std))


In [None]:
img, mask, is_empty, bbx, size = data_loader[8]
augmented = train_transform(image=img,mask=mask)
img,mask = augmented['image'],augmented['mask']


fig, axarr = plt.subplots(ncols=3, figsize=(12, 6))
axarr[0].imshow(torch.permute(img, (1,2,0)).numpy(), cmap="gray")
axarr[1].imshow(color.label2rgb(mask.numpy(), torch.permute(img, (1,2,0)).numpy(), bg_label=0, bg_color=(1.,1.,1.), alpha=0.25))
axarr[2].imshow(mask, vmin=0, interpolation='antialiased', interpolation_stage='rgba')