In [1]:
import os
os.chdir("../")

In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    training_data: Path
    trained_model_path: Path
    base_model_path: Path
    params_epochs: int
    params_learning_rate: float

@dataclass(frozen=True)
class PrepareCallbacksConfig:
    root_dir: Path
    tensorboard_root_log_dir: Path
    checkpoint_model_filepath: Path

In [3]:
from Consumer_Complaint_Analysis.constants import *
from Consumer_Complaint_Analysis.components import PrepareBaseModel
from Consumer_Complaint_Analysis.utils import read_yaml, create_directories

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])

    def get_prepare_callback_config(self) -> PrepareCallbacksConfig:
        config = self.config.prepare_callbacks
        model_ckpt_dir = os.path.dirname(config.checkpoint_model_filepath)
        create_directories([
            Path(model_ckpt_dir),
            Path(config.tensorboard_root_log_dir)
        ])

        prepare_callback_config = PrepareCallbacksConfig(
            root_dir=Path(config.root_dir),
            tensorboard_root_log_dir=Path(config.tensorboard_root_log_dir),
            checkpoint_model_filepath=Path(config.checkpoint_model_filepath)
        )

        return prepare_callback_config

    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model.base_model_path
        training_data = self.config.data_ingestion.csv_file_path
        create_directories([
            Path(training.root_dir)
        ])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            base_model_path=Path(prepare_base_model),
            training_data=Path(training_data),
            trained_model_path=Path(training.trained_model_path),
            params_epochs=self.params.EPOCHS,
            params_learning_rate=self.params.LEARNING_RATE
        )

        return training_config


In [5]:
import time
import torch
from torch.utils.tensorboard import SummaryWriter

class PrepareCallback:
    def __init__(self, config: PrepareCallbacksConfig):
        self.config = config

    @property
    def _create_tb_writer(self):
        timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
        tb_running_log_dir = os.path.join(
            self.config.tensorboard_root_log_dir,
            f"tb_logs_at_{timestamp}",
        )
        return torch.utils.tensorboard.SummaryWriter(tb_running_log_dir)

    def get_ckpt_callback(self, model):
        def save_ckpt(model, iteration):
            torch.nn.util.checkpoint.save(
                model,
                self.config.checkpoint_model_filepath
                )
        return save_ckpt

        
    def get_tb_callback(self, model, inputs, outputs, iteration):
        self._create_tb_writer.add_graph(model, inputs)
        self._create_tb_writer.add_scalar('Loss/train', outputs, iteration)
        self.save_ckpt(model, iteration)

In [6]:
from torch.nn import BCELoss
import torch.optim as optim
import torch.nn as nn
from torch.optim import lr_scheduler

class Training(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        
    def load_base_model(self):
        model = PrepareBaseModel(config=self.config, df=self.config.training_data)
        model.load_state_dict(torch.load(self.config.base_model_path))

    def validate(self, validation_dataloader, loss_fn):
        model = self.load_base_model() # Load the base model
        model.eval() # Set the model to evaluation mode
    
        # Initialize running validation loss and accuracy
        running_val_loss = 0.0
        running_val_acc = 0.0
        val_samples = 0
    
        # Turn off gradient calculation to save memory and computation
        with torch.no_grad():
            for inputs, labels in validation_dataloader:
                inputs = inputs.to("cuda")
                labels = labels.to("cuda")
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                running_val_loss += loss.item()
                val_samples += 1
    
                # Compute accuracy by comparing the model's predictions with the actual labels
                _, preds = torch.max(outputs, 1)
                running_val_acc += (preds == labels).float().mean().item()
    
        # Calculate average validation loss and accuracy
        avg_val_loss = running_val_loss / val_samples
        avg_val_acc = running_val_acc / val_samples
    
        return avg_val_loss, avg_val_acc


    def train(self):
        self.model = self.load_base_model()
        self.criterion = BCELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.params_learning_rate)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=7, gamma=0.1)
        self.model.to(self.device)
    
        for epoch in range(self.config.params_epochs):
            print(f'Epoch {epoch + 1}/{self.config.params_epochs}')
            print('-' * 10)
            for phase in ['train', 'val']:
                if phase == 'train':
                    self.model.train()
                    self.scheduler.step()
                else:
                    self.model.eval()
                running_loss = 0.0
                running_corrects = 0
                for inputs, labels in self.dataloaders[phase]:
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)
                    self.optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.criterion(outputs, labels)
                        if phase == 'train':
                            loss.backward()
                            self.optimizer.step()
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                epoch_loss = running_loss / len(self.dataloaders[phase].dataset)
                epoch_acc = running_corrects.double() / len(self.dataloaders[phase].dataset)
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    
        torch.save(self.model.state_dict(), self.config.trained_model_path)


    @staticmethod
    def save_model(model, path):
        torch.save(model.state_dict(), path)

