In [1]:
import os

In [2]:
!pwd

/home/gourav/ML/Text_Classification_Model_Builder/research


In [3]:
os.chdir("../")

In [4]:
!pwd

/home/gourav/ML/Text_Classification_Model_Builder


In [18]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen= True)
class TrainModelConfig:
    train_data_path: Path 
    val_data_path: Path
    save_model_dir: Path
    
    model_name : str
    num_labels : int
    epochs : int
    train_batch_size : int
    val_batch_size : int


In [19]:
from src.constants import *
from src.utils.common import read_yaml, create_directories

In [22]:
class ConfigurationManager:

    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_train_model_config(self) -> TrainModelConfig:

        

        num_labels = self.params.num_labels
        config = self.config.train_model
        params = self.params.model_params
        
        create_directories([config.saved_model_dir])

        train_model_config = TrainModelConfig(
            train_data_path= config.train_data_path,
            val_data_path= config.val_data_path,
            save_model_dir= config.saved_model_dir,

            model_name = params.model_name,
            num_labels = num_labels,
            epochs = params.epochs,
            train_batch_size = params.train_batch_size,
            val_batch_size = params.val_batch_size
        )

        return train_model_config

In [23]:
conf = ConfigurationManager()
train_model_config = conf.get_train_model_config()
print(train_model_config)


[2024-05-26 00:50:39,220: INFO: common: yaml file: config/config.yaml loaded successfully]
[2024-05-26 00:50:39,235: INFO: common: yaml file: params.yaml loaded successfully]
[2024-05-26 00:50:39,242: INFO: common: already created directory: artifacts]
[2024-05-26 00:50:39,245: INFO: common: already created directory: artifacts/models]
TrainModelConfig(train_data_path='artifacts/split_data/', val_data_path='artifacts/split_data/', save_model_dir='artifacts/models', model_name='bert-base-uncased', num_labels=5, epochs=5, train_batch_size=4, val_batch_size=4)


In [9]:
from src.utils.common import load_json, join_path
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [24]:
class TrainModel:
    def __init__(self, config = TrainModelConfig):
        self.config = config

    def dataset_format(self, data):

        result_data = Dataset.from_dict(data)

        return result_data
    
    
    def compute_metrics(self,pred):
        
        labels = pred.label_ids

        preds = pred.predictions.argmax(-1)

        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')

        acc = accuracy_score(labels, preds)

        # Return the computed metrics as a dictionary
        return {
            'Accuracy': acc,
            'F1': f1,
            'Precision': precision,
            'Recall': recall
        }


    def train_model(self):
        config = self.config

        train_data = load_json(Path(join_path(config.train_data_path, "train_data.json")))
        train_data = self.dataset_format(train_data)

        val_data = load_json(Path(join_path(config.val_data_path, "val_data.json")))
        val_data = self.dataset_format(val_data)


        training_args = TrainingArguments(
            output_dir = join_path(self.config.save_model_dir, self.config.model_name),
            num_train_epochs = self.config.epochs,
            per_device_train_batch_size= self.config.train_batch_size,
            per_device_eval_batch_size=self.config.val_batch_size,
            learning_rate = 2e-5,
            disable_tqdm = False
        )


        model = AutoModelForSequenceClassification.from_pretrained(self.config.model_name, num_labels = config.num_labels)
        
        trainer = Trainer(
            model = model,
            args = training_args,
            train_dataset = train_data,
            eval_dataset = val_data,
            compute_metrics = self.compute_metrics
        )

        # model training
        trainer.train()
        



In [14]:
try:
    config = ConfigurationManager()
    train_model_config = config.get_train_model_config()


    train_model = TrainModel(train_model_config)
    train_model.train_model()
except Exception as e:
    print(e)
    

[2024-05-16 13:58:13,896: INFO: common: yaml file: config/config.yaml loaded successfully]
[2024-05-16 13:58:13,904: INFO: common: yaml file: params.yaml loaded successfully]
[2024-05-16 13:58:13,907: INFO: common: already created directory: artifacts]


Error while downloading from https://cdn-lfs.huggingface.co/bert-base-cased/1d8bdcee6021e2c25f0325e84889b61c2eb26b843eef5659c247af138d64f050?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1716107295&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjEwNzI5NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9iZXJ0LWJhc2UtY2FzZWQvMWQ4YmRjZWU2MDIxZTJjMjVmMDMyNWU4NDg4OWI2MWMyZWIyNmI4NDNlZWY1NjU5YzI0N2FmMTM4ZDY0ZjA1MD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=Zmv76YpyqqWQCdYvEIEBLyZUcGW8vuibpD%7EyEBuUczAvrCRvH4PJzC59-TqalS2wTOZAZKf%7EgSGrNd2TLaJT1AK332Qvzcpypd1pxcxjoWhUZs6JCxRCqxVv6emd19K0Ee%7ExOVfT0vrMYeK13EojMSTEaKPq6Knd2IwxQqzaKSONeNcYgEdAfrD5TsPzjhKaaUUMDFHaflyrXkXGEhuYaB5tNyG9WZkorOib%7Em7n%7EN%7EOSqerLhhTam48BgWI4pD7y-iqotBS-iBxPFz6t-sFF7fqp4sHRcDr1QNevNtuYwLYNRYgcyuSV-WZzzch2838qHAfAByglunGDinkaTbbjQ__&Key-Pair-Id=KVTP0A1DKRTAX: HTTPS

Trying to resume download...]
(MaxRetryError('HTTPSConnectionPool(host=\'cdn-lfs.huggingface.co\', port=443): Max retries exceeded with url: /bert-base-cased/1d8bdcee6021e2c25f0325e84889b61c2eb26b843eef5659c247af138d64f050?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1716107295&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjEwNzI5NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9iZXJ0LWJhc2UtY2FzZWQvMWQ4YmRjZWU2MDIxZTJjMjVmMDMyNWU4NDg4OWI2MWMyZWIyNmI4NDNlZWY1NjU5YzI0N2FmMTM4ZDY0ZjA1MD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=Zmv76YpyqqWQCdYvEIEBLyZUcGW8vuibpD~yEBuUczAvrCRvH4PJzC59-TqalS2wTOZAZKf~gSGrNd2TLaJT1AK332Qvzcpypd1pxcxjoWhUZs6JCxRCqxVv6emd19K0Ee~xOVfT0vrMYeK13EojMSTEaKPq6Knd2IwxQqzaKSONeNcYgEdAfrD5TsPzjhKaaUUMDFHaflyrXkXGEhuYaB5tNyG9WZkorOib~m7n~N~OSqerLhhTam48BgWI4pD7y-iqotBS-iBxPFz6t-sFF7fqp4sHRcDr1QNevNtuYwLYNRYgcyu