-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
613 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Implementation for DINO | ||
|
||
**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation. | ||
|
||
First, ensure you are in the root directory of MMPretrain, then you have two choices | ||
to play with DINO in MMPretrain: | ||
|
||
## Slurm | ||
|
||
If you are using a cluster managed by Slurm, you can use the following command to | ||
start your job: | ||
|
||
```shell | ||
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp | ||
``` | ||
|
||
The above command will pre-train the model on a single node with 8 GPUs. | ||
|
||
## PyTorch | ||
|
||
If you are using a single machine, without any cluster management software, you can use the following command | ||
|
||
```shell | ||
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8 | ||
--amp | ||
``` |
103 changes: 103 additions & 0 deletions
103
projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
model = dict( | ||
type='DINO', | ||
data_preprocessor=dict( | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict(type='mmcls.VisionTransformer', arch='b', patch_size=16), | ||
neck=dict( | ||
type='DINONeck', | ||
in_channels=768, | ||
out_channels=65536, | ||
hidden_channels=2048, | ||
bottleneck_channels=256), | ||
head=dict( | ||
type='DINOHead', | ||
out_channels=65536, | ||
num_crops=10, | ||
student_temp=0.1, | ||
center_momentum=0.9)) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='DINOMultiCrop', | ||
global_crops_scale=(0.4, 1.0), | ||
local_crops_scale=(0.05, 0.4), | ||
local_crops_number=8), | ||
dict(type='PackInputs', meta_keys=['img_path']) | ||
] | ||
train_dataloader = dict( | ||
batch_size=32, | ||
num_workers=16, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
collate_fn=dict(type='default_collate'), | ||
dataset=dict( | ||
type='mmcls.ImageNet', | ||
data_root='/home/liushi_22151211/imagenet/classification', | ||
# ann_file='meta/train.txt', | ||
data_prefix=dict(img_path='train/'), | ||
pipeline=train_pipeline, | ||
)) | ||
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05) | ||
optim_wrapper = dict( | ||
type='AmpOptimWrapper', | ||
optimizer=dict( | ||
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05), | ||
paramwise_cfg=dict( | ||
custom_keys=dict( | ||
ln=dict(decay_mult=0.0), | ||
bias=dict(decay_mult=0.0), | ||
pos_embed=dict(decay_mult=0.0), | ||
mask_token=dict(decay_mult=0.0), | ||
cls_token=dict(decay_mult=0.0))), | ||
loss_scale='dynamic') | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=1e-09, | ||
by_epoch=True, | ||
begin=0, | ||
end=10, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='CosineAnnealingLR', | ||
T_max=90, | ||
by_epoch=True, | ||
begin=10, | ||
end=100, | ||
convert_to_iter_based=True) | ||
] | ||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) | ||
default_scope = 'mmpretrain' | ||
default_hooks = dict( | ||
runtime_info=dict(type='RuntimeInfoHook'), | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=100), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1), | ||
sampler_seed=dict(type='DistSamplerSeedHook')) | ||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl')) | ||
log_processor = dict( | ||
window_size=10, | ||
custom_cfg=[dict(data_src='', method='mean', window_size='global')]) | ||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='UniversalVisualizer', | ||
vis_backends=[dict(type='LocalVisBackend')], | ||
name='visualizer') | ||
log_level = 'INFO' | ||
load_from = None | ||
resume = True | ||
randomness = dict(seed=2, diff_rank_seed=True) | ||
custom_hooks = [ | ||
dict( | ||
type='DINOTeacherTempWarmupHook', | ||
warmup_teacher_temp=0.04, | ||
teacher_temp=0.04, | ||
teacher_temp_warmup_epochs=0, | ||
max_epochs=100) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .transform import * # noqa: F401,F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .processing import DINOMultiCrop | ||
|
||
__all__ = ['DINOMultiCrop'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import random | ||
|
||
from mmcv.transforms import RandomApply # noqa: E501 | ||
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale | ||
|
||
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur, | ||
RandomResizedCrop, Solarize) | ||
from mmpretrain.registry import TRANSFORMS | ||
|
||
|
||
@TRANSFORMS.register_module() | ||
class DINOMultiCrop(BaseTransform): | ||
"""Multi-crop transform for DINO. | ||
This module applies the multi-crop transform for DINO. | ||
Args: | ||
global_crops_scale (int): Scale of global crops. | ||
local_crops_scale (int): Scale of local crops. | ||
local_crops_number (int): Number of local crops. | ||
""" | ||
|
||
def __init__(self, global_crops_scale: int, local_crops_scale: int, | ||
local_crops_number: int) -> None: | ||
super().__init__() | ||
self.global_crops_scale = global_crops_scale | ||
self.local_crops_scale = local_crops_scale | ||
|
||
flip_and_color_jitter = Compose([ | ||
RandomFlip(prob=0.5, direction='horizontal'), | ||
RandomApply([ | ||
ColorJitter( | ||
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1) | ||
], | ||
prob=0.8), | ||
RandomGrayscale( | ||
prob=0.2, | ||
keep_channels=True, | ||
channel_weights=(0.114, 0.587, 0.2989), | ||
) | ||
]) | ||
|
||
self.global_transform_1 = Compose([ | ||
RandomResizedCrop( | ||
224, | ||
crop_ratio_range=global_crops_scale, | ||
interpolation='bicubic'), | ||
flip_and_color_jitter, | ||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), | ||
]) | ||
|
||
self.global_transform_2 = Compose([ | ||
RandomResizedCrop( | ||
224, | ||
crop_ratio_range=global_crops_scale, | ||
interpolation='bicubic'), | ||
flip_and_color_jitter, | ||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), | ||
Solarize(thr=128, prob=0.2), | ||
]) | ||
|
||
self.local_crops_number = local_crops_number | ||
self.local_transform = Compose([ | ||
RandomResizedCrop( | ||
96, | ||
crop_ratio_range=local_crops_scale, | ||
interpolation='bicubic'), | ||
flip_and_color_jitter, | ||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), | ||
]) | ||
|
||
def transform(self, results: dict) -> dict: | ||
ori_img = results['img'] | ||
crops = [] | ||
results['img'] = ori_img | ||
crops.append(self.global_transform_1(results)['img']) | ||
results['img'] = ori_img | ||
crops.append(self.global_transform_2(results)['img']) | ||
for _ in range(self.local_crops_number): | ||
results['img'] = ori_img | ||
crops.append(self.local_transform(results)['img']) | ||
results['img'] = crops | ||
return results | ||
|
||
def __repr__(self) -> str: | ||
repr_str = self.__class__.__name__ | ||
repr_str += f'(global_crops_scale = {self.global_crops_scale}, ' | ||
repr_str += f'local_crops_scale = {self.local_crops_scale}, ' | ||
repr_str += f'local_crop_number = {self.local_crops_number})' | ||
return repr_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .hooks import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook | ||
|
||
__all__ = ['DINOTeacherTempWarmupHook'] |
33 changes: 33 additions & 0 deletions
33
projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
from mmengine.hooks import Hook | ||
|
||
from mmpretrain.registry import HOOKS | ||
|
||
|
||
@HOOKS.register_module() | ||
class DINOTeacherTempWarmupHook(Hook): | ||
"""Warmup teacher temperature for DINO. | ||
This hook warmups the temperature for teacher to stabilize the training | ||
process. | ||
Args: | ||
warmup_teacher_temp (float): Warmup temperature for teacher. | ||
teacher_temp (float): Temperature for teacher. | ||
teacher_temp_warmup_epochs (int): Warmup epochs for teacher | ||
temperature. | ||
max_epochs (int): Maximum epochs for training. | ||
""" | ||
|
||
def __init__(self, warmup_teacher_temp: float, teacher_temp: float, | ||
teacher_temp_warmup_epochs: int, max_epochs: int) -> None: | ||
super().__init__() | ||
self.teacher_temps = np.concatenate( | ||
(np.linspace(warmup_teacher_temp, teacher_temp, | ||
teacher_temp_warmup_epochs), | ||
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp)) | ||
|
||
def before_train_epoch(self, runner) -> None: | ||
runner.model.module.head.teacher_temp = self.teacher_temps[ | ||
runner.epoch] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .algorithm import * # noqa | ||
from .head import * # noqa | ||
from .neck import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dino import DINO | ||
|
||
__all__ = ['DINO'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from mmpretrain.models import BaseSelfSupervisor, CosineEMA | ||
from mmpretrain.registry import MODELS | ||
from mmpretrain.structures import DataSample | ||
|
||
|
||
@MODELS.register_module() | ||
class DINO(BaseSelfSupervisor): | ||
"""Implementation for DINO. | ||
This module is proposed in `DINO: Emerging Properties in Self-Supervised | ||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_. | ||
Args: | ||
backbone (dict): Config for backbone. | ||
neck (dict): Config for neck. | ||
head (dict): Config for head. | ||
pretrained (str, optional): Path for pretrained model. | ||
Defaults to None. | ||
base_momentum (float, optional): Base momentum for momentum update. | ||
Defaults to 0.99. | ||
data_preprocessor (dict, optional): Config for data preprocessor. | ||
Defaults to None. | ||
init_cfg (list[dict] | dict, optional): Config for initialization. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
backbone: dict, | ||
neck: dict, | ||
head: dict, | ||
pretrained: Optional[str] = None, | ||
base_momentum: float = 0.99, | ||
data_preprocessor: Optional[dict] = None, | ||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | ||
super().__init__( | ||
backbone=backbone, | ||
neck=neck, | ||
head=head, | ||
pretrained=pretrained, | ||
data_preprocessor=data_preprocessor, | ||
init_cfg=init_cfg) | ||
|
||
# create momentum model | ||
self.teacher = CosineEMA( | ||
nn.Sequential(self.backbone, self.neck), momentum=base_momentum) | ||
# weight normalization layer | ||
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer) | ||
self.neck.last_layer.weight_g.data.fill_(1) | ||
self.neck.last_layer.weight_g.requires_grad = False | ||
self.teacher.module[1].last_layer = nn.utils.weight_norm( | ||
self.teacher.module[1].last_layer) | ||
self.teacher.module[1].last_layer.weight_g.data.fill_(1) | ||
self.teacher.module[1].last_layer.weight_g.requires_grad = False | ||
|
||
def loss(self, inputs: torch.Tensor, | ||
data_samples: List[DataSample]) -> dict: | ||
global_crops = torch.cat(inputs[:2]) | ||
local_crops = torch.cat(inputs[2:]) | ||
# teacher forward | ||
teacher_output = self.teacher(global_crops) | ||
|
||
# student forward global | ||
student_output_global = self.backbone(global_crops) | ||
student_output_global = self.neck(student_output_global) | ||
|
||
# student forward local | ||
student_output_local = self.backbone(local_crops) | ||
student_output_local = self.neck(student_output_local) | ||
|
||
student_output = torch.cat( | ||
(student_output_global, student_output_local)) | ||
|
||
# compute loss | ||
loss = self.head(student_output, teacher_output) | ||
|
||
return dict(loss=loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dino_head import DINOHead | ||
|
||
__all__ = ['DINOHead'] |
Oops, something went wrong.