# Reference

**MMEngine intro**: https://mmengine.readthedocs.io/en/latest/get_started/introduction.html

**MMEngine 15min to start**: https://mmengine.readthedocs.io/en/latest/get_started/15_minutes.html

# Content

## Build a model

A model class needs to inherit from `BaseModel`.

The `forward` method should receive inputs from the dataset.

The parameter `mode` has following common options:
* loss: return loss
* predict: return prediction results and gt
* tensor: return prediction results, this option is used in model complexity analysis

In [6]:
from typing import Optional, Union
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels=None, mode="tensor"):
        x = self.resnet(imgs)
        if mode == "loss":
            return {"loss": F.cross_entropy(x, labels)}
        elif mode == "predict":
            return x, labels
        elif mode == "tensor":
            return x

## Build a dataset and dataloader

The Torchvision built-in datasets are good enough in this simple example.

In [2]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])

train_dataloader = DataLoader(
    batch_size=32,
    shuffle=True,
    dataset=torchvision.datasets.CIFAR10(
        './data/cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)
        ])
    )
)

val_dataloader = DataLoader(
    batch_size=32,
    shuffle=False,
    dataset=torchvision.datasets.CIFAR10(
        './data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)
        ])
    )
)

Files already downloaded and verified
Files already downloaded and verified


## Build a evaluation metric

The evaluation metric class needs to inherit from `BaseMetric` and should have `process` and `compute_metrics` methods.

`process` method accepts `data_batch` (a batch of inputs) and `data_samples` (a batch of outputs) as parameters. The processed information is saved to `self.results` property.

`compute_metrics` accepts a `result` parameter, which is all the information save in `process`. These computed evaluation metric is returned in a `dict`.

In [3]:
from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            "batch_size": len(gt),
            "correct": (score.argmax(dim=1) == gt).sum().cpu()
        })

    def compute_metrics(self, results):
        total_correct = sum(item["correct"] for item in results)
        total_size = sum(item["batch_size"] for item in results)
        return dict(accuracy=100 * total_correct / total_size)

## Build a runner and run the task

In [None]:
from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # the model used for training and validation.
    # Needs to meet specific interface requirements
    model=MMResNet50(),
    # working directory which saves training logs and weight files
    work_dir='./logs_and_checkpoints',
    # train dataloader needs to meet the PyTorch data loader protocol
    train_dataloader=train_dataloader,
    # optimize wrapper for optimization with additional features like
    # AMP, gradtient accumulation, etc
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # trainging coinfs for specifying training epoches, verification intervals, etc
    # by_epoch means whether you're in EpochBased mode or IterBased mode
        # by_epoch option affect the frequency of logging, checkpoint saving, and validation
        # Ref1: https://mmengine.readthedocs.io/en/latest/common_usage/set_interval.html
        # Ref2: https://mmengine.readthedocs.io/en/latest/common_usage/epoch_to_iter.html
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # validation dataloaer also needs to meet the PyTorch data loader protocol
    val_dataloader=val_dataloader,
    # validation configs for specifying additional parameters required for validation
    val_cfg=dict(),
    # validation evaluator. The default one is used here
    val_evaluator=dict(type=Accuracy),
    # whether you want to resume training or start a fresh training
    resume=True,
    # Specify the checkpoint path
    # load_from='./work_dir/epoch_2.pth',
)

runner.train()

## Compute the FLOPs and parameters of a model

In [8]:
from mmengine.analysis import get_model_complexity_info

input_shape = (3, 224, 224)
model = MMResNet50()
analysis_results = get_model_complexity_info(model, input_shape)
print(analysis_results["out_table"])
print("Model Flops: {}".format(analysis_results["flops_str"]))
print("Model Parameters: {}".format(analysis_results["params_str"]))


+------------------------+----------------------+------------+--------------+
|[1m [0m[1mmodule                [0m[1m [0m|[1m [0m[1m#parameters or shape[0m[1m [0m|[1m [0m[1m#flops    [0m[1m [0m|[1m [0m[1m#activations[0m[1m [0m|
+------------------------+----------------------+------------+--------------+
| resnet                 | 25.557M              | 4.145G     | 11.115M      |
|  conv1                 |  9.408K              |  0.118G    |  0.803M      |
|   conv1.weight         |   (64, 3, 7, 7)      |            |              |
|  bn1                   |  0.128K              |  4.014M    |  0           |
|   bn1.weight           |   (64,)              |            |              |
|   bn1.bias             |   (64,)              |            |              |
|  layer1                |  0.216M              |  0.69G     |  4.415M      |
|   layer1.0             |   75.008K            |   0.241G   |   2.007M     |
|    layer1.0.conv1      |    4.096K         