In [8]:
try:
    config = ConfigurationManager()
    prepare_callbacks_config = config.get_prepare_callback_config()
    prepare_callbacks = PrepareCallback(config=prepare_callbacks_config)
    callback_list = [prepare_callbacks.get_ckpt_callback, prepare_callbacks.get_tb_callback]
    
    training_config = config.get_training_config()
    training = Training(config=training_config)
    training.load_base_model()
    training.validate(validation_dataloader=training.load_base_model().dataloaders,loss_fn=BCELoss)
    training.train()
    
except Exception as e:
    raise e

RuntimeError: Error(s) in loading state_dict for PrepareBaseModel:
	Unexpected key(s) in state_dict: "distilbert.embeddings.word_embeddings.weight", "distilbert.embeddings.position_embeddings.weight", "distilbert.embeddings.LayerNorm.weight", "distilbert.embeddings.LayerNorm.bias", "distilbert.transformer.layer.0.attention.q_lin.weight", "distilbert.transformer.layer.0.attention.q_lin.bias", "distilbert.transformer.layer.0.attention.k_lin.weight", "distilbert.transformer.layer.0.attention.k_lin.bias", "distilbert.transformer.layer.0.attention.v_lin.weight", "distilbert.transformer.layer.0.attention.v_lin.bias", "distilbert.transformer.layer.0.attention.out_lin.weight", "distilbert.transformer.layer.0.attention.out_lin.bias", "distilbert.transformer.layer.0.sa_layer_norm.weight", "distilbert.transformer.layer.0.sa_layer_norm.bias", "distilbert.transformer.layer.0.ffn.lin1.weight", "distilbert.transformer.layer.0.ffn.lin1.bias", "distilbert.transformer.layer.0.ffn.lin2.weight", "distilbert.transformer.layer.0.ffn.lin2.bias", "distilbert.transformer.layer.0.output_layer_norm.weight", "distilbert.transformer.layer.0.output_layer_norm.bias", "distilbert.transformer.layer.1.attention.q_lin.weight", "distilbert.transformer.layer.1.attention.q_lin.bias", "distilbert.transformer.layer.1.attention.k_lin.weight", "distilbert.transformer.layer.1.attention.k_lin.bias", "distilbert.transformer.layer.1.attention.v_lin.weight", "distilbert.transformer.layer.1.attention.v_lin.bias", "distilbert.transformer.layer.1.attention.out_lin.weight", "distilbert.transformer.layer.1.attention.out_lin.bias", "distilbert.transformer.layer.1.sa_layer_norm.weight", "distilbert.transformer.layer.1.sa_layer_norm.bias", "distilbert.transformer.layer.1.ffn.lin1.weight", "distilbert.transformer.layer.1.ffn.lin1.bias", "distilbert.transformer.layer.1.ffn.lin2.weight", "distilbert.transformer.layer.1.ffn.lin2.bias", "distilbert.transformer.layer.1.output_layer_norm.weight", "distilbert.transformer.layer.1.output_layer_norm.bias", "distilbert.transformer.layer.2.attention.q_lin.weight", "distilbert.transformer.layer.2.attention.q_lin.bias", "distilbert.transformer.layer.2.attention.k_lin.weight", "distilbert.transformer.layer.2.attention.k_lin.bias", "distilbert.transformer.layer.2.attention.v_lin.weight", "distilbert.transformer.layer.2.attention.v_lin.bias", "distilbert.transformer.layer.2.attention.out_lin.weight", "distilbert.transformer.layer.2.attention.out_lin.bias", "distilbert.transformer.layer.2.sa_layer_norm.weight", "distilbert.transformer.layer.2.sa_layer_norm.bias", "distilbert.transformer.layer.2.ffn.lin1.weight", "distilbert.transformer.layer.2.ffn.lin1.bias", "distilbert.transformer.layer.2.ffn.lin2.weight", "distilbert.transformer.layer.2.ffn.lin2.bias", "distilbert.transformer.layer.2.output_layer_norm.weight", "distilbert.transformer.layer.2.output_layer_norm.bias", "distilbert.transformer.layer.3.attention.q_lin.weight", "distilbert.transformer.layer.3.attention.q_lin.bias", "distilbert.transformer.layer.3.attention.k_lin.weight", "distilbert.transformer.layer.3.attention.k_lin.bias", "distilbert.transformer.layer.3.attention.v_lin.weight", "distilbert.transformer.layer.3.attention.v_lin.bias", "distilbert.transformer.layer.3.attention.out_lin.weight", "distilbert.transformer.layer.3.attention.out_lin.bias", "distilbert.transformer.layer.3.sa_layer_norm.weight", "distilbert.transformer.layer.3.sa_layer_norm.bias", "distilbert.transformer.layer.3.ffn.lin1.weight", "distilbert.transformer.layer.3.ffn.lin1.bias", "distilbert.transformer.layer.3.ffn.lin2.weight", "distilbert.transformer.layer.3.ffn.lin2.bias", "distilbert.transformer.layer.3.output_layer_norm.weight", "distilbert.transformer.layer.3.output_layer_norm.bias", "distilbert.transformer.layer.4.attention.q_lin.weight", "distilbert.transformer.layer.4.attention.q_lin.bias", "distilbert.transformer.layer.4.attention.k_lin.weight", "distilbert.transformer.layer.4.attention.k_lin.bias", "distilbert.transformer.layer.4.attention.v_lin.weight", "distilbert.transformer.layer.4.attention.v_lin.bias", "distilbert.transformer.layer.4.attention.out_lin.weight", "distilbert.transformer.layer.4.attention.out_lin.bias", "distilbert.transformer.layer.4.sa_layer_norm.weight", "distilbert.transformer.layer.4.sa_layer_norm.bias", "distilbert.transformer.layer.4.ffn.lin1.weight", "distilbert.transformer.layer.4.ffn.lin1.bias", "distilbert.transformer.layer.4.ffn.lin2.weight", "distilbert.transformer.layer.4.ffn.lin2.bias", "distilbert.transformer.layer.4.output_layer_norm.weight", "distilbert.transformer.layer.4.output_layer_norm.bias", "distilbert.transformer.layer.5.attention.q_lin.weight", "distilbert.transformer.layer.5.attention.q_lin.bias", "distilbert.transformer.layer.5.attention.k_lin.weight", "distilbert.transformer.layer.5.attention.k_lin.bias", "distilbert.transformer.layer.5.attention.v_lin.weight", "distilbert.transformer.layer.5.attention.v_lin.bias", "distilbert.transformer.layer.5.attention.out_lin.weight", "distilbert.transformer.layer.5.attention.out_lin.bias", "distilbert.transformer.layer.5.sa_layer_norm.weight", "distilbert.transformer.layer.5.sa_layer_norm.bias", "distilbert.transformer.layer.5.ffn.lin1.weight", "distilbert.transformer.layer.5.ffn.lin1.bias", "distilbert.transformer.layer.5.ffn.lin2.weight", "distilbert.transformer.layer.5.ffn.lin2.bias", "distilbert.transformer.layer.5.output_layer_norm.weight", "distilbert.transformer.layer.5.output_layer_norm.bias", "pre_classifier.weight", "pre_classifier.bias", "classifier.weight", "classifier.bias". 