In [12]:
import os
os.chdir("../")

In [1]:
from dataclasses import dataclass
from pathlib import Path

In [2]:
@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list



@dataclass(frozen=True)
class PrepareCallbacksConfig:
    root_dir: Path
    tensorboard_root_log_dir: Path
    checkpoint_model_filepath: Path

In [3]:
from ChickenDiseaseClassifier.constants import *
from ChickenDiseaseClassifier.utils.common import read_yaml, create_directories

In [None]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])


    
    def get_prepare_callback_config(self) -> PrepareCallbacksConfig:
        config = self.config.prepare_callbacks
        model_ckpt_dir = os.path.dirname(config.checkpoint_model_filepath)
        create_directories([
            Path(model_ckpt_dir),
            Path(config.tensorboard_root_log_dir)
        ])

        prepare_callback_config = PrepareCallbacksConfig(
            root_dir=Path(config.root_dir),
            tensorboard_root_log_dir=Path(config.tensorboard_root_log_dir),
            checkpoint_model_filepath=Path(config.checkpoint_model_filepath)
        )

        return prepare_callback_config
    
    
    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.unzip_dir, "Chicken-fecal-images")
        create_directories([
            Path(training.root_dir)
        ])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_base_model_path),
            training_data=Path(training_data),
            params_epochs=params.EPOCHS,
            params_batch_size=params.BATCH_SIZE,
            params_is_augmentation=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE
        )

        return training_config

In [6]:
import time
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [7]:
class PrepareCallback:
    def __init__(self, config: PrepareCallbacksConfig):
        self.config = config

    def get_tb_ckpt_callbacks(self):
        timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")

        tb_logger = TensorBoardLogger(
            save_dir=self.config.tensorboard_root_log_dir,
            name=f"tb_logs_at_{timestamp}"
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=self.config.root_dir,
            filename="best-checkpoint",
            save_top_k=1,
            monitor="val_loss",
            mode="min"
        )

        return tb_logger, checkpoint_callback

In [9]:
import os
import urllib.request as request
from zipfile import ZipFile
import torch
import pytorch_lightning as pl
from pathlib import Path
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [10]:
class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config

    def get_base_model(self):
        self.model = torch.load(self.config.updated_base_model_path)
    def train_valid_dataloader(self):
        h, w = self.config.params_image_size[:2]

        base_transform = transforms.Compose([
            transforms.Resize((h, w)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        if self.config.params_is_augmentation:
            train_transform = transforms.Compose([
                transforms.Resize((h, w)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(40),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
        else:
            train_transform = base_transform

        full_dataset = datasets.ImageFolder(
            root=self.config.training_data,
            transform=train_transform
        )

        val_size = int(0.2 * len(full_dataset))
        train_size = len(full_dataset) - val_size

        train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

        self.train_loader = DataLoader(
            train_ds,
            batch_size=self.config.params_batch_size,
            shuffle=True,
            num_workers=4
        )

        self.val_loader = DataLoader(
            val_ds,
            batch_size=self.config.params_batch_size,
            shuffle=False,
            num_workers=4
        )

    @staticmethod
    def save_model(path: Path, model: torch.nn.Module):
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), path)


    def train(self, callbacks):
        trainer = pl.Trainer(
            max_epochs=self.config.params_epochs,
            callbacks=callbacks,
            accelerator="auto"
        )

        trainer.fit(
            self.model,
            train_dataloaders=self.train_loader,
            val_dataloaders=self.val_loader
        )

        trainer.save_checkpoint(self.config.trained_model_path)

In [14]:
try:
    config = ConfigurationManager()
    prepare_callbacks_config = config.get_prepare_callback_config()
    prepare_callbacks = PrepareCallback(config=prepare_callbacks_config)
    callback_list = prepare_callbacks.get_tb_ckpt_callbacks()

    training_config = config.get_training_config()
    training = Training(config=training_config)
    training.get_base_model()
    training.train_valid_dataloader
    training.train(
        callback_list=callback_list
    )
    
except Exception as e:
    raise e

[2025-12-31 10:46:45,582: INFO: common: yaml file: config\config.yaml loaded succefully]
[2025-12-31 10:46:45,584: INFO: common: yaml file: params.yaml loaded succefully]
[2025-12-31 10:46:45,585: INFO: common: created directory at: artifacts]
[2025-12-31 10:46:45,585: INFO: common: created directory at: artifacts\prepare_callbacks\checkpoint_dir]
[2025-12-31 10:46:45,586: INFO: common: created directory at: artifacts\prepare_callbacks\tensorboard_log_dir]
[2025-12-31 10:46:45,587: INFO: common: created directory at: artifacts\training]


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL torchvision.models.vgg.VGG was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchvision.models.vgg.VGG])` or the `torch.serialization.safe_globals([torchvision.models.vgg.VGG])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.