In [1]:
import os
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Any, Optional
import numpy as np
import torchvision.transforms as transforms

In [2]:
class PickleDataset(Dataset):
    def __init__(self, data_dir: str, transform: Optional[callable] = None):
        """
        Initialize the dataset with a directory containing pickle files.
        
        Args:
            data_dir (str): Path to the directory containing pickle files
            transform (callable, optional): Optional transform to be applied on the data
        """
        self.data_dir = data_dir
        self.transform = transform
        self.pkl_files = [f for f in os.listdir(data_dir) if f.endswith('.pkl')]
        self.data_cache = {}  # Cache for loaded pickle files
        self.total_samples = 0
        self.file_offsets = []  # Store the starting index for each file
        
        # Calculate total samples and file offsets
        current_offset = 0
        for pkl_file in self.pkl_files:
            file_path = os.path.join(data_dir, pkl_file)
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
                if isinstance(data, (list, tuple)):
                    num_samples = len(data)
                elif isinstance(data, dict):
                    # Assuming the first key contains the data
                    first_key = next(iter(data))
                    num_samples = len(data[first_key])
                else:
                    raise ValueError(f"Unsupported data format in {pkl_file}")
                
                self.total_samples += num_samples
                self.file_offsets.append((current_offset, current_offset + num_samples))
                current_offset += num_samples

    def __len__(self) -> int:
        return self.total_samples

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a single item from the dataset.
        
        Args:
            idx (int): Index of the item to retrieve
            
        Returns:
            Dict[str, torch.Tensor]: Dictionary containing the data as tensors
        """
        # Find which file contains this index
        file_idx = 0
        for i, (start, end) in enumerate(self.file_offsets):
            if start <= idx < end:
                file_idx = i
                local_idx = idx - start
                break
        
        # Load the file if not in cache
        pkl_file = self.pkl_files[file_idx]
        if pkl_file not in self.data_cache:
            file_path = os.path.join(self.data_dir, pkl_file)
            with open(file_path, 'rb') as f:
                self.data_cache[pkl_file] = pickle.load(f)
        
        data = self.data_cache[pkl_file]
        
        # Convert data to tensors
        if isinstance(data, (list, tuple)):
            item = data[local_idx]
        elif isinstance(data, dict):
            item = {k: v[local_idx] for k, v in data.items()}
        
        # Convert numpy arrays to tensors
        if isinstance(item, dict):
            tensor_dict = {}
            for k, v in item.items():
                if isinstance(v, np.ndarray):
                    tensor_dict[k] = torch.from_numpy(v)
                elif isinstance(v, (int, float)):
                    tensor_dict[k] = torch.tensor(v)
                else:
                    tensor_dict[k] = v
            item = tensor_dict
        
        # Apply transform if specified
        if self.transform is not None:
            item = self.transform(item)
            
        return item

In [3]:
def create_dataloader(
    data_dir: str,
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 4,
    transform: Optional[callable] = None
) -> DataLoader:
    """
    Create a DataLoader for the pickle dataset.
    
    Args:
        data_dir (str): Path to the directory containing pickle files
        batch_size (int): Batch size for the DataLoader
        shuffle (bool): Whether to shuffle the data
        num_workers (int): Number of worker processes for loading data
        transform (callable, optional): Optional transform to be applied on the data
        
    Returns:
        DataLoader: PyTorch DataLoader instance
    """
    dataset = PickleDataset(data_dir, transform=transform)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    ) 

In [4]:
transform = transforms.Compose([
        transforms.ToTensor(),  # Convert numpy arrays to tensors
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
    ])

In [5]:
dataloader = create_dataloader(
        data_dir='./data',
        batch_size=16,
        shuffle=True,
        num_workers=2,
        transform=transform
    )

In [6]:
# Run the verification checks
# Method 1: Check dataset size and sample loading
print(f"Total dataset size: {len(dataloader.dataset)} samples")

Total dataset size: 17630 samples


In [7]:
# Method 2: Check batch loading
first_batch = next(iter(dataloader))
print("First batch shapes:", {k: v.shape for k, v in first_batch.items()})

RuntimeError: DataLoader worker (pid(s) 9780, 32080) exited unexpectedly

In [9]:
print(torch.device)

<class 'torch.device'>
