In [None]:
import os

# go to root directory if needed
print(f"Current working directory: {os.getcwd()}")
if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")
    print(f"Changed working directory to: {os.getcwd()}")

In [None]:
import torch
import src.utils.data as data

from torchvision import transforms

# Settings
torch.manual_seed(0)

Loading data

In [None]:
# Als DataModule (Trainings-, Validierungs- und Testdaten unterteilt)
data_folder_path = "data/raw/burst_images/"

data_module = data.ECallistoDataModule(
    data_folder=data_folder_path,
    transform=transforms.Compose(
        [
            transforms.Resize((193, 240), antialias=True),
        ]
    ),
    batch_size=32,
    num_workers=0,
    val_ratio=0.2,
    test_ratio=0.2,
)
data_module.setup()

In [None]:
batch_data, batch_filenames, batch_labels = next(iter(data_module.train_dataloader()))

first_data_in_batch = batch_data[0]
first_timestamp_in_batch = batch_filenames[0]
first_folder_number_in_batch = batch_labels[0]

print("First data in batch:", first_data_in_batch)
print("Timestamp of first data:", first_timestamp_in_batch)
print("Folder number of data:", first_folder_number_in_batch)

Get samples per class for each dataloader

In [None]:
def count_samples_per_class(dataloader):
    count = {}
    for _, _, label in dataloader.dataset:
        if label not in count:
            count[label] = 0
        count[label] += 1
    return dict(sorted(count.items()))

print(
    f"Distribution of classes in training set: {count_samples_per_class(data_module.train_dataloader())}"
)

print(
    f"Distribution of classes in validation set: {count_samples_per_class(data_module.val_dataloader())}"
)

print(
    f"Distribution of classes in test set: {count_samples_per_class(data_module.test_dataloader())}"
)

and check with number of rows in the dataset

In [None]:
print(f"Train Dataset Length: {len(data_module.train_dataset)}")
print(f"Val Dataset Length: {len(data_module.val_dataset)}")
print(f"Test Dataset Length: {len(data_module.test_dataset)}")