In [None]:
import os
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from transformers import BertModel
from gslTranslater.constants import *
from gslTranslater.utils.common import read_yaml, create_directories
from gslTranslater.components.sign_language_translator import SignLanguageTranslator


In [None]:
os.chdir("../")

In [None]:
@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_learning_rate: float
    params_image_size: list

In [None]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath=CONFIG_FILE_PATH, 
        params_filepath=PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        
        create_directories([self.config.artifacts_root])

    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        training_data = Path(training.training_data_dir)
        create_directories([Path(training.root_dir)])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_model_path),
            training_data=Path(training_data),
            params_epochs=self.params.EPOCHS,
            params_batch_size=self.params.BATCH_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_image_size=self.params.IMAGE_SIZE
        )

        return training_config

In [None]:
class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config

    def get_base_model(self):
        self.model = SignLanguageTranslator(
            cnn_model=models.resnet50(pretrained=False),
            transformer_model=BertModel.from_pretrained('nlpaueb/bert-base-greek-uncased-v1'),
            tokenizer_len=None
        )
        self.model.load_state_dict(torch.load(self.config.updated_base_model_path))
        self.model.train()

    def train_valid_loader(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.params_image_size[:-1]),
            transforms.ToTensor(),
        ])

        train_dataset = datasets.ImageFolder(root=self.config.training_data / "Train", transform=transform)
        valid_dataset = datasets.ImageFolder(root=self.config.training_data / "Test", transform=transform)

        self.train_loader = DataLoader(train_dataset, batch_size=self.config.params_batch_size, shuffle=True)
        self.valid_loader = DataLoader(valid_dataset, batch_size=self.config.params_batch_size, shuffle=False)

    @staticmethod
    def save_model(path: Path, model: nn.Module):
        torch.save(model.state_dict(), path)

    def train(self):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.params_learning_rate)
        best_val_loss = float('inf')

        for epoch in range(self.config.params_epochs):
            self.model.train()
            running_loss = 0.0
            for images, labels in self.train_loader:
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            print(f"Epoch {epoch + 1}/{self.config.params_epochs}, Loss: {running_loss / len(self.train_loader)}")

            # Validation phase
            self.model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for images, labels in self.valid_loader:
                    outputs = self.model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

            val_loss /= len(self.valid_loader)
            print(f"Validation Loss after Epoch {epoch + 1}: {val_loss}")

            # Checkpoint: Save the model if it has the best validation loss so far
            if val_loss < best_val_loss:
                print(f"Validation loss improved from {best_val_loss} to {val_loss}. Saving checkpoint...")
                best_val_loss = val_loss
                self.save_model(path=self.config.trained_model_path, model=self.model)

        print("Training completed. Best validation loss was:", best_val_loss)

In [None]:
try:
    config = ConfigurationManager()
    training_config = config.get_training_config()
    training = Training(config=training_config)
    training.get_base_model()
    training.train_valid_loader()
    training.train()

except Exception as e:
    raise e