In [None]:
from torch.utils.data import DataLoader
from utils.visualisation import visualise_image, plot_band_distribution
from utils.data import create_dataloaders, PlanetDataset
import torch
import config

device = config.device
DATA_PATH = config.PATH_TO_DATA

%load_ext autoreload
%autoreload 2

### Illustrate Data Loaders

In [None]:
train_loader, val_loader, test_loader = create_dataloaders(DATA_PATH, batch_size=32)

In [None]:
def illustrate_data_loader(loader, show_n_images: int):
    for i, (batch_sample, batch_masks) in enumerate(loader):
        image, label = batch_sample[0], batch_masks[0]
        image, label = torch.permute(image, (1, 2, 0)), torch.permute(label, (1, 2, 0))
        visualise_image(image.numpy(), label.numpy())
        if i >= show_n_images:
            return
        
def speedtest_dataloader(size, same, num_workers=0, ):
    train_loader, _, _ = create_dataloaders(DATA_PATH, batch_size=size, batch_transforms=same, num_workers=num_workers)
    for x, y in train_loader: # iterate through one batch
        pass
    return

In [None]:
speedtest_dataloader(4, None)

In [None]:
%timeit speedtest_dataloader(4, True)

In [None]:
%timeit speedtest_dataloader(32, None)

In [None]:
%timeit speedtest_dataloader(32, True)

In [None]:
%timeit speedtest_dataloader(64, None)

In [None]:
%timeit speedtest_dataloader(64, True)

In [None]:
%timeit speedtest_dataloader(64, True, 4) # with concurrency - slower; overhead is taking over if we are just loading the image without a pass through the net?

In [None]:
illustrate_data_loader(train_loader, 10)

In [None]:
illustrate_data_loader(val_loader, 5)

In [None]:
illustrate_data_loader(test_loader, 5)

#### Illustrate distribution of raw values across bands

In [None]:
dataset = PlanetDataset(data_dir=DATA_PATH, bands=[0,1,2,3])
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0)
dataloader = iter(dataloader)

In [None]:
sample, mask = next(dataloader)
for batch_num in range(4):
    plot_band_distribution(sample[batch_num,:,:,:])

In [None]:
# pre-calculated means, std, mins, max of raw images calculated on the full test set:
means = torch.tensor([ 265.7371,  445.2234,  393.7881, 2773.2734])
stds = torch.tensor([ 91.8786, 110.0122, 191.7516, 709.2327])
mins = torch.tensor([ 0., 21.,  6., 77.])
max = torch.tensor([ 4433.,  5023.,  8230., 10000.])

"""
means = torch.mean(train_sample.float(), dim=(0, 1, 2))
std = torch.std(train_sample.float(), dim=(0, 1, 2))
min = torch.amin(train_sample.float(), dim=(0, 1, 2))
max = torch.amax(train_sample.float(), dim=(0, 1, 2))
"""