In [1]:
import os

os.chdir("/Users/morizin/Documents/Code/crash-detection-project")

In [2]:
CONFIG_FILE_PATH = "config/config.yaml"
SCHEMA_DIR = "schemas/"

In [6]:
from pydantic import BaseModel, model_validator
from pathlib import Path
from typing import Optional
from box import ConfigBox
import yaml
from glob import glob   
from src.crash_detection import logger
from src.crash_detection.utils.common import load_yaml

class DataSchema(BaseModel):
    name: str
    path: Path | str

    train: Optional[str] = None
    train_image_folder: Optional[str] = None 
    test: Optional[str] = None
    test_image_folder: Optional[str] = None

    columns: Optional[dict[str, str]] = None
    categorical: Optional[list[str]] = None
    target: Optional[str] = None
    additional_properties: Optional[dict[str, str]] = None


    def model_post_init(self, __context__):
        file_path = os.path.join(SCHEMA_DIR, f"{self.name}.yaml")
        if os.path.exists(file_path):
            content = load_yaml(file_path)

            self.train = content["train"]
            self.train_image_folder = content["train_image_folder"]
            self.test = content["test"]
            self.test_image_folder = content["test_image_folder"]

            self.columns = content["columns"]
            self.categorical = content.get("categorical", [])
            self.additional_properties = content.get("additional_properties", {})
            self.target = content["target"]
            return self
        else:
            e = f"Schema of dataset '{self.name}' not found"
            logger.error(e)
            raise Exception(e)

class DataValidataionConfig(BaseModel):
    report_name: str
    indir: Path | str
    outdir: Path | str
    pixel_histogram: bool
    statistics: bool
    kl_divergence: bool
    schemas: dict[str, DataSchema]

In [7]:
from src.crash_detection.utils.common import load_yaml
from src.crash_detection.constants import SCHEMA_DIR, DATA_DIRECTORY_NAME, REPORT_NAME

class ConfigutationManager:
    def __init__(self, config_file_path=CONFIG_FILE_PATH):
        self.config_file = load_yaml(config_file_path)
        self.artifact_path = self.config_file["artifact-path"]


    def get_data_validation_config(self):
        data_schemes = {}
        for d_name in (self.config_file.data_sources):
            data_schemes[d_name] = DataSchema(name=d_name, path = os.path.join(self.artifact_path, DATA_DIRECTORY_NAME, d_name))
        # data_validation_config = self.config_file["data_validation"]
        return DataValidataionConfig(
            report_name = os.path.join(self.artifact_path, DATA_DIRECTORY_NAME, f"{REPORT_NAME}.yaml"),
            indir = os.path.join(self.artifact_path, DATA_DIRECTORY_NAME),
            outdir = os.path.join(self.artifact_path, REPORT_NAME),
            pixel_histogram = False,
            statistics = False,
            kl_divergence = False,
            schemas = data_schemes
        )

In [8]:
cfg = ConfigutationManager()
data_validation_config = cfg.get_data_validation_config()

2026-01-22 02:30:22,237 [INFO] : common - Successfully Loaded YAML file : config/config.yaml
2026-01-22 02:30:22,245 [INFO] : common - Successfully Loaded YAML file : schemas/gta-crash.yaml


In [None]:
import pandas as pd

def load_csv(file_path: Path | str) -> pd.DataFrame:
    return pd.read_csv(file_path)

class DataValidationComponent:
    def __init__(self, config: DataValidataionConfig):
        self.config = config

    def validate(self):
        # Schema validation logic goes here
        for schema_name, schema in self.config.schemas.items():
            logger.info(f"Validating dataset: {schema_name}")

            df = load_csv(os.path.join(schema.path, schema.train))
            logger.info(f"Loaded data shape: {df.shape}")
            for col, dtype in schema.columns.items():
                if col in df.columns:
                    if not pd.api.types.is_dtype_equal(df[col].dtype, dtype):
                        logger.warning(f"Column {col} has dtype {df[col].dtype}, expected {dtype}")
                else:
                    logger.warning(f"Column {col} is missing in the dataset")
            
        


data_validation = DataValidationComponent(config=data_validation_config)

SyntaxError: invalid syntax (1295452104.py, line 19)