In [1]:
import os
import random
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
from tifffile import imread

# set root path and seed
PROJECT_ROOT = Path(os.getcwd()) #Path(r"C:\Users\tdoro\DLMS\mandatory_task")

# use mat-nr as seed
RANDOM_SEED = 3778660
random.seed(RANDOM_SEED)

class EuroSatMsDataset(Dataset):
    def __init__(self, dataset_root_dir, split_name):
        self.dataset_root_dir = dataset_root_dir
        self.img_labels = pd.read_csv(self.dataset_root_dir / "EuroSAT_MS" / (split_name + ".csv"))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.dataset_root_dir, self.img_labels.iloc[idx, 0])
        image = imread(img_path)
        label = self.img_labels.iloc[idx, 1]
        return image, label

ds_train = EuroSatMsDataset(PROJECT_ROOT, "train")
ds_test = EuroSatMsDataset(PROJECT_ROOT, "test")
ds_val = EuroSatMsDataset(PROJECT_ROOT, "val")

In [2]:
ds_train[0]

(array([[[1536, 1560, 1665, ..., 3675, 2542, 3526],
         [1536, 1560, 1665, ..., 3675, 2542, 3526],
         [1520, 1562, 1645, ..., 3602, 2468, 3560],
         ...,
         [1598, 1550, 1614, ..., 3752, 2671, 3218],
         [1599, 1522, 1605, ..., 3755, 2675, 3220],
         [1600, 1522, 1582, ..., 3754, 2682, 3222]],
 
        [[1536, 1560, 1665, ..., 3675, 2542, 3526],
         [1536, 1560, 1665, ..., 3675, 2542, 3526],
         [1520, 1562, 1645, ..., 3602, 2468, 3560],
         ...,
         [1598, 1550, 1614, ..., 3752, 2671, 3218],
         [1599, 1522, 1605, ..., 3755, 2675, 3220],
         [1600, 1522, 1582, ..., 3754, 2682, 3222]],
 
        [[1551, 1578, 1665, ..., 3741, 2608, 3480],
         [1551, 1578, 1665, ..., 3741, 2608, 3480],
         [1538, 1585, 1660, ..., 3718, 2577, 3497],
         ...,
         [1603, 1559, 1616, ..., 3761, 2679, 3211],
         [1604, 1574, 1627, ..., 3764, 2688, 3221],
         [1605, 1560, 1612, ..., 3764, 2697, 3228]],
 
        ...,
