Skip to content

Commit 77d488e

Browse files
author
Tete Xiao
committed
add upernet for ade20k
1 parent db3f45a commit 77d488e

File tree

4 files changed

+120
-13
lines changed

4 files changed

+120
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ usage: train.py [-h] [--id ID] [--arch_encoder ARCH_ENCODER]
137137

138138

139139
## Evaluation
140-
1. Evaluate a trained network on the validation set. Add ```--visualize``` option to output visualizations shown in teaser.
140+
1. Evaluate a trained network on the validation set. Add ```--visualize``` option to output visualizations as shown in teaser.
141141
```bash
142142
python3 eval.py --id MODEL_ID --suffix SUFFIX
143143
```

dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ def __init__(self, odgt, opt, max_sample=-1):
164164
self.imgMaxSize = opt.imgMaxSize
165165
# max down sampling rate of network to avoid rounding during conv or pooling
166166
self.padding_constant = opt.padding_constant
167-
# down sampling rate of segm labe
168-
self.segm_downsampling_rate = opt.segm_downsampling_rate
169167

170168
# mean and std
171169
self.img_transform = transforms.Compose([

eval.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,6 @@ def main(args):
165165
help='maximum input image size of long edge')
166166
parser.add_argument('--padding_constant', default=8, type=int,
167167
help='maxmimum downsampling rate of the network')
168-
parser.add_argument('--segm_downsampling_rate', default=8, type=int,
169-
help='downsampling rate of the segmentation label')
170168

171169
# Misc arguments
172170
parser.add_argument('--ckpt', default='./ckpt',

models.py

Lines changed: 119 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def forward(self, feed_dict, *, segSize=None):
3131
if self.deep_sup_scale is not None: # use deep supervision technique
3232
(pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
3333
else:
34-
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=False))
34+
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
3535

3636
loss = self.crit(pred, feed_dict['seg_label'])
3737
if self.deep_sup_scale is not None:
@@ -41,7 +41,7 @@ def forward(self, feed_dict, *, segSize=None):
4141
acc = self.pixel_acc(pred, feed_dict['seg_label'])
4242
return loss, acc
4343
else: # inference
44-
pred = self.decoder(self.encoder(feed_dict['img_data']), segSize=segSize)
44+
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
4545
return pred
4646

4747

@@ -94,10 +94,17 @@ def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
9494
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
9595
net_encoder = ResnetDilated(orig_resnet,
9696
dilate_scale=16)
97+
elif arch == 'resnet101':
98+
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
99+
net_encoder = Resnet(orig_resnet)
97100
elif arch == 'resnet101_dilated8':
98101
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
99102
net_encoder = ResnetDilated(orig_resnet,
100103
dilate_scale=8)
104+
elif arch == 'resnet101_dilated16':
105+
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
106+
net_encoder = ResnetDilated(orig_resnet,
107+
dilate_scale=16)
101108
else:
102109
raise Exception('Architecture undefined!')
103110

@@ -131,6 +138,18 @@ def build_decoder(self, arch='psp_bilinear_deepsup',
131138
num_class=num_class,
132139
fc_dim=fc_dim,
133140
use_softmax=use_softmax)
141+
elif arch == 'upernet_lite':
142+
net_decoder = UPerNet(
143+
num_class=num_class,
144+
fc_dim=fc_dim,
145+
use_softmax=use_softmax,
146+
fpn_dim=256)
147+
elif arch == 'upernet':
148+
net_decoder = UPerNet(
149+
num_class=num_class,
150+
fc_dim=fc_dim,
151+
use_softmax=use_softmax,
152+
fpn_dim=512)
134153
else:
135154
raise Exception('Architecture undefined!')
136155

@@ -162,17 +181,22 @@ def __init__(self, orig_resnet):
162181
self.layer3 = orig_resnet.layer3
163182
self.layer4 = orig_resnet.layer4
164183

165-
def forward(self, x):
184+
def forward(self, x, return_feature_maps=False):
185+
conv_out = []
186+
166187
x = self.relu1(self.bn1(self.conv1(x)))
167188
x = self.relu2(self.bn2(self.conv2(x)))
168189
x = self.relu3(self.bn3(self.conv3(x)))
169190
x = self.maxpool(x)
170191

171-
x = self.layer1(x)
172-
x = self.layer2(x)
173-
x = self.layer3(x)
174-
x = self.layer4(x)
175-
return x
192+
x = self.layer1(x); conv_out.append(x);
193+
x = self.layer2(x); conv_out.append(x);
194+
x = self.layer3(x); conv_out.append(x);
195+
x = self.layer4(x); conv_out.append(x);
196+
197+
if return_feature_maps:
198+
return conv_out
199+
return [x]
176200

177201

178202
class ResnetDilated(nn.Module):
@@ -404,3 +428,90 @@ def forward(self, conv_out, segSize=None):
404428
_ = nn.functional.log_softmax(_, dim=1)
405429

406430
return (x, _)
431+
432+
433+
# upernet
434+
class UPerNet(nn.Module):
435+
def __init__(self, num_class=150, fc_dim=4096,
436+
use_softmax=False, pool_scales=(1, 2, 3, 6),
437+
fpn_inplanes=(256,512,1024,2048), fpn_dim=256):
438+
super(UPerNet, self).__init__()
439+
self.use_softmax = use_softmax
440+
441+
# PPM Module
442+
self.psp = []
443+
for scale in pool_scales:
444+
self.psp.append(nn.Sequential(
445+
nn.AdaptiveAvgPool2d(scale),
446+
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
447+
SynchronizedBatchNorm2d(512),
448+
nn.ReLU(inplace=True)
449+
))
450+
self.psp = nn.ModuleList(self.psp)
451+
self.psp_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
452+
453+
# FPN Module
454+
self.fpn_in = []
455+
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
456+
self.fpn_in.append(nn.Sequential(
457+
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
458+
SynchronizedBatchNorm2d(fpn_dim),
459+
nn.ReLU(inplace=True)
460+
))
461+
self.fpn_in = nn.ModuleList(self.fpn_in)
462+
463+
self.fpn_out = []
464+
for i in range(len(fpn_inplanes) - 1): # skip the top layer
465+
self.fpn_out.append(nn.Sequential(
466+
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
467+
conv3x3_bn_relu(fpn_dim, fpn_dim // 2, 1)
468+
))
469+
self.fpn_out = nn.ModuleList(self.fpn_out)
470+
471+
self.conv_last = nn.Sequential(
472+
conv3x3_bn_relu(fpn_dim + (len(fpn_inplanes)-1) * fpn_dim // 2, fpn_dim, 1),
473+
nn.Conv2d(fpn_dim, num_class, kernel_size=1)
474+
)
475+
476+
def forward(self, conv_out, segSize=None):
477+
conv5 = conv_out[-1]
478+
479+
input_size = conv5.size()
480+
psp_out = [conv5]
481+
for pool_scale in self.psp:
482+
psp_out.append(nn.functional.upsample(
483+
pool_scale(conv5),
484+
(input_size[2], input_size[3]),
485+
mode='bilinear'))
486+
psp_out = torch.cat(psp_out, 1)
487+
f = self.psp_conv(psp_out)
488+
489+
fpn_feature_list = [f]
490+
for i in reversed(range(len(conv_out) - 1)):
491+
conv_x = conv_out[i]
492+
conv_x = self.fpn_in[i](conv_x) # lateral branch
493+
494+
f = nn.functional.upsample(f, size=conv_x.size()[2:], mode='bilinear') # top-down branch
495+
f = conv_x + f
496+
497+
fpn_feature_list.append(self.fpn_out[i](f))
498+
499+
fpn_feature_list.reverse() # [P2 - P5]
500+
output_size = fpn_feature_list[0].size()[2:]
501+
fusion_list = [fpn_feature_list[0]]
502+
for i in range(1, len(fpn_feature_list)):
503+
fusion_list.append(nn.functional.upsample(
504+
fpn_feature_list[i],
505+
output_size,
506+
mode='bilinear'))
507+
fusion_out = torch.cat(fusion_list, 1)
508+
x = self.conv_last(fusion_out)
509+
510+
if self.use_softmax: # is True during inference
511+
x = nn.functional.upsample(x, size=segSize, mode='bilinear')
512+
x = nn.functional.softmax(x, dim=1)
513+
return x
514+
515+
x = nn.functional.log_softmax(x, dim=1)
516+
517+
return x

0 commit comments

Comments
 (0)