In [1]:
import os 

In [2]:
%pwd

'd:\\Production\\projects\\chest-cancer-classification-p5\\notebook'

In [3]:
os.chdir("../")

In [22]:
%pwd

'd:\\Production\\projects\\chest-cancer-classification-p5'

### Update the entity

In [23]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataIngestionConfig:
    root_dir: Path
    source_url: str
    local_data_file: Path
    unzip_dir: Path

    

### Update the configuration manager in src config

In [None]:
from src.brain_tumor_classification.constants import *
from src.brain_tumor_classification.utils.common import read_yaml, create_directories

In [25]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath=CONFIG_FILE_PATH,
        params_filepath=PARAMS_FILE_PATH
    ):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_data_ingestion_config(self, source_type: str) -> DataIngestionConfig:
        """
        source_type: 'gdrive' or 'kaggle'
        """
        if source_type == "gdrive":
            config = self.config.data_ingestion_gdrive
        elif source_type == "kaggle":
            config = self.config.data_ingestion_kaggle
        else:
            raise ValueError(f"Invalid source_type: {source_type}, choose 'gdrive' or 'kaggle'")

        create_directories([config.root_dir])

        return DataIngestionConfig(
            root_dir=config.root_dir,
            source_url=config.source_url,
            local_data_file=config.local_data_file,
            unzip_dir=config.unzip_dir
        )

### Update the components

In [None]:
import os
import zipfile
import gdown
import kagglehub
import shutil
from src.brain_tumor_classification import logger

In [27]:
class DataIngestion:
    def __init__(self, config: DataIngestionConfig, source: str):
        self.config = config
        self.source = source
        self.downloaded_path = None

    def download(self) -> str:
        """Download dataset from gdrive or kaggle."""
        try:
            if self.source == "gdrive":
                self._download_from_gdrive()
            elif self.source == "kaggle":
                self._download_from_kaggle()
            else:
                raise ValueError(f"Unsupported source: {self.source}")

        except Exception as e:
            raise e
        return self.config.local_data_file

    def _download_from_gdrive(self):
        dataset_url = self.config.source_url
        zip_download_dir = self.config.local_data_file
        os.makedirs(self.config.root_dir, exist_ok=True)

        logger.info(f"Downloading data from {dataset_url} into file {zip_download_dir}")
        file_id = dataset_url.split("/")[-2]
        prefix = 'https://drive.google.com/uc?/export=download&id='
        gdown.download(prefix + file_id, str(zip_download_dir), quiet=False)
        logger.info(f"Downloaded data from {dataset_url} into file {zip_download_dir}")

    def _download_from_kaggle(self):
        dataset_slug = self.config.source_url
        os.makedirs(self.config.root_dir, exist_ok=True)

        logger.info(f"Downloading Kaggle dataset '{dataset_slug}'")
        self.downloaded_path = kagglehub.dataset_download(dataset_slug)
        logger.info(f"Dataset downloaded to cache: {self.downloaded_path}")

        # Move instead of copy (avoid duplicates)
        if os.path.exists(self.config.local_data_file):
            shutil.rmtree(self.config.local_data_file)
        shutil.move(self.downloaded_path, self.config.local_data_file)

        logger.info(f"Kaggle dataset moved to {self.config.local_data_file}")

    def extract_zip_file(self):
        """Extract if it's a zip file."""
        try:
            if str(self.config.local_data_file).endswith(".zip"):
                unzip_path = self.config.unzip_dir
                os.makedirs(unzip_path, exist_ok=True)

                with zipfile.ZipFile(self.config.local_data_file, 'r') as zip_ref:
                    zip_ref.extractall(unzip_path)

                logger.info(f"Extracted zip file to {unzip_path}")
            else:
                logger.info("No extraction needed (not a zip file).")

        except Exception as e:
            raise e


### Update the pipeline

In [None]:
try:
    config = ConfigurationManager()
    # Source: kaggle or gdrive
    data_ingestion_cfg = config.get_data_ingestion_config("kaggle")
    data_ingestion = DataIngestion(config=data_ingestion_cfg, source="kaggle")
    data_ingestion.download()
    data_ingestion.extract_zip_file()

except Exception as e:
    raise e