Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add PSPNet model #388

Closed
wants to merge 27 commits into from
Closed

Add PSPNet model #388

wants to merge 27 commits into from

Conversation

mitmul
Copy link
Member

@mitmul mitmul commented Aug 10, 2017

Please merge #494 and #392 first.
This PR uses Conv2DBNActiv with a ChainerMN communicator and CityscapesSemanticSegmentationDataset.

This PR adds PSPNet model definition and it's inference demo code.
This also includes convert.py to convert the official caffemodel to Chainer model.

  • Ensure the convert.py correctly works
  • Add tests
  • Add evaluation scripts to ensure the conversion correctness

@mitmul mitmul changed the title [WIP] Add PSPNet model Add PSPNet model Oct 4, 2017
Copy link
Member

@yuyu2172 yuyu2172 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comment.


with self.init_scope():
self.input_size = input_size
self.trunk = DilatedFCN(n_blocks=n_blocks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use extractor instead of trunk.
This is consistent with links such as FasterRCNN.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I changed it.

mid_stride = chainer.config.mid_stride
super(BottleneckConv, self).__init__()
with self.init_scope():
self.cbr1 = ConvBNReLU(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make the attribute name of ConvBN* consistent across the library.
We have Conv2DBNActiv, which will be used by ResNet implementations.

Since cbr is specific to Conv->BN->ReLU, I want to avoid the name.
How about conv1, conv2, ... ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I changed the attribute name to conv1, conv2, ...

return F.relu(h + x)


class ResBlock(chainer.ChainList):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge ResBlock and DilatedResBlock?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged.

h = F.max_pooling_2d(h, 3, 2, 1) # 1/4
h = self.res2(h)
h = self.res3(h) # 1/8
if chainer.config.train:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not always return h1 and h2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

h1 is only for calculating auxiliary loss. That is completely useless during inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict should be called during inference and not __call__. There is not much inconvenience returning it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I want to avoid using a global variable to configure behavior if possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I changed this part to return both h1 and h2 always.

if initialW is None:
chainer.config.initialW = chainer.initializers.HeNormal()
else:
chainer.config.initialW = initialW
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not using global variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced all such part to use a local variable given by the argument to the class.


# To calculate auxirally loss
if chainer.config.train:
self.cbr_aux = ConvBNReLU(None, 512, 3, 1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like to change the configuration of a model using global variable.
It would not be much problem to instantiate these links. Why not always instantiate them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's instantiate them always.

constructor. ``H, W`` is the input image size.

"""
if chainer.config.train:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add compute_aux option instead of changing behavior using chainer.config.train?
There can be a situation when a user wants to compute aux with chainer.config.train=False.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I'll change it.

img -= self.mean[:, None, None]
img = img.astype(np.float32, copy=False)
if self._use_pretrained_model:
# Pre-trained model is trained for BGR images
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you flip the order of in_channels of the weights in conversion code?
By doing so, this if statement can be deleted.
I want to make pspnet.py to be as simple as possible at the expense of more complex conversion code because people would rarely see the conversion code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll try to do it.

}
}

def __init__(self, n_class=None, input_size=None, n_blocks=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making PSPNet50 and PSPNet101, which inherit from PSPNet.
In my idea, PSPNet.__init__ would take extractor as input.

I have the designs of FasterRCNN and FasterRCNNVGG16 in my mind.


class ConvBNReLU(chainer.Chain):

def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make input variables/variable names similar to Conv2DBNActiv?

try:
from chainermn.links import MultiNodeBatchNormalization
except Exception:
warnings.warn('To perform batch normalization with multiple GPUs or '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of raising a warning during import, it should be raised when PSPNet is instantiated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this. Please make sure that the error is raised only when comm is specified.

Also, we should install the latest release.

@yuyu2172
Copy link
Member

yuyu2172 commented Oct 9, 2017

I ran predict and an error was raised.
I used Python2.

  File "/home/yuyu2172/projects/chainercv/chainercv/links/model/pspnet/pspnet.py", line 420, in _tile_predict           
    for yy in six.moves.xrange(hh):                                                                                     
TypeError: integer argument expected, got float                                                                         

@yuyu2172
Copy link
Member

yuyu2172 commented Oct 9, 2017

Can you make sure that predict does not alter the value of input images?
Currently, prepare changes the value because the input is not copied.


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_fn', '-f', type=str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use
parser.add_argument('image')
for parsing image path?
This is the style used by other demos.

parser.add_argument('--img_fn', '-f', type=str)
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--scales', '-s', type=float, nargs='*', default=None)
parser.add_argument('--out_dir', '-o', type=str, default='.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the demo to plot.show instead of plot.savefig?
As a consequence, could you remove this option?

parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--scales', '-s', type=float, nargs='*', default=None)
parser.add_argument('--out_dir', '-o', type=str, default='.')
parser.add_argument('--model', '-m', type=str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use model option when mentioning architectures of networks (e.g., SSD512 and SSD300).
Could you use pretrained_model as the name of option?

from chainercv.links import PSPNet
from chainercv.utils import read_image
from chainercv.visualizations import vis_image
from chainercv.visualizations import vis_label
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the function has been modified to vis_semantic_segmentation.


if args.model == 'voc2012':
model = PSPNet(pretrained_model='voc2012')
labels = voc_semantic_segmentation_label_names
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the name to label_names?
labels is (B, H, W) array in this context.

return pred


if __name__ == '__main__':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing the name of the file to predict_cityscapes.

@@ -1,3 +1,3 @@
[pep8]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should merge .pep8 to setup.cfg because this is how Chainer does it.

@@ -0,0 +1,3 @@
[flake8]
exclude = .eggs,*.egg,build,caffe_pb2.py,caffe_pb3.py,docs,.git

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete empty line.

try:
from chainermn.links import MultiNodeBatchNormalization
except Exception:
warnings.warn('To perform batch normalization with multiple GPUs or '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this. Please make sure that the error is raised only when comm is specified.

Also, we should install the latest release.

@yuyu2172
Copy link
Member

yuyu2172 commented Mar 6, 2018

@Hakuyume @mitmul

What do you think about reusing ResBlock for resnet for PSPNet?
This can be possible once we merge #494.

https://github.com/chainer/chainercv/blob/master/chainercv/links/model/resnet/resblock.py

@Hakuyume
Copy link
Member

Hakuyume commented Mar 6, 2018

👍 for reusing ResBlock in PSPNet if possible.

@mitmul
Copy link
Member Author

mitmul commented May 18, 2018

@yuyu2172 @Hakuyume Can I restart working on this? What's the situation around this PR?
Just updating this PR by following your reviews is fine?

@yuyu2172
Copy link
Member

I was actually working on it, but had not yet completed implementations.
I can finish the inference-part by this week. Is that OK?

https://github.com/yuyu2172/chainercv/tree/add-pspnet-infer-major-change

@yuyu2172
Copy link
Member

Reopening PR #610

@mitmul
Copy link
Member Author

mitmul commented May 22, 2018

@yuyu2172 Thanks, that's fine. I think it's better to close this PR. Thank you for handing over this.

@mitmul mitmul closed this May 22, 2018
@yuyu2172 yuyu2172 removed this from Stall in ChainerCV May 22, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants