In [None]:
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import v2
import os
import torch
import torch.nn.functional as TF
import matplotlib.pyplot as plt

In [None]:
class SquarePad:
    def __call__(self, image):
        _, w, h = image.size()
        max_wh = max(w, h)
        hp = int((max_wh - w) / 2)
        vp = int((max_wh - h) / 2)
        padding = [vp, vp, hp, hp]
        padded_img = TF.pad(image, padding)
        return padded_img

In [None]:
class MinMaxScalerVectorized(object):
    def __call__(self, image):
        dist = (image.max(dim=1, keepdim=True)[0] - image.min(dim=1, keepdim=True)[0])
        dist[dist==0.] = 1.
        scale = 1.0 /  dist
        image.mul_(scale).sub_(image.min(dim=1, keepdim=True)[0])
        return image

In [None]:
class BaseDataset(Dataset):
    def __init__(self, img_dir, img_size, antialias=True):
        self.img_dir = img_dir        
        self.img_size = img_size
        self.antialias = antialias
        self.image_names = os.listdir(img_dir)
        self.transform = v2.Compose([
            # Initialize the transform list by adding in the square padding
            SquarePad(),
            v2.Resize(size=self.img_size, antialias=self.antialias),
            # Append the conversion to tensor to the transform list
            v2.ConvertImageDtype(torch.float32),
            MinMaxScalerVectorized()
            # v2.Normalize(mean=[0.09371563792228699, 0.0821407213807106, 0.08119282871484756], std=[0.16423960030078888, 0.14668214321136475, 0.14383625984191895])
        ])    

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        row = self.image_names[index] # picks all images in the directory

        image_path = os.path.join(os.getcwd(), self.img_dir, row)
        image = read_image(image_path)
        
        image = self.transform(image)

        return image
    
    def save_dataset(self, filename):
        with open(filename, 'wb') as f:
            torch.save(self, f)

In [None]:
class PlotsDataset(BaseDataset):

    """ Args: 
            labels: dataframe with the labels must contain the columns: filename, elevation (normalized), elevation_avg
            img_size: tuple with the size of the image"""

    def __init__(self, labels, *args, **kwargs):
        self.labels = labels
        super().__init__(*args, **kwargs)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        row = self.labels.iloc[index] # picks only those in the ground truth

        image_path = os.path.join(os.getcwd(), self.img_dir, row['filename'])
        image = read_image(image_path)
        image = self.transform(image)
        label = row['elevation']

        return image, label

    def get_means_stds(self):
        tensors = [img for img, _ in self]

        # Split channels
        channels = torch.chunk(torch.stack(tensors), 3, dim=1)

        means = [torch.mean(channel).item() for channel in channels]
        stds = [torch.std(channel).item() for channel in channels]

        return means, stds