In [None]:
import os
os.chdir("..")

In [None]:
os.getcwd()

# constants

In [None]:
from dataclasses import dataclass
from text_summarization.utils import read_yaml


CONFIG = read_yaml("config/config.yaml")

@dataclass(frozen=True)
class ModelTrainerConstants:
    ARITFACTS_ROOT_DIR_NAME = CONFIG.ARITFACTS_ROOT_DIR_NAME
    MODEL_ROOT_DIR_NAME = CONFIG.MODEL.ROOT_DIR_NAME
    TRAINER_ROOT_DIR_NAME = CONFIG.MODEL.TRAINER.ROOT_DIR_NAME
    BASE_ESTIMATOR_NAME = CONFIG.MODEL.TRAINER.BASE_ESTIMATOR_NAME
    FINETUNED_ESTIMATOR_NAME = CONFIG.MODEL.TRAINER.FINETUNED_ESTIMATOR_NAME
    PARAMS_FILE_NAME = CONFIG.MODEL.TRAINER.PARAMS_FILE_NAME




In [None]:
print("ARITFACTS_ROOT_DIR_NAME:", ModelTrainerConstants.ARITFACTS_ROOT_DIR_NAME)
print("MODEL_ROOT_DIR_NAME:", ModelTrainerConstants.MODEL_ROOT_DIR_NAME)
print("TRAINER_ROOT_DIR_NAME:", ModelTrainerConstants.TRAINER_ROOT_DIR_NAME)
print("BASE_ESTIMATOR_NAME:", ModelTrainerConstants.BASE_ESTIMATOR_NAME)
print("FINETUNED_ESTIMATOR_NAME:", ModelTrainerConstants.FINETUNED_ESTIMATOR_NAME)
print("PARAMS_FILE_NAME:", ModelTrainerConstants.PARAMS_FILE_NAME)

# entity

In [None]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainerArtifacts:
    ARITFACTS_ROOT_DIR_PATH:Path
    MODEL_ROOT_DIR_PATH:Path
    TRAINER_ROOT_DIR_PATH:Path
    BASE_ESTIMATOR_PATH:Path
    FINETUNED_ESTIMATOR_PATH:Path
    PARAMS_FILE_PATH:Path




# configuration

In [None]:
from text_summarization.configuration import __timestamp
from dataclasses import dataclass
from pathlib import Path
import os


@dataclass(frozen=True)
class ModelTrainerConfig:
    ARITFACTS_ROOT_DIR_PATH = os.path.join(ModelTrainerConstants.ARITFACTS_ROOT_DIR_NAME, __timestamp)
    MODEL_ROOT_DIR_PATH = os.path.join(ARITFACTS_ROOT_DIR_PATH, ModelTrainerConstants.MODEL_ROOT_DIR_NAME)
    TRAINER_ROOT_DIR_PATH = os.path.join(MODEL_ROOT_DIR_PATH, ModelTrainerConstants.TRAINER_ROOT_DIR_NAME)
    BASE_ESTIMATOR_PATH = os.path.join(TRAINER_ROOT_DIR_PATH, ModelTrainerConstants.BASE_ESTIMATOR_NAME)
    FINETUNED_ESTIMATOR_PATH = os.path.join(TRAINER_ROOT_DIR_PATH, ModelTrainerConstants.FINETUNED_ESTIMATOR_NAME)
    PARAMS_FILE_PATH = Path(ModelTrainerConstants.PARAMS_FILE_NAME)




In [None]:
print("ARITFACTS_ROOT_DIR_PATH:", ModelTrainerConfig.ARITFACTS_ROOT_DIR_PATH)
print("MODEL_ROOT_DIR_PATH:", ModelTrainerConfig.MODEL_ROOT_DIR_PATH)
print("TRAINER_ROOT_DIR_PATH:", ModelTrainerConfig.TRAINER_ROOT_DIR_PATH)
print("BASE_ESTIMATOR_PATH:", ModelTrainerConfig.BASE_ESTIMATOR_PATH)
print("FINETUNED_ESTIMATOR_PATH:", ModelTrainerConfig.FINETUNED_ESTIMATOR_PATH)
print("PARAMS_FILE_PATH:", ModelTrainerConfig.PARAMS_FILE_PATH)

# components

In [None]:
from text_summarization.entity import DataTransformationArtifacts, ModelTrainerArtifacts
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from text_summarization.utils import create_dirs, load_json
from text_summarization.exception import CustomException
from text_summarization.logger import logging
from dataclasses import dataclass
import torch, sys, os
from pathlib import Path
import datasets



os.environ["DISABLE_MLFLOW_INTEGRATION"] = "True"
os.environ["WANDB_DISABLED"] = "True"

