In [2]:
import os

In [3]:
os.chdir("../")

In [5]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataPreprocessingConfig:
    data_dir: Path
    batch_size: int
    image_size: list
    val_split: float
    test_split: float
    shuffle: bool
    random_seed: int
    augmentation: bool


In [16]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories
# from cnnClassifier.entity import DataPreprocessingConfig
from pathlib import Path

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_data_preprocessing_config(self) -> DataPreprocessingConfig:
        config_root_data = self.config.data_ingestion
        config = DataPreprocessingConfig(
            data_dir=Path(config_root_data.unzip_dir,"dataset-resized"),
            batch_size=self.params.BATCH_SIZE,
            image_size=self.params.IMAGE_SIZE,
            val_split=self.params.VAL_SPLIT,
            test_split=self.params.TEST_SPLIT,
            shuffle=self.params.SHUFFLE_DATASET,
            random_seed=self.params.RANDOM_SEED,
            augmentation=self.params.AUGMENTATION
        )
        return config


In [17]:
import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class TrashNetDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        self.image_paths = []
        self.labels = []
        
        for class_name in self.classes:
            class_dir = os.path.join(data_dir, class_name)
            if os.path.exists(class_dir):
                for img_file in os.listdir(class_dir):
                    if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(class_dir, img_file))
                        self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


class DataPreprocessing:
    def __init__(self, config):
        self.config = config

    def _get_transforms(self, train=True):
        if train and self.config.augmentation:
            return transforms.Compose([
                transforms.Resize((self.config.image_size[0], self.config.image_size[1])),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        else:
            return transforms.Compose([
                transforms.Resize((self.config.image_size[0], self.config.image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

    def get_dataloaders(self):
        dataset = TrashNetDataset(
            data_dir=self.config.data_dir,
            transform=None
        )
        
        dataset_size = len(dataset)
        test_size = int(self.config.test_split * dataset_size)
        val_size = int(self.config.val_split * dataset_size)
        train_size = dataset_size - val_size - test_size

        if self.config.shuffle:
            torch.manual_seed(self.config.random_seed)

        train_dataset, val_dataset, test_dataset = random_split(
            dataset, [train_size, val_size, test_size]
        )

        # Apply transforms
        train_dataset.dataset.transform = self._get_transforms(train=True)
        val_dataset.dataset.transform = self._get_transforms(train=False)
        test_dataset.dataset.transform = self._get_transforms(train=False)

        # DataLoaders
        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader


In [18]:
from cnnClassifier import logger
try:
    logger.info(">>>>> Data Preprocessing stage started <<<<<")
    
    # Load configuration
    config = ConfigurationManager()
    data_preprocessing_config = config.get_data_preprocessing_config()
    
    # Create preprocessing component
    preprocessing = DataPreprocessing(config=data_preprocessing_config)
    
    # Get dataloaders
    train_loader, val_loader, test_loader = preprocessing.get_dataloaders()
    
    # Print some info
    print(f"Train loader batches: {len(train_loader)}")
    print(f"Validation loader batches: {len(val_loader)}")
    print(f"Test loader batches: {len(test_loader)}")
    
    # Optional sanity check: inspect one batch
    images, labels = next(iter(train_loader))
    print(f"Sample batch - images shape: {images.shape}, labels shape: {labels.shape}")
    
    logger.info(">>>>> Data Preprocessing stage completed <<<<<")

except Exception as e:
    logger.exception(e)
    raise e

[2025-09-28 01:07:30,323: INFO: 2741232552: >>>>> Data Preprocessing stage started <<<<<]
[2025-09-28 01:07:30,326: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-09-28 01:07:30,328: INFO: common: yaml file: params.yaml loaded successfully]
[2025-09-28 01:07:30,329: INFO: common: created directory at: artifacts]
Train loader batches: 56
Validation loader batches: 16
Test loader batches: 8
Sample batch - images shape: torch.Size([32, 3, 224, 224]), labels shape: torch.Size([32])
[2025-09-28 01:07:30,566: INFO: 2741232552: >>>>> Data Preprocessing stage completed <<<<<]
