In [1]:
import os
os.chdir("../")

In [2]:
%pwd

'/Users/goldyrana/work/dl/kidney_scans/end_to_end_dl_project_kidney_scans'

In [3]:
from dataclasses import dataclass
import os
from pathlib import Path
from src.common import read_yaml, config_path, params_path

In [4]:
@dataclass
class TrainModelConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list

In [5]:
class ConfigurationManager:
    def __init__(self):
        self.config_path = config_path
        self.params_path = params_path
        
        self.config = read_yaml(self.config_path)
        self.params = read_yaml(self.params_path)
        
    def get_training_config(self) -> TrainModelConfig:
        print(self.config)
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.extract_path, "data") 
        
        return TrainModelConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_base_model_path),
            training_data=Path(training_data),
            params_epochs=params.EPOCHS,
            params_batch_size=params.BATCH_SIZE,
            params_is_augmentation=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE
            )

In [6]:
config = ConfigurationManager()
params = config.get_training_config()
params

{'data': 'data/raw', 'data_ingestion': {'root_dir': 'data/', 'source_url': 'https://drive.google.com/file/d/1ZF74F8h_419Lf-jK-k9TRMHY1GazwSQd/view?usp=drive_link', 'zip_path': 'data/raw/data.zip', 'extract_path': 'data/processed'}, 'prepare_base_model': {'root_dir': 'models/prepare_base_model', 'base_model_path': 'models/prepare_base_model/base_model.h5', 'updated_base_model_path': 'models/prepare_base_model/base_model_updated.h5'}, 'training': {'root_dir': 'models/training', 'trained_model_path': 'models/training/model.h5'}}


TrainModelConfig(root_dir=PosixPath('models/training'), trained_model_path=PosixPath('models/training/model.h5'), updated_base_model_path=PosixPath('models/prepare_base_model/base_model_updated.h5'), training_data=PosixPath('data/processed/data'), params_epochs=1, params_batch_size=16, params_is_augmentation=True, params_image_size=BoxList([224, 224, 3]))

In [7]:
import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
import time

In [24]:
class Training:
    def __init__(self, config: TrainModelConfig):
        self.config = config
    
    @staticmethod
    def save_model(path: Path, model: tf.keras.Model):
        model.save(path)
      
    def get_base_model(self):
        self.model = tf.keras.models.load_model(
            self.config.updated_base_model_path)
        
    def train_valid_generator(self):
       
        data_generator_kwargs = dict(
            target_size=self.config.params_image_size[:-1],
            batch_size=self.config.params_batch_size,
            interpolation="bilinear"
        )
        
        validation_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
            rescale = 1./255,
            validation_split = 0.20
        )
        
        self.valid_generator=validation_data_gen.flow_from_directory(
            directory=self.config.training_data,
            shuffle=False,
            **data_generator_kwargs)
        
        # keeping same as validataion data gen
        train_generator = validation_data_gen
        self.train_generator=train_generator.flow_from_directory(
            directory=self.config.training_data,
            shuffle=True,
            **data_generator_kwargs)
        
    def train(self):
        self.steps_per_epoch=10
        self.validation_steps=3
        
        self.model.fit(
            self.train_generator,
            epochs=self.config.params_epochs,
            steps_per_epoch=self.steps_per_epoch,
            validation_data=self.valid_generator,
        )
        
        self.save_model(path=self.config.trained_model_path, 
                        model=self.model)

In [25]:
try:
    config = ConfigurationManager()
    training_config = config.get_training_config()
    training = training(training_config)
    training.get_base_model()
    training.train_valid_generator()
    training.train()
except Exception as e:
    raise e

{'data': 'data/raw', 'data_ingestion': {'root_dir': 'data/', 'source_url': 'https://drive.google.com/file/d/1ZF74F8h_419Lf-jK-k9TRMHY1GazwSQd/view?usp=drive_link', 'zip_path': 'data/raw/data.zip', 'extract_path': 'data/processed'}, 'prepare_base_model': {'root_dir': 'models/prepare_base_model', 'base_model_path': 'models/prepare_base_model/base_model.h5', 'updated_base_model_path': 'models/prepare_base_model/base_model_updated.h5'}, 'training': {'root_dir': 'models/training', 'trained_model_path': 'models/training/model.h5'}}
Found 56 images belonging to 2 classes.
Found 56 images belonging to 2 classes.


  saving_api.save_model(
