In [13]:
import torch
from torch.utils.data import Dataset, DataLoader
import h5py
import pandas as pd

class SpectrogramDataset(Dataset):
    def __init__(self, hdf5_file, csv_file):
        self.hdf5_file_path = hdf5_file
        self.labels = pd.read_csv(csv_file)
        self.label_map = self.labels.columns[1:].tolist() # Get effect label names
        self.hdf5_file = None   # File will be opened for each worker

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Open HDF5 file once per worker
        if self.hdf5_file is None:
            self.hdf5_file = h5py.File(self.hdf5_file_path, "r", swmr=True) # SWMR ensures multi-thread safe

        key = self.labels.iloc[idx]['key']
        spectrogram = torch.tensor(self.hdf5_file[key][()], dtype=torch.float32).unsqueeze(0)
        #label = torch.tensor(self.labels.iloc[idx][1:].values, dtype=torch.float32) # Multi-hot label
        label_values = self.labels.iloc[idx][1:].infer_objects(copy=False).fillna(0).astype(float).values  # Convert all label columns to float
        label = torch.tensor(label_values, dtype=torch.float32)  # Convert to tensor


        return spectrogram, label

    def __del__(self):
        if self.hdf5_file is not None:
            self.hdf5_file.close()

In [14]:
# Initialize dataset from HD5F and csv file
h5_path = '/content/drive/MyDrive/Capstone 210/Small Agg Dataset/Output/spectrograms.h5'
csv_path = '/content/drive/MyDrive/Capstone 210/Small Agg Dataset/Output/labels.csv'

dataset = SpectrogramDataset(h5_path, csv_path)

# Initialize dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

In [None]:
# Example of accessing the data
for epoch in range(2):
    print(f"Beginning epoch {epoch+1}")
    for batch_idx, (spectrograms, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx + 1}: {spectrograms.shape}, {labels.shape}")

print("Done")

Beginning epoch 1
Batch 1: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 2: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 3: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 4: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 5: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 6: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 7: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 8: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 9: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 10: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 11: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 12: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 13: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 14: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 15: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])
Batch 16: torch.Size([32, 1, 128, 626]), torch.Size([32, 12])