In [1]:
import os

In [2]:
%pwd

'd:\\Production\\projects\\brain-tumor-classification-p5\\notebook'

In [3]:
os.chdir("../")


In [4]:
%pwd

'd:\\Production\\projects\\brain-tumor-classification-p5'

In [5]:
import torch
import torch.nn as nn
import timm
from dataclasses import dataclass
from pathlib import Path
from src.brain_tumor_classification.utils.common import read_yaml, create_directories
from src.brain_tumor_classification.constants import *
from src.brain_tumor_classification import logger

  from .autonotebook import tqdm as notebook_tqdm


### Entity

In [6]:
@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_model_path: Path
    model_name: str
    image_size: list
    num_classes: int
    learning_rate: float
    weight_decay: float
    freeze_base: bool

### Configuration

In [7]:
# Configuration Manager
class ConfigurationManager:
    def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])

    def prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        params = self.params.base_model
        create_directories([config.root_dir])

        return PrepareBaseModelConfig(
            root_dir=config.root_dir,
            base_model_path=config.base_model_path,
            updated_model_path=config.updated_model_path,
            model_name=params.model_name,
            image_size=params.image_size,
            num_classes=params.num_classes,
            learning_rate=params.learning_rate,
            weight_decay=params.weight_decay,
            freeze_base=params.freeze_base
        )

### Model

In [8]:
# Base Model Preparation
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
        self.model = None

    def load_model(self):
        logger.info(f"Loading pretrained model: {self.config.model_name}")
        self.model = timm.create_model(
            self.config.model_name,
            pretrained=True,
            num_classes=self.config.num_classes
        )
        logger.info(f"Model {self.config.model_name} loaded successfully.")

        # Save raw base model
        torch.save(self.model, self.config.base_model_path)
        logger.info(f"Base model saved at: {self.config.base_model_path}")

    def freeze_base_layers(self):
        if self.config.freeze_base and self.model is not None:
            logger.info("Freezing base layers...")
            for param in self.model.parameters():
                param.requires_grad = False
            # Keep classifier head trainable
            for param in self.model.get_classifier().parameters():
                param.requires_grad = True
            logger.info("Base layers frozen, classifier head trainable.")

    def save_updated_model(self):
        torch.save(self.model, self.config.updated_model_path)
        logger.info(f"Updated model saved at: {self.config.updated_model_path}")


In [9]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)

    prepare_base_model.load_model()
    prepare_base_model.freeze_base_layers()
    prepare_base_model.save_updated_model()

except Exception as e:
    raise e

[2025-08-31 18:26:17,716: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-08-31 18:26:17,720: INFO: common: yaml file: config\params.yaml loaded successfully]
[2025-08-31 18:26:17,722: INFO: common: created directory at: artifacts]
[2025-08-31 18:26:17,726: INFO: common: created directory at: artifacts/prepare_base_model]
[2025-08-31 18:26:17,727: INFO: 4290034370: Loading pretrained model: swin_tiny_patch4_window7_224]
[2025-08-31 18:26:18,111: INFO: _builder: Loading pretrained weights from Hugging Face hub (timm/swin_tiny_patch4_window7_224.ms_in1k)]
[2025-08-31 18:26:18,782: INFO: _hub: [timm/swin_tiny_patch4_window7_224.ms_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.]
[2025-08-31 18:26:18,979: INFO: _builder: Missing keys (head.fc.weight, head.fc.bias) discovered while loading pretrained weights. This is expected if model is being adapted.]
[2025-08-31 18:26:19,000: INFO: 4290034370: Mo