In [1]:
import os
os.chdir("../")
%pwd

'/Users/isma/Documents/Portfolio/kitchen_robot'

In [4]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainingConfig:
    root_dir: Path
    feature_size: int
    target_size: int
    batch_size: int
    epochs: int
    train_eval_dataset: Path
    base_model_path: Path
    updated_model_path: Path

In [5]:
from kitchen_robot.constants import *
# from kitchen_robot.entity.config_entity import (
#     PrepareDatasetConfig,
# )
from kitchen_robot.utils.common import read_yaml, create_directories


class ConfigurationManager:
    def __init__(
        self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH
    ):

        print(config_filepath)

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_model_training_config(self) -> ModelTrainingConfig:
        config = self.config.model_training
        create_directories([config.root_dir])

        model_training_config = ModelTrainingConfig(
            root_dir=config.root_dir,
            feature_size=self.config.prepare_model.feature_size,
            target_size=self.config.prepare_model.target_size,
            batch_size=self.params.batch_size,
            epochs=self.params.epochs,
            train_eval_dataset=config.train_eval_dataset,
            base_model_path=config.base_model_path,
            updated_model_path=config.updated_model_path,
        )

        return model_training_config

In [7]:
import torch
from torch.utils.data import Subset, DataLoader, random_split
from kitchen_robot.components.prepare_model import PredictionModel

In [13]:
class ModelTraining:
    def __init__(self, config: ModelTrainingConfig):
        self.config = config

    def get_base_model(self):
        self.model = PredictionModel(self.config.feature_size, self.config.target_size)
        self.model.load_state_dict(torch.load(self.config.base_model_path))

    def train_valid_generator(self):
        loaded_dataset = torch.load(self.config.train_eval_dataset)

        train_size = int(0.9 * len(loaded_dataset))
        eval_size = len(loaded_dataset) - train_size

        train_dataset, eval_dataset = random_split(loaded_dataset, [train_size, eval_size])

        self.train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
        self.eval_loader = DataLoader(eval_dataset, batch_size=self.config.batch_size)

    @staticmethod
    def save_model(path, model):
        torch.save(model.state_dict(), path)

    def train(self):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.model.parameters())

        for epoch in range(self.config.epochs):
            self.model.train()
            running_loss = 0.0

            for inputs, targets in self.train_loader:
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()

            epoch_loss = running_loss / len(self.train_loader)
            print(f"Epoch {epoch+1}, Loss: {epoch_loss}")

        self.save_model(self.config.updated_model_path, self.model)
        print("Training complete")

In [14]:
try:
    config = ConfigurationManager()
    training_config = config.get_model_training_config()
    training = ModelTraining(config=training_config)
    training.get_base_model()
    training.train_valid_generator()
    training.train()
    
except Exception as e:
    raise e

config/config.yaml
[2024-02-27 23:46:33,109: INFO: common: created directory at: artifacts]
[2024-02-27 23:46:33,110: INFO: common: created directory at: artifacts/model_training]
Epoch 1, Loss: 137237.20103346457
Training complete
