In [12]:
import os
%pwd

'c:\\Users\\aarav\\Downloads\\VisualSearch\\srcx\\research'

In [13]:
os.chdir('../')

In [44]:
%pwd

'c:\\Users\\aarav\\Downloads\\VisualSearch\\srcx'

In [58]:
from dataclasses import dataclass
from pathlib import Path 

@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path 
    data_path: Path
    params_weights: str
    params_include_top: bool 
    params_image_size: list 
    base_model: Path

In [59]:
from VisualSearch.constants import *
from VisualSearch.utils.common import read_yaml, create_directories

In [60]:
class ConfigurationManager:
    def __init__(
            self,
            config_filepath = CONFIG_FILE_PATH,
            params_filepath = PARAMS_FILE_PATH,
            schema_filepath = SCHEMA_FILE_PATH):
        
            self.config = read_yaml(config_filepath)
            self.params = read_yaml(params_filepath)
            self.schema = read_yaml(schema_filepath)

            create_directories([self.config.artifacts_root])
    
    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
          config = self.config.prepare_base_model

          create_directories([config.root_dir])

          prepare_base_model_config = PrepareBaseModelConfig(
                root_dir = config.root_dir,
                data_path=config.data_path,
                params_weights=self.params.WEIGHTS,
                params_image_size=self.params.INPUT_SHAPE,
                params_include_top=self.params.INCLUDE_TOP,
                base_model = config.base_model
          )

          return prepare_base_model_config   

In [61]:
import os
from VisualSearch import logger 
import joblib
import tensorflow as tf
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.applications.resnet50 import ResNet50


In [62]:
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
        
    
    def get_base_model(self):
        self.model = ResNet50(weights=self.config.params_weights,include_top=self.config.params_include_top,input_shape=self.config.params_image_size)

        self.model.trainable = False

        self.model = tf.keras.Sequential([
            self.model,
            GlobalMaxPooling2D()
        ])
    
        joblib.dump(self.model, os.path.join(self.config.root_dir, self.config.base_model))


In [64]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
except Exception as e:
    raise e

[2024-02-01 23:03:51,546: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-02-01 23:03:51,548: INFO: common: yaml file: params.yaml loaded successfully]
[2024-02-01 23:03:51,550: INFO: common: yaml file: schema.yaml loaded successfully]
[2024-02-01 23:03:51,551: INFO: common: created directory at: artifacts]
[2024-02-01 23:03:51,554: INFO: common: created directory at: artifacts/prepare_base_model]
