In [None]:
import os
os.chdir('../')
%pwd

In [None]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelConfig:
    root_dir: Path
    verification_info_dir: Path
    verification_summary_file: Path
    verification_weights_file: Path
    src_vocab_size: int
    tgt_vocab_size: int
    src_seq_len: int
    tgt_seq_len: int
    d_model: int
    N: int
    h: int
    dropout: float
    d_ff: int

In [None]:
from transformerEnFa.constants import *
from transformerEnFa.utils.common import read_yaml, create_directories

In [None]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
        
    def get_model_config(self) -> ModelConfig:
        config = self.config.model_config

        create_directories([config.root_dir])
        create_directories([config.verification_info_dir])
        

        return ModelConfig(
            root_dir = config.root_dir,
            verification_info_dir = config.verification_info_dir,
            verification_summary_file = config.verification_summary_file, 
            verification_weights_file = config.verification_weights_file, 
            src_vocab_size = config.src_vocab_size,
            tgt_vocab_size = config.tgt_vocab_size,
            src_seq_len = config.src_seq_len,
            tgt_seq_len = config.tgt_seq_len,
            d_model = config.d_model,
            N = config.N,
            h = config.h,
            dropout = config.dropout,
            d_ff = config.d_ff
        )

In [None]:
config = ConfigurationManager()
model_config = config.get_model_config()

In [None]:
import torch
from pathlib import Path
from transformerEnFa.models.transformer import built_transformer
from transformerEnFa.config.configuration import ConfigurationManager
from transformerEnFa.utils.model_utils import save_model_summary, save_initial_weights
from transformerEnFa.logging import logger

class ModelVerificationTrainingPipeline:
    def __init__(self):
        self.config_manager = ConfigurationManager()
        self.config = self.config_manager.get_model_config()
        self.device = self.get_device()

    def get_device(self):
        device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        logger.info(f"Using device: {device}")
        if device == 'cuda':
            logger.info(f"Device name: {torch.cuda.get_device_name(0)}")
            logger.info(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.2f} GB")
        elif device == 'mps':
            logger.info("Device name: Apple Metal Performance Shaders (MPS)")
        else:
            logger.info("NOTE: If you have a GPU, consider using it for training.")
        return torch.device(device)

    def main(self):
        try:
            # Instantiate the model to check for syntax errors in initialization
            model = built_transformer(
                src_vocab_size=self.config.src_vocab_size,
                tgt_vocab_size=self.config.tgt_vocab_size,
                src_seq_len=self.config.src_seq_len,
                tgt_seq_len=self.config.tgt_seq_len,
                d_model=self.config.d_model,
                N=self.config.N,
                h=self.config.h,
                dropout=self.config.dropout,
                d_ff=self.config.d_ff
            ).to(self.device)
            logger.info("Model instantiation successful.")
            create_directories([self.config.verification_info_dir])
            
            # Optionally, perform a simple forward pass check
            # dummy_input = torch.rand(1, self.config.src_seq_len).long().to(self.device)
            # with torch.no_grad():
            #     _ = model(dummy_input)
            # logger.info("Basic forward pass successful.")

        except Exception as e:
            logger.error(f"Model verification failed: {e}")
            raise e

        # Save model summary and initial weights as before
        save_model_summary(
            model,
            Path(self.config.verification_info_dir) / self.config.verification_summary_file,
            input_size=(self.config.src_seq_len,),
            device=str(self.device)
        )
      

# Usage

model_verification_pipeline = ModelVerificationTrainingPipeline()
model_verification_pipeline.main()
