In [5]:
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List

In [2]:
os.chdir("..")

In [6]:
@dataclass(frozen=True)
class BaseModelConfig:
    root_dir: Path
    base_model_filename: str
    updated_base_model_filename: str
    augmentation: bool
    image_size: List
    batch_size: int
    include_top: bool
    epochs: int
    classes: int
    weights: str
    learning_rate: float


In [7]:
from cnnClassifier.constants import CONFIG_FILE_PATH, PARAMS_FILE_PATH
from cnnClassifier.utils.common import read_yaml, create_directories

In [11]:
class ConfigManager:

    def __init__(self,
                 config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
    
    def get_base_model_config(self) -> BaseModelConfig:
        config = self.config.base_model

        create_directories([config.root_dir])

        base_model_config = BaseModelConfig(
            root_dir=config.root_dir,
            base_model_filename=config.base_model_filename,
            updated_base_model_filename=config.updated_base_model_filename,
            augmentation=self.params.AUGMENTATION,
            image_size=self.params.IMAGE_SIZE,
            batch_size=self.params.BATCH_SIZE,
            include_top=self.params.INCLUDE_TOP,
            epochs=self.params.EPOCHS,
            classes=self.params.CLASSES,
            weights=self.params.WEIGHTS,
            learning_rate=self.params.LEARNING_RATE

        )

        return base_model_config

In [14]:
import os, sys
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf

from cnnClassifier.exception import CustomException
from cnnClassifier.logger import logging

In [19]:
class BaseModel:

    def __init__(self, config: BaseModelConfig):
        self.config = config
        self.model = None
        self.full_model = None
    
    @staticmethod
    def save_model(model: tf.keras.Model, path: Path):
        model.save(path)

    def get_base_model(self):
        self.model = tf.keras.applications.VGG16(
            input_shape=self.config.image_size,
            weights=self.config.weights,
            include_top=self.config.include_top,
        )
        path = os.path.join(self.config.root_dir, self.config.base_model_filename)
        self.save_model(model=self.model, path=path)
    


    def _prepare_full_model(self, freeze_all=True, freeze_till=None):

        if freeze_all:
            for layer in self.model.layers:
                layer.trainable = False
        elif freeze_till is not None and freeze_till > 0:
            for layer in self.model.layers[:-freeze_till]:
                layer.trainable = False
        
        model_input = tf.keras.layers.Flatten()(self.model.output)
        output = tf.keras.layers.Dense(16, activation="relu")(model_input)
        output = tf.keras.layers.Dense(self.config.classes,
                                       activation="softmax")(output)
        
        full_model = tf.keras.Model(inputs=self.model.input,
                                    outputs=output)
        
        full_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=self.config.learning_rate),
            loss=tf.keras.losses.CategoricalCrossentropy(),
            metrics=["accuracy"]
        )
        full_model.summary()
        return full_model

    def update_full_model(self):
        self.full_model = self._prepare_full_model()
        path = os.path.join(self.config.root_dir, self.config.updated_base_model_filename)
        self.save_model(model=self.full_model, path=path)
    


In [20]:
try:
    config = ConfigManager()
    base_model_config = config.get_base_model_config()
    base_model = BaseModel(base_model_config)
    base_model.get_base_model()
    base_model.update_full_model()

except CustomException as e:
    error = CustomException(e, sys)
    logging.error(error.error_message)
    raise error

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0   

  saving_api.save_model(
