Skip to content

Commit

Permalink
[CodeCamp2023-584]Support DINO self-supervised learning in project (#…
Browse files Browse the repository at this point in the history
…1756)

* feat: impelemt DINO

* chore: delete debug code

* chore: impplement pre-commit

* fix: fix imported package

* chore: pre-commit check
  • Loading branch information
LALBJ committed Aug 23, 2023
1 parent 732b0f4 commit d2ccc44
Show file tree
Hide file tree
Showing 18 changed files with 612 additions and 0 deletions.
26 changes: 26 additions & 0 deletions projects/dino/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Implementation for DINO

**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation.

First, ensure you are in the root directory of MMPretrain, then you have two choices
to play with DINO in MMPretrain:

## Slurm

If you are using a cluster managed by Slurm, you can use the following command to
start your job:

```shell
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp
```

The above command will pre-train the model on a single node with 8 GPUs.

## PyTorch

If you are using a single machine, without any cluster management software, you can use the following command

```shell
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8
--amp
```
104 changes: 104 additions & 0 deletions projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
model = dict(
type='DINO',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
neck=dict(
type='DINONeck',
in_channels=768,
out_channels=65536,
hidden_channels=2048,
bottleneck_channels=256),
head=dict(
type='DINOHead',
out_channels=65536,
num_crops=10,
student_temp=0.1,
center_momentum=0.9))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='DINOMultiCrop',
global_crops_scale=(0.4, 1.0),
local_crops_scale=(0.05, 0.4),
local_crops_number=8),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type='mmpretrain.ImageNet',
data_root='/data/imagenet/',
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline,
))
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys=dict(
ln=dict(decay_mult=0.0),
bias=dict(decay_mult=0.0),
pos_embed=dict(decay_mult=0.0),
mask_token=dict(decay_mult=0.0),
cls_token=dict(decay_mult=0.0))),
loss_scale='dynamic')
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-09,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
default_scope = 'mmpretrain'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(
window_size=10,
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_level = 'INFO'
load_from = None
resume = True
randomness = dict(seed=2, diff_rank_seed=True)
custom_hooks = [
dict(
type='DINOTeacherTempWarmupHook',
warmup_teacher_temp=0.04,
teacher_temp=0.04,
teacher_temp_warmup_epochs=0,
max_epochs=100)
]
1 change: 1 addition & 0 deletions projects/dino/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transform import * # noqa: F401,F403
3 changes: 3 additions & 0 deletions projects/dino/dataset/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .processing import DINOMultiCrop

__all__ = ['DINOMultiCrop']
91 changes: 91 additions & 0 deletions projects/dino/dataset/transform/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

from mmcv.transforms import RandomApply # noqa: E501
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale

from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
RandomResizedCrop, Solarize)
from mmpretrain.registry import TRANSFORMS


@TRANSFORMS.register_module()
class DINOMultiCrop(BaseTransform):
"""Multi-crop transform for DINO.
This module applies the multi-crop transform for DINO.
Args:
global_crops_scale (int): Scale of global crops.
local_crops_scale (int): Scale of local crops.
local_crops_number (int): Number of local crops.
"""

def __init__(self, global_crops_scale: int, local_crops_scale: int,
local_crops_number: int) -> None:
super().__init__()
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale

flip_and_color_jitter = Compose([
RandomFlip(prob=0.5, direction='horizontal'),
RandomApply([
ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
],
prob=0.8),
RandomGrayscale(
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
)
])

self.global_transform_1 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])

self.global_transform_2 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
Solarize(thr=128, prob=0.2),
])

self.local_crops_number = local_crops_number
self.local_transform = Compose([
RandomResizedCrop(
96,
crop_ratio_range=local_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])

def transform(self, results: dict) -> dict:
ori_img = results['img']
crops = []
results['img'] = ori_img
crops.append(self.global_transform_1(results)['img'])
results['img'] = ori_img
crops.append(self.global_transform_2(results)['img'])
for _ in range(self.local_crops_number):
results['img'] = ori_img
crops.append(self.local_transform(results)['img'])
results['img'] = crops
return results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
repr_str += f'local_crop_number = {self.local_crops_number})'
return repr_str
1 change: 1 addition & 0 deletions projects/dino/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hooks import * # noqa
3 changes: 3 additions & 0 deletions projects/dino/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook

__all__ = ['DINOTeacherTempWarmupHook']
33 changes: 33 additions & 0 deletions projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmengine.hooks import Hook

from mmpretrain.registry import HOOKS


@HOOKS.register_module()
class DINOTeacherTempWarmupHook(Hook):
"""Warmup teacher temperature for DINO.
This hook warmups the temperature for teacher to stabilize the training
process.
Args:
warmup_teacher_temp (float): Warmup temperature for teacher.
teacher_temp (float): Temperature for teacher.
teacher_temp_warmup_epochs (int): Warmup epochs for teacher
temperature.
max_epochs (int): Maximum epochs for training.
"""

def __init__(self, warmup_teacher_temp: float, teacher_temp: float,
teacher_temp_warmup_epochs: int, max_epochs: int) -> None:
super().__init__()
self.teacher_temps = np.concatenate(
(np.linspace(warmup_teacher_temp, teacher_temp,
teacher_temp_warmup_epochs),
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp))

def before_train_epoch(self, runner) -> None:
runner.model.module.head.teacher_temp = self.teacher_temps[
runner.epoch]
3 changes: 3 additions & 0 deletions projects/dino/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .algorithm import * # noqa
from .head import * # noqa
from .neck import * # noqa
3 changes: 3 additions & 0 deletions projects/dino/models/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dino import DINO

__all__ = ['DINO']
82 changes: 82 additions & 0 deletions projects/dino/models/algorithm/dino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import torch
from torch import nn

from mmpretrain.models import BaseSelfSupervisor, CosineEMA
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample


@MODELS.register_module()
class DINO(BaseSelfSupervisor):
"""Implementation for DINO.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
backbone (dict): Config for backbone.
neck (dict): Config for neck.
head (dict): Config for head.
pretrained (str, optional): Path for pretrained model.
Defaults to None.
base_momentum (float, optional): Base momentum for momentum update.
Defaults to 0.99.
data_preprocessor (dict, optional): Config for data preprocessor.
Defaults to None.
init_cfg (list[dict] | dict, optional): Config for initialization.
Defaults to None.
"""

def __init__(self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
base_momentum: float = 0.99,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)

# create momentum model
self.teacher = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
# weight normalization layer
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
self.neck.last_layer.weight_g.data.fill_(1)
self.neck.last_layer.weight_g.requires_grad = False
self.teacher.module[1].last_layer = nn.utils.weight_norm(
self.teacher.module[1].last_layer)
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
self.teacher.module[1].last_layer.weight_g.requires_grad = False

def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
global_crops = torch.cat(inputs[:2])
local_crops = torch.cat(inputs[2:])
# teacher forward
teacher_output = self.teacher(global_crops)

# student forward global
student_output_global = self.backbone(global_crops)
student_output_global = self.neck(student_output_global)

# student forward local
student_output_local = self.backbone(local_crops)
student_output_local = self.neck(student_output_local)

student_output = torch.cat(
(student_output_global, student_output_local))

# compute loss
loss = self.head(student_output, teacher_output)

return dict(loss=loss)
3 changes: 3 additions & 0 deletions projects/dino/models/head/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dino_head import DINOHead

__all__ = ['DINOHead']
Loading

0 comments on commit d2ccc44

Please sign in to comment.