/
cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py
107 lines (102 loc) · 2.81 KB
/
cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
_base_ = [
'../../_base_/datasets/mmseg/cityscapes.py',
'../../_base_/mmseg_runtime.py',
'../../_base_/schedules/mmseg/schedule_80k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
# pspnet r18
student = dict(
type='mmseg.EncoderDecoder',
backbone=dict(
type='ResNetV1c',
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://resnet18_v1c'),
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=512,
in_index=3,
channels=128,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=256,
in_index=2,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# pspnet r101
teacher = dict(
type='mmseg.EncoderDecoder',
backbone=dict(
type='ResNetV1c',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
)
# algorithm setting
algorithm = dict(
type='GeneralDistill',
architecture=dict(
type='MMSegArchitecture',
model=student,
),
distiller=dict(
type='SingleTeacherDistiller',
teacher=teacher,
teacher_trainable=False,
components=[
dict(
student_module='decode_head.conv_seg',
teacher_module='decode_head.conv_seg',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau=1,
loss_weight=5,
)
])
]),
)
find_unused_parameters = True