In [None]:
import utils.data as data
from utils.visualisation import visualise_image
from config import config
from torch.utils.data import DataLoader
import numpy as np

device = config.device
DATA_PATH = config.PATH_TO_DATA

%load_ext autoreload
%autoreload 2

In [None]:
dataset = data.PlanetDataset(data_dir=DATA_PATH, bands=[0,1,2,3])
print(dataset)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0)

dataloader = iter(dataloader)

### 1. Example : loading data samples from dataset

In [None]:
sample, mask = next(dataloader)
print(type(sample))
print(f"Batch sample shape (batch, x, y, n_channels): {sample.shape}")
print(f"Batch mask shape (batch, x, y): {mask.shape}")
visualise_image(sample[0].squeeze().numpy(),mask[0].squeeze().numpy())

### 2. Choosing framework to implement transforms/data augmentation functions

In [None]:
### Option 1 : torchgeo + Kornia
import torch

### ATTENTION : this framework requires that sample tensors are of shape (batch, n_channels, x, y)
# Doc: https://torchgeo.readthedocs.io/en/latest/tutorials/transforms.html
# https://torchgeo.readthedocs.io/en/latest/api/transforms.html
# 
sample = torch.permute(sample,(0,3,1,2))


In [None]:
sample.shape


In [None]:

from torchvision.transforms import v2
import kornia.augmentation as K
from torchgeo.transforms import AugmentationSequential, indices
transforms = AugmentationSequential(
    indices.AppendNDVI(index_nir=3, index_red=0), #There are very handy torchgeo functions for calculating indices!
    K.RandomHorizontalFlip(p=1),
    K.RandomVerticalFlip(p=1),
    K.RandomBoxBlur(kernel_size=(10, 10), border_type='reflect', normalized=True, p=1), #Note that the randomblur is applied to the image but no the mask!
    data_keys=["image","mask"],
)
transformed_tuple = transforms({"image" : sample, "mask" : mask})

# Unshuffle dimensions of sample so it is compatible with Luca's visualize_image:

transformed_tuple['image'] = torch.permute(transformed_tuple['image'],(0,2,3,1))

visualise_image(transformed_tuple['image'][0].squeeze().numpy(),transformed_tuple['mask'][0].squeeze().numpy())

print(f"Note that we have a new channel corresponding to NDVI: {transformed_tuple['image'].shape}")


In [None]:
sample, mask = next(dataloader)


In [None]:
### Option 2 : torch transforms v2 (implemented in beta since March 2023)
dataset = data.PlanetDataset(data_dir=DATA_PATH, bands=[0,1,2,3])
print(dataset)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0)

dataloader = iter(dataloader)

sample, mask = next(dataloader)
print(type(sample))
print(f"Batch sample shape (batch, x, y, n_channels): {sample.shape}")
print(f"Batch mask shape (batch, x, y): {mask.shape}")

# get means for normlization
means = torch.mean(sample.float(), dim=(0, 2, 3))
stds = torch.std(sample.float(), dim=(0, 2, 3))

visualise_image(sample[0].squeeze().numpy(),mask[0].squeeze().numpy())

from torchvision.transforms import v2
import kornia.augmentation as K
from torchgeo.transforms import AugmentationSequential, indices

print(sample.shape)
sample = torch.permute(sample,(0,3,1,2))
print(sample.shape)

# get means for normlization
means = torch.mean(sample.float(), dim=(0, 2, 3))
stds = torch.std(sample.float(), dim=(0, 2, 3))

sample = sample.float()

transforms = AugmentationSequential(
    K.Resize((224, 224)),
    K.RandomResizedCrop(size=(224, 224), p=0.5),

    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomBoxBlur(kernel_size=(10, 10), border_type='reflect', normalized=True, p=0.5), #Note that the randomblur is applied to the image but no the mask!
    
    # K.RandomGrayscale(p=1), only works when we have 3 input channels
    # K.ColorJitter(p=1), only works when we have 3 input channels
    # K.RandomPosterize(bits=4, p=1), # is a bit funky, need to normalize first
    # K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 0.1), p=0.2), # duplicate from 
    #TODO: normalize
    K.Normalize(mean=means, std=stds),

    data_keys=["image","mask"],
)


transformed_tuple = transforms({"image" : sample, "mask" : mask})
transformed_tuple['image'] = torch.permute(transformed_tuple['image'],(0,2,3,1))
visualise_image(transformed_tuple['image'][0].squeeze().numpy(),transformed_tuple['mask'][0].squeeze().numpy())


In [None]:
### Option 2 : torch transforms v2 (implemented in beta since March 2023)
### Old Experiments with Base Kornia - currently used in production

from utils.visualisation import visualise_image_3_channels

dataset = data.PlanetDataset(data_dir=DATA_PATH, bands=[0,1,2])
print(dataset)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0)

dataloader = iter(dataloader)

sample, mask = next(dataloader)
print(type(sample))
print(f"Batch sample shape (batch, x, y, n_channels): {sample.shape}")
print(f"Batch mask shape (batch, x, y): {mask.shape}")

# get means for normlization
means = torch.mean(sample.float(), dim=(0, 2, 3))
stds = torch.std(sample.float(), dim=(0, 2, 3))

visualise_image_3_channels(sample[0].squeeze().numpy(),mask[0].squeeze().numpy())


print(sample.shape)
sample = torch.permute(sample,(0,3,1,2))
print(sample.shape)

# get means for normlization
means = torch.mean(sample.float(), dim=(0, 2, 3))
stds = torch.std(sample.float(), dim=(0, 2, 3))

sample = sample.float()
mask = mask.float()
mask = mask[:, None, :, :]
print(mask.shape)


transforms = K.container.AugmentationSequential(
    K.Resize((224, 224)),
    K.RandomResizedCrop(size=(224, 224), p=0.5),
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomBoxBlur(kernel_size=(20, 20), border_type='reflect', normalized=True, p=0.5), #Note that the randomblur is applied to the image but no the mask!
    K.RandomGrayscale(p=0.2), # only works when we have 3 input channels
    K.Normalize(mean=means, std=stds),
    data_keys=["image","mask"],
)

transformed = transforms(sample, mask)
transformed_tuple = {k: v for k, v in zip(["image", "mask"], transformed)}
visualise_image_3_channels(transformed_tuple['image'][0].squeeze().numpy(),transformed_tuple['mask'][0].squeeze().numpy())