In [26]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm 

class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.classes = os.listdir(data_dir)  # training directories
        self.images = []  
        
        # looping class directories
        for class_dir in self.classes:
            class_path = os.path.join(data_dir, class_dir)
            if os.path.isdir(class_path):
                # image files in class directory
                image_files = [file for file in os.listdir(class_path) if file.endswith(('.jpg', '.jpeg', '.png'))]
                # Append image paths and labels to the images list
                self.images.extend([(os.path.join(class_path, image_file), class_dir) for image_file in image_files])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path, label = self.images[idx]
        try:
            # Load image data from the file
            with open(image_path, 'rb') as file:  # Open in binary mode
                image = Image.open(file)  # Open the image
                # Convert grayscale images to RGB format if needed
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                # resize and convertin to tensor
                transform = transforms.Compose([
                    transforms.Resize((256, 256)),  # Resize to a fixed size
                    transforms.ToTensor(),  # Convert to PyTorch tensor
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize 
                ])
                tensor_image = transform(image)
                return tensor_image, label
        except Exception as e:
            print(f"Error loading data from {image_path}: {e}")
            return None, None

data_path = "train" # Update directory path - preferrably put the code with the data

# instance of custom dataset
dataset = MyDataset(data_path)

# parameters for DataLoader
batch_size = 64
shuffle = True
num_workers = 0  # Adjust this based on your system's capabilities, 4 if system is good

# DataLoader instance
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

# Iterate over batches of data
for images, labels in tqdm(dataloader, desc="Training", unit="batch"):
    # Process your image data batch here
    pass

<torch.utils.data.dataloader.DataLoader object at 0x1484674d0>


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 82/82 [03:03<00:00,  2.23s/batch]
