# Notebook to explore data and test image loaders

In [None]:
from skimage.io import imread
import os
import numpy as np
import seaborn as sns
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from utils.visualisation import plot_band_distribution
from utils.data import create_dataloaders, PlanetBaseDataset, normalized_image
from utils.evaluation import visualise_batch_predictions

# load config
from config import config

DATA_PATH = config.PATH_TO_DATA

%load_ext autoreload
%autoreload 2

### Visualise one image and mask

In [None]:
# load first image and mask
image_path = os.path.join(DATA_PATH, r"images\10000.tif")
label_path = os.path.join(DATA_PATH, r"labels\10000.tif")

image = imread(image_path)
print(image.shape)

label = imread(label_path)
print(label.shape)

In [None]:
normalized = normalized_image(image)

fig, axes = plt.subplots(
    nrows=1, ncols=3, sharex=True, sharey=True, figsize=(18, 8)
)
axes[0].imshow(normalized[:, :, [0, 1, 2]])
axes[0].set_title("True Color Image")
axes[0].axis("off")

axes[1].imshow(normalized[:, :, [3, 0, 1]])
axes[1].set_title("False Color Image (NINFR, RED, GREEN)")
axes[1].axis("off")

axes[2].imshow(label)
axes[2].set_title("Segmentation Mask")
axes[2].axis("off")

### Loop through all masks to inspect class disitrubtions

In [None]:
label_paths = os.path.join(DATA_PATH, r"labels")

positives = 0
all_pixels = 0
has_any = list()
output = list()

for i, mask_path in enumerate(os.listdir(label_paths)):
    mask_path = os.path.join(DATA_PATH, r"labels", mask_path)
    label = imread(mask_path)
    true_pixels = label.sum()
    total = np.prod(label.shape)  # proportion of true in the image
    has_true = True if true_pixels > 0 else False

    output.append(true_pixels / total)
    positives += true_pixels
    all_pixels += total
    has_any.append(has_true)
    if i % 100 == 0:
        print("Processing Image {}".format(i))

In [None]:
print("Overall Proportion of Positive Pixels: {:.2f}".format(positives / all_pixels))
print("All Masks have at least one positive: {}".format(all(has_any)))

# Distribution of Proportion of positive pixels across all labels
sns.histplot(output, kde=True)

## Test and visualise data loaders

In [None]:
train_loader, val_loader, test_loader = create_dataloaders(DATA_PATH, batch_size=4)

In [None]:
def illustrate_data_loader(loader, show_n_images: int):
    for i, (batch_sample, batch_masks) in enumerate(loader):
        image, label = batch_sample[0], batch_masks[0]
        label = label.unsqueeze(0)
        image, label = torch.permute(image, (1, 2, 0)), torch.permute(label, (1, 2, 0))
        
        normalized = normalized_image(image.numpy())

        fig, axes = plt.subplots(
            nrows=1, ncols=3, sharex=True, sharey=True, figsize=(18, 8)
        )
        axes[0].imshow(normalized[:, :, [0, 1, 2]])
        axes[0].set_title("True Color Image")
        axes[0].axis("off")

        axes[1].imshow(normalized[:, :, [3, 0, 1]])
        axes[1].set_title("False Color Image (NINFR, RED, GREEN)")
        axes[1].axis("off")

        axes[2].imshow(label)
        axes[2].set_title("Segmentation Mask")
        axes[2].axis("off")

        if i >= show_n_images:
            return
        
def speedtest_dataloader(size, same, num_workers=0, ):
    train_loader, _, _ = create_dataloaders(DATA_PATH, batch_size=size, batch_transforms=same, num_workers=num_workers)
    for x, y in train_loader: # iterate through one batch
        pass
    return

In [None]:
train_loader, val_loader, test_loader = create_dataloaders(DATA_PATH, batch_size=4, bands=[0, 1, 2, 3], transforms=False)

batch_sample, batch_masks = next(iter(train_loader))
batch_masks = batch_masks.unsqueeze(1)

print(batch_sample.shape)
print(batch_masks.shape)

# repurpose the batch prediction function by simply passing the mask as the prediction
visualise_batch_predictions(batch_sample, batch_masks, batch_masks, bands=[0, 1, 2, 3], rescale=False)

In [None]:
illustrate_data_loader(train_loader, 10)

In [None]:
illustrate_data_loader(val_loader, 5)

In [None]:
illustrate_data_loader(test_loader, 5)

#### Illustrate distribution of raw values across bands

In [None]:
dataset = PlanetBaseDataset(data_dir=DATA_PATH, bands=[0,1,2,3])
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0)
dataloader = iter(dataloader)

In [None]:
# plot pixel values for all bands for 4 images

sample, mask = next(dataloader)
for batch_num in range(4):
    plot_band_distribution(sample[batch_num,:,:,:])

### Calculated means and standard deviations for normlization

In [None]:


"""
means = torch.mean(train_sample.float(), dim=(0, 1, 2))
std = torch.std(train_sample.float(), dim=(0, 1, 2))
min = torch.amin(train_sample.float(), dim=(0, 1, 2))
max = torch.amax(train_sample.float(), dim=(0, 1, 2))
"""


# pre-calculated means, std, mins, max of raw images calculated on the full train set:
means = torch.tensor([ 265.7371,  445.2234,  393.7881, 2773.2734])
stds = torch.tensor([ 91.8786, 110.0122, 191.7516, 709.2327])
mins = torch.tensor([ 0., 21.,  6., 77.])
max = torch.tensor([ 4433.,  5023.,  8230., 10000.])
