In [1]:
from torch import optim
import torch as th
from torch import nn
from dataclasses import dataclass
import os
from src.constants import *
from src.utils.common import *
os.chdir("../")
print(os.getcwd())

/Users/goldyrana/mess/deep_learning/projects/blood_group_detection


In [2]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

In [5]:
@dataclass
class ImageTransformationConfig:
    # Image params
    image_shape: list
    mean: int
    std: int
    # Directory params
    train_path: str
    batch_size: int
    shuffle: bool
    num_workers:int

    
class ConfigurationManager:
    def __init__(self):
        self.config=read_yaml(CONFIG_FILE_PATH)
        self.params=read_yaml(PARAMS_FILE_PATH)
    
    
    def get_image_transformation_params(self)->ImageTransformationConfig:
        params=ImageTransformationConfig(
            image_shape=[self.params.image_transformation.height,
                        self.params.image_transformation.width],

            mean=self.params.image_transformation.mean,
            std=self.params.image_transformation.std,
            # Directory params

            train_path=self.config.data.train_path,
            batch_size=self.params.image_transformation.batch_size,
            shuffle=self.params.image_transformation.shuffle,
            num_workers=self.params.image_transformation.workers
        )
        return params

In [8]:
config_manager = ConfigurationManager()
config_params = config_manager.get_image_transformation_params()

ob = ImageTransformation(config_params)

2025-01-14 22:18:26,426 - root - INFO - Yaml read successfully from config/config.yaml
2025-01-14 22:18:26,429 - root - INFO - Yaml read successfully from params.yaml


{'training': {'epochs': 5, 'batch_size': 32}, 'mlflow_params': {'uri': 'https://dagshub.com/RajeshGoldy/blood_group_detection.mlflow'}, 'image_params': {'height': 100, 'width': 96, 'channels': 1, 'no_categories': 8}, 'image_transformation': {'height': 100, 'width': 96, 'mean': 0, 'std': 1, 'batch_size': 32, 'shuffle': True, 'workers': 4}}


In [7]:

class ImageTransformation:
    def __init__(self, config):
        self.config=config

    def get_image_transformer(self, data="train"):
        if data != "train":
            transform = transforms.Compose(
                [transforms.Resize(self.config.image_shape),
                transforms.ToTensor(),
                transforms.Normalize(mean=self.config.mean, std=self.config.std)]
            )
            return transform
        else: # transformation for test or validation
            logger.info(f"Initializing transformation on {data} data")
            transform = transforms.Compose([
                transforms.Resize(self.config.image_shape),
                transforms.ToTensor(),
                transforms.Normalize(mean=self.config.mean, std=self.config.std)
            ])
            return transform
        
    def get_data_loader(self, image_folder_path, transform):
        image_dataset = ImageFolder(image_folder_path, transform=transform)

        image_loader = DataLoader(image_dataset,
                                  batch_size=self.config.batch_size,
                                  shuffle=self.config.shuffle,
                                  num_workers=self.config.num_workers)
        
        return image_loader
        

In [9]:
if __name__ == "__main__":
    config_manager = ConfigurationManager()
    config_params = config_manager.get_image_transformation_params()

    ob = ImageTransformation(config_params)
    train_transformer = ob.get_image_transformer(data="train")
    train_loader = ob.get_data_loader("data/processed/train", train_transformer)

2025-01-14 22:18:55,246 - root - INFO - Yaml read successfully from config/config.yaml
2025-01-14 22:18:55,250 - root - INFO - Yaml read successfully from params.yaml
2025-01-14 22:18:55,251 - root - INFO - Initializing transformation on train data


{'training': {'epochs': 5, 'batch_size': 32}, 'mlflow_params': {'uri': 'https://dagshub.com/RajeshGoldy/blood_group_detection.mlflow'}, 'image_params': {'height': 100, 'width': 96, 'channels': 1, 'no_categories': 8}, 'image_transformation': {'height': 100, 'width': 96, 'mean': 0, 'std': 1, 'batch_size': 32, 'shuffle': True, 'workers': 4}}
