#  [lenet05] All Customization in one Notebook

在這個教學裡面，我們示範怎麼把所有元件整合在一起。

這樣的使用方法最有彈性，但缺點在於這樣的便利會使我們
1. 在開發的過程中寫出耦合度高的元件
2. 直接修改程式碼，一次動到實驗中的多個變項。

In [1]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

## Example
SGD 泛化能力好，Adam 收斂速度快，那我們有沒有辦法前期先透過 Adam 訓練，後期再透過 SGD 收斂到一個泛化能力強的神經網絡權重？這個優化方法稱為 SWATS，由 Keskar 和 Socher 在 2017 的 [Improving Generalization Performance by Switching from Adam to SGD](https://arxiv.org/pdf/1712.07628) 提出。

在 Keskar 和 Socher 的研究中，SGD 的學習率和切換優化器的時機取決於梯度和超參數。這邊我們為求簡便，不嚴謹地利用 `loss < 2` 作為切換的門檻。

修改模板的部分我們以 `# >> Modified` 或 `# -- Modified` 標記。

In [15]:
from __future__ import annotations

import torch
from torch import nn
from torch.utils.data import DataLoader
from functools import partial

import warnings
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from tqdm.auto import tqdm
import numpy as np

from modules.base.trainer import BaseTrainer, TrainLogger
from modules.base.updater import BaseUpdater
from modules.base.validator import BaseValidator

import torch
import torchmetrics
import itertools
from tqdm.auto import tqdm
from modules.base.validator import BaseValidator


class CustomTrainer(BaseTrainer):
    def __init__(
        self,
        max_iter: int = 10000,
        eval_step: int = 200,
        metric: Metric | None = None,
        validator: BaseValidator | None = None,
        checkpoint_dir: str = "./checkpoints/",
        device: Literal["cuda", "cpu"] = "cuda",
        unpack_item: Callable | Literal["monai", "pytorch"] = "pytorch",
        dev: bool = False,
    ):
        super().__init__(max_iter, eval_step, metric, validator, checkpoint_dir, device, unpack_item, dev)
        self.max_iter = max_iter
        self.eval_step = eval_step
        self.checkpoint_dir = checkpoint_dir
        self.device = device

        self.pbar_description = "Training ({step} / {max_iter} Steps) ({optimizer}) (loss={loss:2.5f})"

        # Setup validator or metric used during training
        # 設定驗證器
        self.metric = metric
        self.validator = validator if validator else BaseValidator(self.metric, is_train=True, device=self.device)
        self.metric = self.validator.metric

        # The function to unpack the batch into images and targets (based on the __getitem___ of your dataset)
        # 設定打開 batch 為 images 和 targets 的函式
        if unpack_item == "pytorch":
            self.unpack_item = lambda batch: (batch[0].to(self.device), batch[1].to(self.device))
        elif unpack_item == "monai":
            self.unpack_item = lambda batch: (batch["image"].to(self.device), batch["label"].to(self.device))
        else:
            self.unpack_item = unpack_item

        # developer mode
        # 開發模式
        if dev:
            self.max_iter = 10
            self.eval_step = 3
            self.checkpoint_dir = "./debug/"
            warnings.warn(
                "Trainer will be executed under developer mode. "
                f"max_iter = {self.max_iter}, "
                f"eval_step = {self.eval_step}, "
                f"checkpoint_dir = {self.checkpoint_dir} ",
                UserWarning,
            )

    def train(
        self,
        module: nn.Module,
        updater: BaseUpdater,
        *,
        train_dataloader: DataLoader | None = None,
        val_dataloader: DataLoader | None = None,
    ):
        self.show_training_info(module, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

        # Initalize progress bar and logger
        # 初始化進度條和紀錄器
        train_pbar = tqdm(range(self.max_iter), dynamic_ncols=True)
        logger = TrainLogger(self.checkpoint_dir)

        # Initial stage. Note: updater(module) checks the module and returns a partial func of updating parameters.
        # 初始化訓練狀態和更新函式
        module.to(self.device)
        best_metric = 0
        module_update = updater(module)

        for step in train_pbar:
            module.train()

            # Backpropagation
            # 反向傳播
            batch = next(iter(train_dataloader))
            images, targets = self.unpack_item(batch)
            loss, opt = module_update(images, targets)  # >> Modified

            # Update progress bar and summary writer
            # 紀錄目前訓練狀態
            info = {"step": step + 1, "max_iter": self.max_iter, "loss": loss, "optimizer":opt} # >> Modified
            train_pbar.set_description(self.pbar_description.format(**info))
            logger.log_train(module.criterion, loss, step)

            # Validation
            # 驗證目前的網路訓練
            if val_dataloader and ((step + 1) % self.eval_step == 0) or (step == self.max_iter - 1):
                val_metrics = self.validator(module, val_dataloader, global_step=step)
                logger.log_val(self.metric, suffix=["Average"], value=(val_metrics,), step=step)

                # Select validation metric
                # 指定驗證分數
                if val_metrics is not np.nan:
                    val_metric = val_metrics

                # Update best metric
                # 更新驗證分數
                if val_metric > best_metric:
                    module.save(self.checkpoint_dir)
                    logger.success(f"Model saved! Validation: (New) {val_metric:2.5f} > (Old) {best_metric:2.5f}")
                    best_metric = val_metric
                else:
                    logger.info(f"No improvement. Validation: (New) {val_metric:2.5f} <= (Old) {best_metric:2.5f}")



class CustomUpdater(BaseUpdater):
    """Base class of updaters."""

    def __init__(self):
        self.phase= "adam"

    def register_module(self, module):
        self.check_module(module)
        # --- Modified 
        self.sgd_optimizer = torch.optim.SGD(module.net.parameters(), lr=1)
        self.adam_optimizer = torch.optim.Adam(module.net.parameters(), lr=0.001)
        module.optimizer = self.adam_optimizer
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.sgd_optimizer, T_0=10, T_mult=1, eta_min=0, last_epoch=-1)
        module.scheduler = scheduler
        # --- Modified
        return partial(self.update, module)

    def update(self, module, images, targets, **kwargs) -> float:
        module.optimizer.zero_grad()
        preds = module(images)
        loss = module.criterion(preds, targets)
        loss.backward()
        module.optimizer.step()
        # --- Modified 
        if loss < 2 and self.phase == "adam":
            self.phase = "sgd"
            module.optimizer = self.sgd_optimizer
        if self.phase == "sgd":
            module.scheduler.step()
        return loss.item(), self.phase
        # --- Modified 
    

class CustomValidator(BaseValidator):
    def validation(self, module, dataloader, global_step=None):

        module.eval()
        module.to(self.device)

        if not isinstance(dataloader, (list, tuple)):
            dataloader = [dataloader]
        else:
            dataloader = [dl for dl in dataloader if dl is not None]
        data_iter = itertools.chain(*dataloader)
        pbar = tqdm(
            data_iter,
            total=sum(len(dl) for dl in dataloader),
            dynamic_ncols=True,
        )

        with torch.no_grad():
            for batch in pbar:
                # Infer, decollate data into list of samples, and postprocess both predictions and labels
                images, targets = self.unpack_item(batch)

                # Get inferred / forwarded results of module
                if getattr(module, "inference", False) and self.output_infer:
                    infer_out = module.inference(images)
                else:
                    infer_out = module.forward(images)

                # Compute validation metrics
                batch_metric = self.metric(infer_out, targets).item() # >> Modified

                # Update progressbar
                info = {
                    "metric_name": self.metric.__class__.__name__,
                    "batch_metric": batch_metric,
                    "global_step": global_step,
                }
                desc = self.pbar_description.format(**info)
                pbar.set_description(desc)

        output = self.metric.compute() # >> Modified
        self.metric.reset() # >> Modified
        return output

In [16]:
import torchmetrics
from mnist_dataloaders import train_dataloader, val_dataloader, test_dataloader
from lenet import LeNet5


validator = CustomValidator(metric=torchmetrics.classification.Accuracy(task="multiclass", num_classes=10).to("cuda"))
updater = CustomUpdater()
trainer = CustomTrainer(max_iter=3000, eval_step=3000, validator=validator)

# train
print("Train:")
lenet = LeNet5().cuda()
trainer.train(module=lenet, updater=updater, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

# test
print("\n Test:")
validator.validation(module=lenet, dataloader=test_dataloader)

Train:
--------
Device: cuda
# of Training Samples: 211
# of Validation Samples: 47
Max iteration: 3000 steps (validates per 3000 steps)
Checkpoint directory: ./checkpoints/
Evaluation metric: MulticlassAccuracy
--------


  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

[32m2024-08-19 18:21:28.500[0m | [32m[1mSUCCESS [0m | [36mmodules.base.trainer[0m:[36msuccess[0m:[36m71[0m - [32m[1mModel saved! Validation: (New) 0.95517 > (Old) 0.00000[0m

 Test:


  0%|          | 0/79 [00:00<?, ?it/s]

tensor(0.9576, device='cuda:0')

完成。