Skip to content

Commit

Permalink
[Feature] Dedicated MMClsWandbHook for MMClassification (Weights and …
Browse files Browse the repository at this point in the history
…Biases Integration) (open-mmlab#764)

* wandb integration

* visualize using wandb tables

* wandb tables enhanced

* Refactor MMClsWandbHook (open-mmlab#1)

* [Enhance] Add extra dataloader settings in configs. (open-mmlab#752)

* Use `train_dataloader`, `val_dataloader` and `test_dataloader` settings
in the `data` field to specify different arguments.

* Fix bug

* Fix bug

* [Enhance] Improve CPE performance by reduce memory copy. (open-mmlab#762)

* [Feature] Support resize relative position embedding in `SwinTransformer`. (open-mmlab#749)

* [Feature]: Add resize rel pos embed

* [Refactor]: Create a separated resize_rel_pos_bias_table func

* [Refactor]: Refactor rel pos embed bias

* [Refactor]: Move interpolate into func

* Remove index buffer only when window_size changes

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] Add PoolFormer backbone and checkpoints. (open-mmlab#746)

* add PoolFormer

* fix some typos in PoolFormer

* fix lint error

* modify out_indices and gap

* fix typo

* fix lint

* fix typo

* fix typo in poolforemr README

* fix lint

* Update some paths

* Refactor freeze_stages method

* Add unit tests

* Fix lint

Co-authored-by: mzr1996 <mzr1996@163.com>

* Bump version to v0.22.1 (open-mmlab#785)

* [Docs] Refine API reference. (open-mmlab#774)

* [Docs] Refine API reference

* Add PoolFormer

* [Docs] Fix docs.

* [Enhance] Reduce the memory usage of unit tests for Swin-Transformer. (open-mmlab#759)

* [Feature] Support VAN. (open-mmlab#739)

* add van

* fix config

* add metafile

* add test

* model convert script

* fix review

* fix lint

* fix the configs and improve docs

* rm debug lines

* add VAN into api

Co-authored-by: Yu Zhaohui <1105212286@qq.com>

* [Feature] Support DenseNet. (open-mmlab#750)

* init add densenet implementation

* Add config and converted models

* update meta

* add test for memory efficient

* Add docs

* add doc for jit

* Update checkpoint path

* Update readthedocs

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Fix] Use symbolic link in the API reference of Chinese docs.

* [Enhance] Support training on IPU and add fine-tuning configs of ViT. (open-mmlab#723)

* implement training and evaluation on IPU

* fp16 SOTA

* Tput reaches 5600

* 123

* add poptorch dataloder

* change ipu_replicas to ipu-replicas

* add noqa to config long line(website)

* remove ipu dataloder test code

* del one blank line in test_builder

* refine the dataloder initialization

* fix a typo

* refine args for dataloder

* remove an annoted line

* process one more conflict

* adjust code structure in mmcv.ipu

* adjust ipu code structure in mmcv

* IPUDataloader to IPUDataLoader

* align with mmcv

* adjust according to mmcv

* mmcv code structre fixed

Co-authored-by: hudi <dihu@graphcore.ai>

* [Fix] Fix lint and mmcv version requirement for IPU.

* Bump version to v0.23.0 (open-mmlab#809)

* Refacoter Wandb hook and refine docstring

Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Weihao Yu <1090924009@qq.com>
Co-authored-by: takuoko <to78314910@gmail.com>
Co-authored-by: Yu Zhaohui <1105212286@qq.com>
Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: Hu Di <476658825@qq.com>
Co-authored-by: hudi <dihu@graphcore.ai>

* shuffle val data

* minor updates

* minor fix

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Weihao Yu <1090924009@qq.com>
Co-authored-by: takuoko <to78314910@gmail.com>
Co-authored-by: Yu Zhaohui <1105212286@qq.com>
Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: Hu Di <476658825@qq.com>
Co-authored-by: hudi <dihu@graphcore.ai>
  • Loading branch information
10 people committed Jun 2, 2022
1 parent 5b6f407 commit 4145384
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/en/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Hook
ClassNumCheckHook
PreciseBNHook
CosineAnnealingCooldownLrUpdaterHook
MMClsWandbHook


Optimizers
Expand Down
3 changes: 1 addition & 2 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)
from mmcv.runner.hooks import DistEvalHook, EvalHook

from mmcls.core import DistOptimizerHook
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger

Expand Down
4 changes: 3 additions & 1 deletion mmcls/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import DistEvalHook, EvalHook
from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
precision_recall_f1, recall, support)
from .mean_ap import average_precision, mAP
from .multilabel_eval_metrics import average_performance

__all__ = [
'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP',
'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1'
'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1',
'EvalHook', 'DistEvalHook'
]
78 changes: 78 additions & 0 deletions mmcls/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import torch.distributed as dist
from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EvalHook as BaseEvalHook
from torch.nn.modules.batchnorm import _BatchNorm


class EvalHook(BaseEvalHook):
"""Non-Distributed evaluation hook.
Comparing with the ``EvalHook`` in MMCV, this hook will save the latest
evaluation results as an attribute for other hooks to use (like
`MMClsWandbHook`).
"""

def __init__(self, dataloader, **kwargs):
super(EvalHook, self).__init__(dataloader, **kwargs)
self.latest_results = None

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
results = self.test_fn(runner.model, self.dataloader)
self.latest_results = results
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
# the key_score may be `None` so it needs to skip the action to save
# the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)


class DistEvalHook(BaseDistEvalHook):
"""Non-Distributed evaluation hook.
Comparing with the ``EvalHook`` in MMCV, this hook will save the latest
evaluation results as an attribute for other hooks to use (like
`MMClsWandbHook`).
"""

def __init__(self, dataloader, **kwargs):
super(DistEvalHook, self).__init__(dataloader, **kwargs)
self.latest_results = None

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)

tmpdir = self.tmpdir
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')

results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
gpu_collect=self.gpu_collect)
self.latest_results = results
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
# the key_score may be `None` so it needs to skip the action to
# save the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)
3 changes: 2 additions & 1 deletion mmcls/core/hook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from .class_num_check_hook import ClassNumCheckHook
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
from .precise_bn_hook import PreciseBNHook
from .wandblogger_hook import MMClsWandbHook

__all__ = [
'ClassNumCheckHook', 'PreciseBNHook',
'CosineAnnealingCooldownLrUpdaterHook'
'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
]
Loading

0 comments on commit 4145384

Please sign in to comment.