# [lenet04] BaseValidator As an Validator Template

在這個教學裡面，會告訴你怎麼透過模板，修改出自己的驗證器 (Validator)。

Validataor 基本上用在二個地方，分別為訓練中的驗證，以及訓練後的測試。

In [8]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

## Introduction

In [14]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from modules.base.validator import BaseValidator
from print_source import print_source

print_source(BaseValidator)

基本上跟 Trainer 的架構差不多，主要的差別在於 `__init__` 多了 `is_train` 和 `output_infer` 的設計。

1. `is_train` 用來區別這個 Validator 是否用在訓練流程中。如果是的話，會把目前訓練的步數呈現出來。
2. `output_infer` 用來決定驗證時，神經網路的輸出是透過 `forward` 還是 `inference`，一般而言， `forward` 的輸出會是 logits，而 `inference` 的輸出會是 label。

至於 `validation` 中

```python
if not isinstance(dataloader, (list, tuple)):
            dataloader = [dataloader]
        else:
            dataloader = [dl for dl in dataloader if dl is not None]
        data_iter = itertools.chain(*dataloader)
        pbar = tqdm(
            data_iter,
            total=sum(len(dl) for dl in dataloader),
            dynamic_ncols=True,
        )
```

則在你輸入一個包含很多 dataloaders 的 list 的時候，幫你把這些 dataloaders 給串起來。

## Example 

這邊我們來實作適用於 [TorchMetrics](https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html) 的 Validator。 <br>
我們配合 TorchMetrics 的 API，把計算分數的函數放在 Validator 中對應的位置。 <br>
我們以 `# >> Modified` 表示修改的地方。

In [10]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import torch
import torchmetrics
import itertools
from tqdm.auto import tqdm
from modules.base.validator import BaseValidator

class CustomValidator(BaseValidator):
    def validation(self, module, dataloader, global_step=None):

        module.eval()
        module.to(self.device)

        if not isinstance(dataloader, (list, tuple)):
            dataloader = [dataloader]
        else:
            dataloader = [dl for dl in dataloader if dl is not None]
        data_iter = itertools.chain(*dataloader)
        pbar = tqdm(
            data_iter,
            total=sum(len(dl) for dl in dataloader),
            dynamic_ncols=True,
        )

        with torch.no_grad():
            for batch in pbar:
                # Infer, decollate data into list of samples, and postprocess both predictions and labels
                images, targets = self.unpack_item(batch)

                # Get inferred / forwarded results of module
                if getattr(module, "inference", False) and self.output_infer:
                    infer_out = module.inference(images)
                else:
                    infer_out = module.forward(images)

                # Compute validation metrics
                batch_metric = self.metric(infer_out, targets).item() # >> Modified

                # Update progressbar
                info = {
                    "metric_name": self.metric.__class__.__name__,
                    "batch_metric": batch_metric,
                    "global_step": global_step,
                }
                desc = self.pbar_description.format(**info)
                pbar.set_description(desc)

        output = self.metric.compute() # >> Modified
        self.metric.reset() # >> Modified
        return output

In [7]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from modules.base.trainer import BaseTrainer
from modules.base.updater import BaseUpdater
from modules.base.validator import BaseValidator

from torch import nn
from mnist_dataloaders import train_dataloader, val_dataloader, test_dataloader
from lenet import LeNet5,  batch_acc

validator = CustomValidator(metric=torchmetrics.classification.Accuracy(task="multiclass", num_classes=10).to("cuda"))
updater = BaseUpdater()
trainer = BaseTrainer(max_iter=1000, eval_step=334, validator=validator)


# train
print("Train:")
lenet = LeNet5().cuda()
trainer.train(module=lenet, updater=updater, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

# test
print("\n Test:")
lenet.load("./checkpoints")
validator.validation(module=lenet, dataloader=test_dataloader)

Train:
--------
Device: cuda
# of Training Samples: 211
# of Validation Samples: 47
Max iteration: 1000 steps (validates per 334 steps)
Checkpoint directory: ./checkpoints/
Evaluation metric: MulticlassAccuracy
--------


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

[32m2024-08-19 16:08:34.109[0m | [32m[1mSUCCESS [0m | [36mmodules.base.trainer[0m:[36msuccess[0m:[36m71[0m - [32m[1mModel saved! Validation: (New) 0.81783 > (Old) 0.00000[0m


  0%|          | 0/47 [00:00<?, ?it/s]

[32m2024-08-19 16:08:46.299[0m | [32m[1mSUCCESS [0m | [36mmodules.base.trainer[0m:[36msuccess[0m:[36m71[0m - [32m[1mModel saved! Validation: (New) 0.90783 > (Old) 0.81783[0m


  0%|          | 0/47 [00:00<?, ?it/s]

[32m2024-08-19 16:08:58.429[0m | [32m[1mSUCCESS [0m | [36mmodules.base.trainer[0m:[36msuccess[0m:[36m71[0m - [32m[1mModel saved! Validation: (New) 0.92933 > (Old) 0.90783[0m

 Test:


  0%|          | 0/79 [00:00<?, ?it/s]

tensor(0.9376, device='cuda:0')

完成。