In [17]:
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50(pretrained=True)

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

In [18]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_transformer = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(**norm_cfg)
])

vaild_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**norm_cfg)
])

train_dataloader = DataLoader(
    batch_size=128, shuffle=True,
    dataset=torchvision.datasets.CIFAR10(
        '/home/akiyo/nfs/zhang/dataset',
        train=True, download=True, transform=train_transformer
))

val_dataloader = DataLoader(
    batch_size=128, shuffle=False,
    dataset=torchvision.datasets.CIFAR10(
        '/home/akiyo/nfs/zhang/dataset',
        train=False, download=True, transform=vaild_transformer
))

Files already downloaded and verified
Files already downloaded and verified


In [19]:
from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # save the middle result of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # return the dict containing the eval results
        # the key is the name of the metric name
        return dict(accuracy=100 * total_correct / total_size)

In [20]:
from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # the model used for training and validation.
    # Needs to meet specific interface requirements
    model=MMResNet50(),
    # working directory which saves training logs and weight files
    work_dir='/home/akiyo/sandbox/work_dirs/quick_start',
    # train dataloader needs to meet the PyTorch data loader protocol
    train_dataloader=train_dataloader,
    # optimize wrapper for optimization with additional features like
    # AMP, gradtient accumulation, etc
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # trainging coinfs for specifying training epoches, verification intervals, etc
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # validation dataloaer also needs to meet the PyTorch data loader protocol
    val_dataloader=val_dataloader,
    # validation configs for specifying additional parameters required for validation
    val_cfg=dict(),
    # validation evaluator. The default one is used here
    val_evaluator=dict(type=Accuracy),
)

runner.train()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/akiyo/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 35.1MB/s]


12/15 17:25:26 - mmengine - [4m[37mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.8.15 (default, Nov 24 2022, 15:19:38) [GCC 11.2.0]
    CUDA available: True
    numpy_random_seed: 146479247
    GPU 0,1,2,3,4,5,6,7: Tesla V100-SXM2-32GB
    CUDA_HOME: /usr/local/cuda
    NVCC: Cuda compilation tools, release 11.4, V11.4.152
    GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
    PyTorch: 1.11.0+cu113
    PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.5.2 (Git Hash a9302535553c73243c632ad3c4c80beec3d19a1e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-ge



12/15 17:25:26 - mmengine - [4m[37mINFO[0m - Checkpoints will be saved to /home/akiyo/sandbox/work_dirs/quick_start.
12/15 17:25:28 - mmengine - [4m[37mINFO[0m - Epoch(train) [1][ 10/391]  lr: 1.0000e-03  eta: 0:03:52  time: 0.1197  data_time: 0.0564  memory: 1159  loss: 10.1526
12/15 17:25:29 - mmengine - [4m[37mINFO[0m - Epoch(train) [1][ 20/391]  lr: 1.0000e-03  eta: 0:03:49  time: 0.1170  data_time: 0.0651  memory: 1159  loss: 5.6264
12/15 17:25:30 - mmengine - [4m[37mINFO[0m - Epoch(train) [1][ 30/391]  lr: 1.0000e-03  eta: 0:03:49  time: 0.1212  data_time: 0.0669  memory: 1159  loss: 3.5788
12/15 17:25:31 - mmengine - [4m[37mINFO[0m - Epoch(train) [1][ 40/391]  lr: 1.0000e-03  eta: 0:03:49  time: 0.1214  data_time: 0.0704  memory: 1159  loss: 2.4943
12/15 17:25:32 - mmengine - [4m[37mINFO[0m - Epoch(train) [1][ 50/391]  lr: 1.0000e-03  eta: 0:03:44  time: 0.1093  data_time: 0.0636  memory: 1159  loss: 2.2143
12/15 17:25:33 - mmengine - [4m[37mINFO[0m - Epoch(

KeyboardInterrupt: 