@@ -31,7 +31,7 @@ def forward(self, feed_dict, *, segSize=None):
3131 if self .deep_sup_scale is not None : # use deep supervision technique
3232 (pred , pred_deepsup ) = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = True ))
3333 else :
34- pred = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = False ))
34+ pred = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = True ))
3535
3636 loss = self .crit (pred , feed_dict ['seg_label' ])
3737 if self .deep_sup_scale is not None :
@@ -41,7 +41,7 @@ def forward(self, feed_dict, *, segSize=None):
4141 acc = self .pixel_acc (pred , feed_dict ['seg_label' ])
4242 return loss , acc
4343 else : # inference
44- pred = self .decoder (self .encoder (feed_dict ['img_data' ]), segSize = segSize )
44+ pred = self .decoder (self .encoder (feed_dict ['img_data' ], return_feature_maps = True ), segSize = segSize )
4545 return pred
4646
4747
@@ -94,10 +94,17 @@ def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
9494 orig_resnet = resnet .__dict__ ['resnet50' ](pretrained = pretrained )
9595 net_encoder = ResnetDilated (orig_resnet ,
9696 dilate_scale = 16 )
97+ elif arch == 'resnet101' :
98+ orig_resnet = resnet .__dict__ ['resnet101' ](pretrained = pretrained )
99+ net_encoder = Resnet (orig_resnet )
97100 elif arch == 'resnet101_dilated8' :
98101 orig_resnet = resnet .__dict__ ['resnet101' ](pretrained = pretrained )
99102 net_encoder = ResnetDilated (orig_resnet ,
100103 dilate_scale = 8 )
104+ elif arch == 'resnet101_dilated16' :
105+ orig_resnet = resnet .__dict__ ['resnet101' ](pretrained = pretrained )
106+ net_encoder = ResnetDilated (orig_resnet ,
107+ dilate_scale = 16 )
101108 else :
102109 raise Exception ('Architecture undefined!' )
103110
@@ -131,6 +138,18 @@ def build_decoder(self, arch='psp_bilinear_deepsup',
131138 num_class = num_class ,
132139 fc_dim = fc_dim ,
133140 use_softmax = use_softmax )
141+ elif arch == 'upernet_lite' :
142+ net_decoder = UPerNet (
143+ num_class = num_class ,
144+ fc_dim = fc_dim ,
145+ use_softmax = use_softmax ,
146+ fpn_dim = 256 )
147+ elif arch == 'upernet' :
148+ net_decoder = UPerNet (
149+ num_class = num_class ,
150+ fc_dim = fc_dim ,
151+ use_softmax = use_softmax ,
152+ fpn_dim = 512 )
134153 else :
135154 raise Exception ('Architecture undefined!' )
136155
@@ -162,17 +181,22 @@ def __init__(self, orig_resnet):
162181 self .layer3 = orig_resnet .layer3
163182 self .layer4 = orig_resnet .layer4
164183
165- def forward (self , x ):
184+ def forward (self , x , return_feature_maps = False ):
185+ conv_out = []
186+
166187 x = self .relu1 (self .bn1 (self .conv1 (x )))
167188 x = self .relu2 (self .bn2 (self .conv2 (x )))
168189 x = self .relu3 (self .bn3 (self .conv3 (x )))
169190 x = self .maxpool (x )
170191
171- x = self .layer1 (x )
172- x = self .layer2 (x )
173- x = self .layer3 (x )
174- x = self .layer4 (x )
175- return x
192+ x = self .layer1 (x ); conv_out .append (x );
193+ x = self .layer2 (x ); conv_out .append (x );
194+ x = self .layer3 (x ); conv_out .append (x );
195+ x = self .layer4 (x ); conv_out .append (x );
196+
197+ if return_feature_maps :
198+ return conv_out
199+ return [x ]
176200
177201
178202class ResnetDilated (nn .Module ):
@@ -404,3 +428,90 @@ def forward(self, conv_out, segSize=None):
404428 _ = nn .functional .log_softmax (_ , dim = 1 )
405429
406430 return (x , _ )
431+
432+
433+ # upernet
434+ class UPerNet (nn .Module ):
435+ def __init__ (self , num_class = 150 , fc_dim = 4096 ,
436+ use_softmax = False , pool_scales = (1 , 2 , 3 , 6 ),
437+ fpn_inplanes = (256 ,512 ,1024 ,2048 ), fpn_dim = 256 ):
438+ super (UPerNet , self ).__init__ ()
439+ self .use_softmax = use_softmax
440+
441+ # PPM Module
442+ self .psp = []
443+ for scale in pool_scales :
444+ self .psp .append (nn .Sequential (
445+ nn .AdaptiveAvgPool2d (scale ),
446+ nn .Conv2d (fc_dim , 512 , kernel_size = 1 , bias = False ),
447+ SynchronizedBatchNorm2d (512 ),
448+ nn .ReLU (inplace = True )
449+ ))
450+ self .psp = nn .ModuleList (self .psp )
451+ self .psp_conv = conv3x3_bn_relu (fc_dim + len (pool_scales )* 512 , fpn_dim , 1 )
452+
453+ # FPN Module
454+ self .fpn_in = []
455+ for fpn_inplane in fpn_inplanes [:- 1 ]: # skip the top layer
456+ self .fpn_in .append (nn .Sequential (
457+ nn .Conv2d (fpn_inplane , fpn_dim , kernel_size = 1 , bias = False ),
458+ SynchronizedBatchNorm2d (fpn_dim ),
459+ nn .ReLU (inplace = True )
460+ ))
461+ self .fpn_in = nn .ModuleList (self .fpn_in )
462+
463+ self .fpn_out = []
464+ for i in range (len (fpn_inplanes ) - 1 ): # skip the top layer
465+ self .fpn_out .append (nn .Sequential (
466+ conv3x3_bn_relu (fpn_dim , fpn_dim , 1 ),
467+ conv3x3_bn_relu (fpn_dim , fpn_dim // 2 , 1 )
468+ ))
469+ self .fpn_out = nn .ModuleList (self .fpn_out )
470+
471+ self .conv_last = nn .Sequential (
472+ conv3x3_bn_relu (fpn_dim + (len (fpn_inplanes )- 1 ) * fpn_dim // 2 , fpn_dim , 1 ),
473+ nn .Conv2d (fpn_dim , num_class , kernel_size = 1 )
474+ )
475+
476+ def forward (self , conv_out , segSize = None ):
477+ conv5 = conv_out [- 1 ]
478+
479+ input_size = conv5 .size ()
480+ psp_out = [conv5 ]
481+ for pool_scale in self .psp :
482+ psp_out .append (nn .functional .upsample (
483+ pool_scale (conv5 ),
484+ (input_size [2 ], input_size [3 ]),
485+ mode = 'bilinear' ))
486+ psp_out = torch .cat (psp_out , 1 )
487+ f = self .psp_conv (psp_out )
488+
489+ fpn_feature_list = [f ]
490+ for i in reversed (range (len (conv_out ) - 1 )):
491+ conv_x = conv_out [i ]
492+ conv_x = self .fpn_in [i ](conv_x ) # lateral branch
493+
494+ f = nn .functional .upsample (f , size = conv_x .size ()[2 :], mode = 'bilinear' ) # top-down branch
495+ f = conv_x + f
496+
497+ fpn_feature_list .append (self .fpn_out [i ](f ))
498+
499+ fpn_feature_list .reverse () # [P2 - P5]
500+ output_size = fpn_feature_list [0 ].size ()[2 :]
501+ fusion_list = [fpn_feature_list [0 ]]
502+ for i in range (1 , len (fpn_feature_list )):
503+ fusion_list .append (nn .functional .upsample (
504+ fpn_feature_list [i ],
505+ output_size ,
506+ mode = 'bilinear' ))
507+ fusion_out = torch .cat (fusion_list , 1 )
508+ x = self .conv_last (fusion_out )
509+
510+ if self .use_softmax : # is True during inference
511+ x = nn .functional .upsample (x , size = segSize , mode = 'bilinear' )
512+ x = nn .functional .softmax (x , dim = 1 )
513+ return x
514+
515+ x = nn .functional .log_softmax (x , dim = 1 )
516+
517+ return x
0 commit comments