@dataclass
class ModelTrainerComponents:
    __data_transformation_config:DataTransformationArtifacts
    __model_trainer_config:ModelTrainerArtifacts

    @staticmethod
    def __get_model(repo_id:str, path:Path) -> AutoModelForSeq2SeqLM:
        """get model from repo id

        Args:
            repo_id (str): repository id of model
            path (Path): path to save model locally

        Returns:
            AutoModelForSeq2SeqLM: model loaded from repo_id
        """
        try:
            logging.info("In __get_model")

            # get device
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info(f"device in use {{{device}}}")
            
            # get model from hugging face
            model = AutoModelForSeq2SeqLM.from_pretrained(repo_id).to(device)
            logging.info(f"model collected from {{{repo_id}}}")

            # save model in locals
            model.save_pretrained(path)
            logging.info(f"{{{repo_id.split("/")[-1]}}} saved at {{{path}}}")

            logging.info("Out __get_model")
            return model
        except Exception as e:
            logging.info(e)
            raise CustomException(e, sys)
        
    @staticmethod
    def __get_trainer(model:AutoModelForSeq2SeqLM, tokenizer:AutoTokenizer, data_collator:DataCollatorForSeq2Seq, training_args:TrainingArguments, train_data:datasets.Dataset, validation_data:datasets.Dataset, callbacks:list) -> Trainer:
        try:
            logging.info("In __get_trainer")

            # Initialize the Trainer
            trainer = Trainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_data,
            eval_dataset=validation_data,
            callbacks=[]
            )
            logging.info("{{Trainer}} initialized")

            logging.info("Out __get_trainer")
            return trainer
        except Exception as e:
            logging.exception(e)
            raise CustomException(e, sys)


    def start_model_training(self) -> ModelTrainerArtifacts:
        """starts model's training and evaluation

        Returns:
            ModelTrainerArtifacts: path of artifacts created throughout training of model
        """
        try:
            logging.info("In start_model_training")
            # create required dir's
            create_dirs(self.__model_trainer_config.ARITFACTS_ROOT_DIR_PATH)
            create_dirs(self.__model_trainer_config.MODEL_ROOT_DIR_PATH)
            create_dirs(self.__model_trainer_config.TRAINER_ROOT_DIR_PATH)
            logging.info("create required dir's")

            # collect data
            # train_data = datasets.load_from_disk(self.__data_transformation_config.TRAIN_DATA_DIR_PATH)
            # validation_data = datasets.load_from_disk(self.__data_transformation_config.VALIDATION_DATA_DIR_PATH)

            # collect less data for faster training
            train_data = datasets.load_from_disk("less_records_artifacts/train")
            validation_data = datasets.load_from_disk("less_records_artifacts/validation")

            logging.info("train and validation data collected for model training")

            # get tokenizer
            tokenizer_path = self.__data_transformation_config.TOKENIZER_PATH
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            logging.info(f"tokenizer loaded from {{{tokenizer_path}}}")

            # get model
            repo_id = self.__data_transformation_config.MODEL_REPO_ID
            base_model_path = self.__model_trainer_config.BASE_ESTIMATOR_PATH
            model = self.__get_model(repo_id, base_model_path)


            # get datacollator
            data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

            # Set up training arguments
            params = load_json(self.__model_trainer_config.PARAMS_FILE_PATH)
            training_args = TrainingArguments(
                **params,
                # If using a GPU cluster, you might want to enable fp16 for faster training
                fp16=True if os.environ.get("USE_FP16", "false").lower() == "true" else False,
            )

            # get trainer 
            trainer = self.__get_trainer(
                model=model, 
                tokenizer=tokenizer,
                data_collator=data_collator,
                training_args=training_args,
                train_data=train_data,
                validation_data=validation_data,
                callbacks=[]
            )

            # start training
            trainer.train()

            # save model
            finetuned_model_path = self.__model_trainer_config.FINETUNED_ESTIMATOR_PATH
            model.save_pretrained(finetuned_model_path)

            logging.info("Out start_model_training")
            return self.__model_trainer_config
        except Exception as e:
            logging.exception(e)
            raise CustomException(e, sys)
        



# pipeline

In [None]:
from text_summarization.configuration import (
    DataTransformationConfig
)
from dataclasses import dataclass
from text_summarization.logger import logging


@dataclass
class ModelTrainerPipeline:

    def main(self) -> None:
        self.model_trainer = ModelTrainerComponents(DataTransformationConfig, ModelTrainerConfig)
        self.model_trainer.start_model_training()





STAGE_NAME = "Model Training"

if __name__=="__main__":
    print(f"\n>>>>>>>>>>>>>>>>>>>>> {STAGE_NAME} initiated <<<<<<<<<<<<<<<<<<<<<")
    logging.info(f"\n>>>>>>>>>>>>>>>>>>>>> {STAGE_NAME} initiated <<<<<<<<<<<<<<<<<<<<<")
    obj = ModelTrainerPipeline()
    obj.main()
    logging.info(f"\n>>>>>>>>>>>>>>>>>>>>> {STAGE_NAME} completed <<<<<<<<<<<<<<<<<<<<<")
    print(f"\n>>>>>>>>>>>>>>>>>>>>> {STAGE_NAME} completed <<<<<<<<<<<<<<<<<<<<<")


