In [16]:
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List

In [2]:
os.chdir("..")

In [17]:
@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    training_data_path: Path
    pretrained_model_path: Path
    model_filename: str
    augmentation: bool
    batch_size: int
    epochs: int
    image_size: List

In [18]:
from cnnClassifier.constants import CONFIG_FILE_PATH, PARAMS_FILE_PATH
from cnnClassifier.utils.common import read_yaml, create_directories

In [19]:
class ConfigManager:

    def __init__(self,
                 config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
    
    def get_training_config(self) -> TrainingConfig:
        config = self.config.training

        training_config = TrainingConfig(
            root_dir=config.root_dir,
            training_data_path=config.training_data_path,
            pretrained_model_path=config.pretrained_model_path,
            model_filename=config.model_filename,
            augmentation=self.params.AUGMENTATION,
            batch_size=self.params.BATCH_SIZE,
            epochs=self.params.EPOCHS,
            image_size=self.params.IMAGE_SIZE
        )

        return training_config

In [20]:
import os
import tensorflow as tf
from cnnClassifier.exception import CustomException
from cnnClassifier.logger import logging

In [21]:
class Training:

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = None
        self.train_generator = None
        self.valid_generator = None
        self.get_base_model()

    
    def get_base_model(self):
        self.model = tf.keras.models.load_model(self.config.pretrained_model_path)
    
    @staticmethod
    def save_model(model: tf.keras.Model, path: Path):
        model.save(path)
    
    def _train_valid_generator(self):

        data_generator_kwargs = dict(
            rescale=1./255,
            validation_split = 0.2
        )

        valid_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(**data_generator_kwargs)

        dataflow_kwargs = dict(
            target_size = self.config.image_size[:-1],
            batch_size = self.config.batch_size,
            interpolation="bilinear"
        )

        self.valid_generator = valid_data_generator.flow_from_directory(
            self.config.training_data_path,
            subset="validation",
            shuffle=False,
            **dataflow_kwargs
        )

        if self.config.augmentation:
            train_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=15,
                horizontal_flip=True,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.1,
                zoom_range=0.1,
                **data_generator_kwargs
            )
        else:
            train_data_generator = valid_data_generator
        
        self.train_generator = train_data_generator.flow_from_directory(
            directory=self.config.training_data_path,
            subset="training",
            shuffle=True,
            **dataflow_kwargs
        )
    
    def train(self):
        train_steps = self.train_generator.samples // self.train_generator.batch_size
        validation_steps = self.valid_generator.samples // self.valid_generator.batch_size

        self.model.fit(
            self.train_generator,
            epochs = self.config.epochs,
            steps_per_epoch = train_steps,
            validation_data=self.valid_generator,
            validation_steps=validation_steps
        )
        path = os.path.join(self.config.root_dir, self.config.model_filename)

        self.save_model(model=self.model, path=path)


    


In [22]:
try:
    config = ConfigManager()
    training_config = config.get_training_config()
    training = Training(config=training_config)
    training._train_valid_generator()
    training.train()
    
except Exception as e:
    raise e

Found 68 images belonging to 2 classes.
Found 275 images belonging to 2 classes.
Epoch 1/10


2024-07-21 17:28:49.270476: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2024-07-21 17:28:52.525659: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
