<a href="https://colab.research.google.com/github/marcvonrohr/machine_learning/blob/main/lab_4/lab_04_custom_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img align="center" style="max-width: 1000px" src="https://github.com/HSG-AIML-Teaching/ML2025-Lab/blob/main/lab_4/figures/banner.png?raw=1">

<img align="right" style="max-width: 200px; height: auto" src="https://github.com/HSG-AIML-Teaching/ML2025-Lab/blob/main/lab_4/figures/hsg_logo.png?raw=1">

##  Lab 04- Custom Datasets in PyTorch

Machine Learning, University of St. Gallen, Spring Term 2025


In this tutorial, we want to implement a custom PyTorch dataset that processes images of a given dataset folder and prepares inputs for training and evaluation. Although the structure of datasets can significantly for vary, the principles in this tutorial should be applicable to any PyTorch dataset regardless of the folder structure or file formats.

Lab Objectives:
- Understand dataset structures and how to process dataset files.
- Learn how to implement a PyTorch dataset class.


## Example: A Multi-Folder Dataset

In this example, we have a dataset called **Omniglot** where the images of each class are inside a separate folder. We want to load the images inside each folder which corrsponds to a separate class and return them as instances of that class.

First let's download the files that we need from this link: https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip

For more information about the dataset please refer to this link:
https://github.com/brendenlake/omniglot/

To read the file list from a folder we need a package called `glob`. We use pip to install the package:

In [19]:
!pip install glob2



In [33]:
!mkdir -p dataset
!cd /content/
!wget https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip

--2025-03-24 11:15:16--  https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9464212 (9.0M) [application/zip]
Saving to: ‘images_background.zip.5’


2025-03-24 11:15:16 (123 MB/s) - ‘images_background.zip.5’ saved [9464212/9464212]



Now, let's see how we can retrieve and print the list of folders for a given root directory:

In [39]:
import zipfile
import shutil

# Definiere den Pfad zur ZIP-Datei
zip_file_path = '/content/images_background.zip'

# Definiere den Zielordner für die extrahierten Dateien
extract_dir = '/content/dataset'

# Erstelle den Zielordner, falls er noch nicht existiert
import os
os.makedirs(extract_dir, exist_ok=True)

# Entpacke die ZIP-Datei
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print(f'Die Datei {zip_file_path} wurde erfolgreich in {extract_dir} entpackt.')

import shutil
import os

# Definiere den Quellordner und den Zielordner
source_dir = '/content/dataset/images_background'
target_dir = '/content/dataset'

# Verschiebe alle Dateien und Ordner
for item in os.listdir(source_dir):
    source_path = os.path.join(source_dir, item)
    target_path = os.path.join(target_dir, item)
    shutil.move(source_path, target_path)

# Lösche den leeren Quellordner
os.rmdir(source_dir)

print(f'Alle Dateien und Ordner wurden aus {source_dir} in {target_dir} verschoben und {source_dir} wurde gelöscht.')

Die Datei /content/images_background.zip wurde erfolgreich in /content/dataset entpackt.
Alle Dateien und Ordner wurden aus /content/dataset/images_background in /content/dataset verschoben und /content/dataset/images_background wurde gelöscht.


In [40]:
import glob

glob.glob("dataset/Greek/*")

folder_names = [f.split("/")[-1] for f in glob.glob("dataset/Greek/*")]
print(folder_names)

['character20', 'character07', 'character12', 'character23', 'character18', 'character22', 'character03', 'character15', 'character01', 'character06', 'character11', 'character08', 'character05', 'character21', 'character16', 'character04', 'character17', 'character02', 'character24', 'character10', 'character14', 'character13', 'character19', 'character09']


If each folder corresponds to a class, we need to map the class name to a class ID:

In [41]:
name_to_id = {name: id for (id, name) in enumerate(sorted(folder_names))}

print(name_to_id)

{'character01': 0, 'character02': 1, 'character03': 2, 'character04': 3, 'character05': 4, 'character06': 5, 'character07': 6, 'character08': 7, 'character09': 8, 'character10': 9, 'character11': 10, 'character12': 11, 'character13': 12, 'character14': 13, 'character15': 14, 'character16': 15, 'character17': 16, 'character18': 17, 'character19': 18, 'character20': 19, 'character21': 20, 'character22': 21, 'character23': 22, 'character24': 23}


Next, we extract the list of all images in the dataset and assign them their label IDs:

In [42]:
all_files = glob.glob("./dataset/Greek/*/*.png")
all_label = [name_to_id[path.split("/")[-2]] for path in all_files]

print(all_label)

[19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,

Then, let's define a class that takes care of loading file lists and returning random samples from the dataset:

In [43]:
# CODE TO BE IMPLEMENTED DURING THE TUTORIAL SESSION
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image

class MyDataset(Dataset):
    def __init__(self, root, transform=None) -> None:
        super().__init__()

        self.transform = transform

        folder_names = [f.split("/")[-1] for f in glob.glob(root + "/*")]
        name_to_id = {name: id for (id, name) in enumerate(sorted(folder_names))}

        self.all_paths = glob.glob(root + "/*/*.png")
        self.all_label = [name_to_id[path.split("/")[-2]] for path in self.all_paths]

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, index):
        path_i = self.all_paths[index]
        image = Image.open(path_i)

        if self.transform is not None:
            image = self.transform(image)

        label = self.all_label[index]


        return image, label


Finally, we need to test the implemented PyTorch dataset class.

In [44]:
my_transform = transforms.ToTensor()
my_dataset = MyDataset(root="./dataset/Greek", transform=my_transform)


In [45]:
len(my_dataset)

480

In [46]:
my_dataset[13]

(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]),
 19)

Now, we use the dataset to create a dataloader and iterate through its samples.

In [47]:
from torch.utils.data import DataLoader

my_dataloder = DataLoader(my_dataset, batch_size=32, num_workers=0)

In [48]:
for batch in my_dataloder:
    image, label = batch
    print(image.shape)

torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
