diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..67cbbfc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Kaiyu Yue + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100755 index 0000000..fdb39eb --- /dev/null +++ b/README.md @@ -0,0 +1,152 @@ +# Compact Generalized Non-local Network + +By [Kaiyu Yue](http://kaiyuyue.com), [Ming Sun](https://msunming.github.io/), [Yuchen Yuan](https://scholar.google.com.au/citations?user=QJAr1KoAAAAJ&hl=en), [Feng Zhou](http://www.f-zhou.com/bio.html), Errui Ding and Fuxin Xu + +## Introduction + +This is a [PyTorch](https://pytorch.org/) re-implementation for the paper Compact Generalized Non-local (CGNL) Network [ADD LINK]. It brings the CGNL models trained on the [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200.html), [ImageNet](http://image-net.org/index) and [COCO](http://cocodataset.org/) based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) from FAIR. + +![introfig](img/intro.jpg) + +## Citation + +If you think this code is useful in your research or wish to refer to the baseline results published in our paper, please use the following BibTeX entry. + +``` +@article{CGNLNetwork2018, + title={Compact Generalized Non-local Network}, + author={Kaiyu Yue, Ming Sun, Yuchen Yuan, Feng Zhou, Errui Ding and Fuxin Xu}, + journal={NIPS}, + year={2018} +} +``` + +## Requirements + + * PyTorch >= 0.4.1 or 1.0 from a nightly release + * Python >= 3.5 + * torchvision >= 0.2.1 + * termcolor >= 1.1.0 + +## Environment + +The code is developed and tested under 8 Tesla P40 / V100-SXM2-16GB GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1. + +## Baselines and Main Results on CUB-200 Dataset + +| File ID | Model | Best Top-1 (%) | Top-5 (%) | Google Drive | Baidu Pan | +|:---------- |:-------------------- |:--------------:|:---------:|:------------:|:---------:| +| 1832260500 | R-50 Base | 86.45 | 97.00 | [`link`](https://drive.google.com/file/d/1nk1y6YE5A-jcuxC5r-hdnU0B5dny8Nf2/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1QB4GLaq0rY9C7psItj8IdA) | +| 1832260501 | R-50 w/ 1 NL Block | 86.69 | 96.95 | [`link`](https://drive.google.com/file/d/1r15lAheWPHyo4V9aitceyG6h-rEPwzun/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1UERXF48oM8MKRfhRFdrHFw) | +| 1832260502 | R-50 w/ 1 CGNL Block | 87.06 | 96.91 | [`link`](https://drive.google.com/file/d/1XND1fPWCzlYTsfiggiZqYhFREAZ-DVTK/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1onJlxgEhQxTVwUdHVmNXRA) | +| 1832261010 | R-101 Base | 86.76 | 96.91 | [`link`](https://drive.google.com/file/d/1mZPz9E3eEwmusEBD-zpjrL6IceM8kxOq/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1NWTbRny-Qs6_7Ts40RkeAQ) | +| 1832261011 | R-101 w/ 1 NL Block | 87.04 | 97.01 | [`link`](https://drive.google.com/file/d/1eYKqnFxSVfCM_5qeTeiEG69PteaQZ3DT/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1JFu5JRjNo9drQZkDeZzPxw) | +| 1832261012 | R-101 w/ 1 CGNL Block | 87.28 | 97.20 | [`link`](https://drive.google.com/file/d/1oiRzne2nhS4aUwLWr3DzS7ke3D6HYMZ2/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1iTMZPd2BuXcyZl56sEWUrA) | + +### Notes: + - The input size is 448. + - The CGNL block with dot production kernel is configured within 8 groups. + +| File ID | Model | Best Top-1 (%) | Top-5 (%) | Google Drive | Baidu Pan | +|:----------- |:---------------------- |:--------------:|:---------:|:------------:|:---------:| +| 1832260503x | R-50 w/ 1 CGNLx Block | 86.56 | 96.63 | [`link`](https://drive.google.com/file/d/1grkSRJfojIECZs85abvfelnxnVleSLlJ/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1RisiSxWMBjVDnhxeM2Ddaw) | +| 1832261013x | R-101 w/ 1 CGNLx Block | 87.18 | 97.03 | [`link`](https://drive.google.com/file/d/1AZPZk4IvmsbT6jcm0An_V249IbGK8ml6/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1FHhhGrj8qJe1bLMFeDClwQ) | + +### Notes: + - The input size is 448. + - The CGNLx block with Gaussian RBF \[[0](#reference)\]\[[1](#reference)\] kernel is configured within 8 groups. + - The Talyor Expansion order for the kernel function is 3. + +## Experiments on ImageNet Dataset + +| File ID | Model | Best Top-1 (%) | Top-5 (%) | Google Drive | Baidu Pan | +|:----------- |:-------------------- |:--------------:|:---------:|:------------:|:---------:| +| [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L12) | R-50 Base | 76.15 | 92.87 | - | - | +| 1832261502 | R-50 w/ 1 CGNL Block | 77.69 | 93.63 | [`link`](https://drive.google.com/file/d/1ezE6_tblZdoFZTYw24NJIaP_A5E0xduS/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1105xXRHxeU_gMC0Mv-EU4w) | +| 1832261503 | R-50 w/ 1 CGNLx Block | 77.32 | 93.40 | [`link`](https://drive.google.com/file/d/1HSqYZvL8EOiQO47dqasZTVpufLjwtg0-/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1GVQwUzDln3rCOBL8_ACJWg) | +| [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L14) | R-152 Base | 78.31 | 94.06 | - | - | +| 1832261522 | R-152 w/ 1 CGNL Block | 79.53 | 94.52 | [`link`](https://drive.google.com/file/d/1UKfwi3_Egj_pxFDcCV4sdQjFrb5dyDH2/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1tgJkWSfj9WlyqpXvRGDTig) | +| 1832261523 | R-152 w/ 1 CGNLx Block| 79.37 | 94.47 | [`link`](https://drive.google.com/file/d/14XsJta6FevDDvZG9evKAXeXj6dKG-lZJ/view?usp=sharing) | [`link`](https://pan.baidu.com/s/1kLBL6vsteq-wlj6XlIVqug) | + +### Notes: + - The input size is 224. + - The CGNL and CGNLx blocks are configured as same as above experiments on CUB-200. + +## Experiments on COCO based on Mask R-CNN in PyTorch 1.0 + +| backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time(hr) | inference time(s/im) | box AP | mask AP | model id | Google Drive | Baidu Pan | +| ----------------------- | ---- | -------- | -------- | ------------- | ------------------- | -------------------- | -------------------- | ------ | ------- |:--------:|:------------:|:----------:| +| R-50-C4 | Mask | 1x | 1 | 5.641 | 0.5434 | 27.3 | 0.18329 + 0.011 | 35.6 | 31.5 | [6358801](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_C4_1x.pth) | - | - | +| R-50-C4 w/ 1 CGNL Block | Mask | 1x | 1 | 5.868 | 0.5785 | 28.5 | 0.20326 + 0.008 | 36.3 | 32.1 | - | [`link`](https://drive.google.com/file/d/1n1hVs2r0FIbsiZyQAQkwxGYtYhpq5RjH/view?usp=sharing) | [`link`](https://pan.baidu.com/s/18Yl5MRGzniNpMxfwKUFu7g) | + +### Notes: + - The CGNL model is simply trained using the same experimental strategy as in [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark). It is configured as same as above experiments on CUB-200. + - If you want to add the `CGNL` / `CGNLx` / `NL` blocks to the backbone of Mask-RCNN models, you can use the `maskrcnn-benchmark/modeling/backbone/resnet.py` and `maskrcnn-benchmark/utils/c2_model_loading.py` to replace the original py-files. Please refer to the code for specific configurations. + - Due to some reasons of the Linux virtual environment or the data I/O speed, the numbers of `train time`, `total train time` and `inference time` in above table are both larger than the benchmarks. But this does not affect the demonstration of the efficiency of CGNL block. + +## Getting Start + +### Prepare Dataset + + - Download pytorch imagenet pretrained models from [pytorch model zoo](https://pytorch.org/docs/stable/model_zoo.html#module-torch.utils.model_zoo). The optional download links can be found in [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models). Put them in the `pretrained` folder. + + - Download the training and validation lists for CUB-200 dataset from [Google Drive](https://drive.google.com/file/d/1coklkMoDeFy-JHVhHtmbcrxFLjn7ndrJ/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1rhGlTxZWBn38j6JoDWPwwA). Download the ImageNet dataset and move validation images to labeled subfolders following the [tutorial](https://github.com/pytorch/examples/tree/master/imagenet). The training and validation lists can be found in [Google Drive](https://drive.google.com/file/d/1Qy1zYUUSY5v_13XcjcXYBhlqb2ofdw_v/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1QjNgBaAKOYUseW8ygC_-hw). Put them in the `data` folder and make them look like: + + ``` + ${THIS REPO ROOT} + `-- pretrained + |-- resnet50-19c8e357.pth + |-- resnet101-5d3b4d8f.pth + |-- resnet152-b121ed2d.pth + `-- data + `-- cub + `-- images + | |-- 001.Black_footed_Albatross + | |-- 002.Laysan_Albatross + | |-- ... + | |-- 200.Common_Yellowthroat + |-- cub_train.list + |-- cub_val.list + |-- images.txt + |-- image_class_labels.txt + |-- README + `-- imagenet + `-- img_train + | |-- n01440764 + | |-- n01734418 + | |-- ... + | |-- n15075141 + `-- img_val + | |-- n01440764 + | |-- n01734418 + | |-- ... + | |-- n15075141 + |-- imagenet_train.list + |-- imagenet_val.list + ``` + +### Perform Validating + +```bash +$ python train_val.py --arch '50' --dataset 'cub' --nl-type 'cgnl' --nl-num 1 --checkpoints ${FOLDER_DIR} --valid +``` + +### Perform Training Baselines + +```bash +$ python train_val.py --arch '50' --dataset 'cub' --nl-num 0 +``` + +### Perform Training NL and CGNL Networks + +```bash +$ python train_val.py --arch '50' --dataset 'cub' --nl-type 'cgnl' --nl-num 1 --warmup +``` + +## Reference + * \[0\] Y. Cui et al, [Kernel Pooling for Convolutional Neural Networks](http://openaccess.thecvf.com/content_cvpr_2017/html/Cui_Kernel_Pooling_for_CVPR_2017_paper.html), CVPR 2017. + * \[1\] T. Poggio et al, [Networks for Approximation and Learning](https://ieeexplore.ieee.org/document/58326), Proceedings of the IEEE 1990. + +## License + +This code is released under the MIT License. See [LICENSE](LICENSE) for additional details. \ No newline at end of file diff --git a/img/intro.jpg b/img/intro.jpg new file mode 100755 index 0000000..428b228 Binary files /dev/null and b/img/intro.jpg differ diff --git a/lib/dataloader.py b/lib/dataloader.py new file mode 100755 index 0000000..7196e58 --- /dev/null +++ b/lib/dataloader.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------- +# CGNL Network +# Copyright (c) 2018 Kaiyu Yue +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +"""Functions for dataloader +""" + +import os +import torch.utils.data as data +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class ImgLoader(data.Dataset): + def __init__(self, root, ann_file, transform=None, target_transform=None): + print('=> loading annotations from: ' + os.path.basename(ann_file) + ' ...') + self.root = root + with open(ann_file, 'r') as f: + self.imgs = f.readlines() + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index): + ls = self.imgs[index].strip().split() + img_path = ls[0] + target = int(ls[1]) + img = Image.open( + os.path.join(self.root, img_path)).convert('RGB') + return self.transform(img), target + + def __len__(self): + return len(self.imgs) + diff --git a/maskrcnn_benchmark/modeling/backbone/resnet.py b/maskrcnn_benchmark/modeling/backbone/resnet.py new file mode 100644 index 0000000..0881c96 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/resnet.py @@ -0,0 +1,615 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Variant of the resnet module that takes cfg as an argument. +Example usage. Strings may be specified in the config file. + model = ResNet( + "StemWithFixedBatchNorm", + "BottleneckWithFixedBatchNorm", + "ResNet50StagesTo4", + ) +Custom implementations may be written in user code and hooked in via the +`register_*` functions. +""" +from collections import namedtuple + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from termcolor import cprint + +from maskrcnn_benchmark.layers import FrozenBatchNorm2d +from maskrcnn_benchmark.layers import Conv2d + +# ResNet stage specification +StageSpec = namedtuple( + "StageSpec", + [ + "index", # Index of the stage, eg 1, 2, ..,. 5 + "block_count", # Numer of residual blocks in the stage + "return_features", # True => return the last feature map from this stage + ], +) + +# ----------------------------------------------------------------------------- +# Global cfgs for CGNL blocks +# ----------------------------------------------------------------------------- +nl_nums = 1 +nl_type = 'cgnlx' # cgnl | cgnlx | nl + +# ----------------------------------------------------------------------------- +# Standard ResNet models +# ----------------------------------------------------------------------------- +# ResNet-50 (including all stages) +ResNet50StagesTo5 = ( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) +) +# ResNet-50 up to stage 4 (excludes stage 5) +ResNet50StagesTo4 = ( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) +) +# ResNet-50-FPN (including all stages) +ResNet50FPNStagesTo5 = ( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) +) +# ResNet-101-FPN (including all stages) +ResNet101FPNStagesTo5 = ( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) +) + +class SpatialCGNL(nn.Module): + """Spatial CGNL block with dot production kernel for image classfication. + """ + def __init__(self, inplanes, planes, use_scale=False, groups=None): + self.use_scale = use_scale + self.groups = groups + + super(SpatialCGNL, self).__init__() + # conv theta + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv phi + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv g + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv z + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, + groups=self.groups, bias=False) + self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialCGNL block uses 'SCALE'", \ + 'yellow') + if self.groups: + cprint("=> WARN: SpatialCGNL block uses '{}' groups".format(self.groups), \ + 'yellow') + + def kernel(self, t, p, g, b, c, h, w): + """The linear kernel (dot production). + Args: + t: output of conv theata + p: output of conv phi + g: output of conv g + b: batch size + c: channels number + h: height of featuremaps + w: width of featuremaps + """ + t = t.view(b, 1, c * h * w) + p = p.view(b, 1, c * h * w) + g = g.view(b, c * h * w, 1) + + att = torch.bmm(p, g) + + if self.use_scale: + att = att.vid((c*h*w)**0.5) + + x = torch.bmm(att, t) + x = x.view(b, c, h, w) + + return x + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + if self.groups and self.groups > 1: + _c = int(c / self.groups) + + ts = torch.split(t, split_size_or_sections=_c, dim=1) + ps = torch.split(p, split_size_or_sections=_c, dim=1) + gs = torch.split(g, split_size_or_sections=_c, dim=1) + + _t_sequences = [] + for i in range(self.groups): + _x = self.kernel(ts[i], ps[i], gs[i], + b, _c, h, w) + _t_sequences.append(_x) + + x = torch.cat(_t_sequences, dim=1) + else: + x = self.kernel(t, p, g, + b, c, h, w) + + x = self.z(x) + x = self.gn(x) + residual + + return x + +class SpatialCGNLx(nn.Module): + """Spatial CGNL block with Gaussian RBF kernel for image classification. + """ + def __init__(self, inplanes, planes, use_scale=False, groups=None, order=2): + self.use_scale = use_scale + self.groups = groups + self.order = order + + super(SpatialCGNLx, self).__init__() + # conv theta + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv phi + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv g + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv z + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, + groups=self.groups, bias=False) + self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialCGNLx block uses 'SCALE'", \ + 'yellow') + if self.groups: + cprint("=> WARN: SpatialCGNLx block uses '{}' groups".format(self.groups), \ + 'yellow') + + cprint('=> WARN: The Taylor expansion order in SpatialCGNLx block is {}'.format(self.order), \ + 'yellow') + + def kernel(self, t, p, g, b, c, h, w): + """The non-linear kernel (Gaussian RBF). + Args: + t: output of conv theata + p: output of conv phi + g: output of conv g + b: batch size + c: channels number + h: height of featuremaps + w: width of featuremaps + """ + + t = t.view(b, 1, c * h * w) + p = p.view(b, 1, c * h * w) + g = g.view(b, c * h * w, 1) + + # gamma + gamma = torch.Tensor(1).fill_(1e-4) + + # beta + beta = torch.exp(-2 * gamma) + + t_taylor = [] + p_taylor = [] + for order in range(self.order+1): + # alpha + alpha = torch.mul( + torch.div( + torch.pow( + (2 * gamma), + order), + math.factorial(order)), + beta) + + alpha = torch.sqrt( + alpha.cuda()) + + _t = t.pow(order).mul(alpha) + _p = p.pow(order).mul(alpha) + + t_taylor.append(_t) + p_taylor.append(_p) + + t_taylor = torch.cat(t_taylor, dim=1) + p_taylor = torch.cat(p_taylor, dim=1) + + att = torch.bmm(p_taylor, g) + + if self.use_scale: + att = att.div((c*h*w)**0.5) + + att = att.view(b, 1, int(self.order+1)) + x = torch.bmm(att, t_taylor) + x = x.view(b, c, h, w) + + return x + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + if self.groups and self.groups > 1: + _c = int(c / self.groups) + + ts = torch.split(t, split_size_or_sections=_c, dim=1) + ps = torch.split(p, split_size_or_sections=_c, dim=1) + gs = torch.split(g, split_size_or_sections=_c, dim=1) + + _t_sequences = [] + for i in range(self.groups): + _x = self.kernel(ts[i], ps[i], gs[i], + b, _c, h, w) + _t_sequences.append(_x) + + x = torch.cat(_t_sequences, dim=1) + else: + x = self.kernel(t, p, g, + b, c, h, w) + + x = self.z(x) + x = self.gn(x) + residual + + return x + + +class SpatialNL(nn.Module): + """Spatial NL block for image classification. + [https://github.com/facebookresearch/video-nonlocal-net]. + """ + def __init__(self, inplanes, planes, use_scale=False): + self.use_scale = use_scale + + super(SpatialNL, self).__init__() + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.softmax = nn.Softmax(dim=2) + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, bias=False) + self.bn = nn.BatchNorm2d(inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialNL block uses 'SCALE' before softmax", 'yellow') + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + t = t.view(b, c, -1).permute(0, 2, 1) + p = p.view(b, c, -1) + g = g.view(b, c, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) + + if self.use_scale: + att = att.div(c**0.5) + + att = self.softmax(att) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1) + x = x.contiguous() + x = x.view(b, c, h, w) + + x = self.z(x) + x = self.bn(x) + residual + + return x + + +class ResNet(nn.Module): + def __init__(self, cfg): + super(ResNet, self).__init__() + + # If we want to use the cfg in forward(), then we should make a copy + # of it and store it for later use: + # self.cfg = cfg.clone() + + # Translate string names to implementations + stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] + stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] + transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] + + # Construct the stem module + self.stem = stem_module(cfg) + + # Constuct the specified ResNet stages + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + stage2_bottleneck_channels = num_groups * width_per_group + stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + self.stages = [] + self.return_features = {} + for stage_spec in stage_specs: + name = "layer" + str(stage_spec.index) + stage2_relative_factor = 2 ** (stage_spec.index - 1) + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + out_channels = stage2_out_channels * stage2_relative_factor + module = _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + stage_spec.block_count, + num_groups, + cfg.MODEL.RESNETS.STRIDE_IN_1X1, + first_stride=int(stage_spec.index > 1) + 1, + ) + in_channels = out_channels + self.add_module(name, module) + self.stages.append(name) + self.return_features[name] = stage_spec.return_features + + # Optionally freeze (requires_grad=False) parts of the backbone + self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) + + if nl_nums == 1: + for name, m in self._modules['layer3'][-2].named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, mean=0, std=0.01) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + + def _freeze_backbone(self, freeze_at): + for stage_index in range(freeze_at): + if stage_index == 0: + m = self.stem # stage 0 is the stem + else: + m = getattr(self, "layer" + str(stage_index)) + for p in m.parameters(): + p.requires_grad = False + + def forward(self, x): + outputs = [] + x = self.stem(x) + for stage_name in self.stages: + x = getattr(self, stage_name)(x) + if self.return_features[stage_name]: + outputs.append(x) + return outputs + + +class ResNetHead(nn.Module): + def __init__( + self, + block_module, + stages, + num_groups=1, + width_per_group=64, + stride_in_1x1=True, + stride_init=None, + res2_out_channels=256, + ): + super(ResNetHead, self).__init__() + + stage2_relative_factor = 2 ** (stages[0].index - 1) + stage2_bottleneck_channels = num_groups * width_per_group + out_channels = res2_out_channels * stage2_relative_factor + in_channels = out_channels // 2 + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + + block_module = _TRANSFORMATION_MODULES[block_module] + + self.stages = [] + stride = stride_init + for stage in stages: + name = "layer" + str(stage.index) + if not stride: + stride = int(stage.index > 1) + 1 + module = _make_stage( + block_module, + in_channels, + bottleneck_channels, + out_channels, + stage.block_count, + num_groups, + stride_in_1x1, + first_stride=stride, + ) + stride = None + self.add_module(name, module) + self.stages.append(name) + + def forward(self, x): + for stage in self.stages: + x = getattr(self, stage)(x) + return x + + +def _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + block_count, + num_groups, + stride_in_1x1, + first_stride, +): + blocks = [] + stride = first_stride + + for idx in range(block_count): + if idx == 5 and block_count == 6: + # print(in_channels, bottleneck_channels, out_channels, num_groups) + # 1024 256 1024 1 + if nl_type == 'nl': + blocks.append(SpatialNL( + in_channels, + int(in_channels/2), + use_scale=True)) + elif nl_type == 'cgnl': + blocks.append(SpatialCGNL( + in_channels, + int(in_channels/2), + use_scale=False, + groups=8)) + elif nl_type == 'cgnlx': + blocks.append(SpatialCGNLx( + in_channels, + int(in_channels/2), + use_scale=False, + groups=8, + order=3)) + else: + pass + + blocks.append( + transformation_module( + in_channels, + bottleneck_channels, + out_channels, + num_groups, + stride_in_1x1, + stride, + ) + ) + stride = 1 + in_channels = out_channels + return nn.Sequential(*blocks) + + +class BottleneckWithFixedBatchNorm(nn.Module): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + ): + super(BottleneckWithFixedBatchNorm, self).__init__() + + self.downsample = None + if in_channels != out_channels: + self.downsample = nn.Sequential( + Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, bias=False + ), + FrozenBatchNorm2d(out_channels), + ) + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + ) + self.bn1 = FrozenBatchNorm2d(bottleneck_channels) + # TODO: specify init for the above + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1, + bias=False, + groups=num_groups, + ) + self.bn2 = FrozenBatchNorm2d(bottleneck_channels) + + self.conv3 = Conv2d( + bottleneck_channels, out_channels, kernel_size=1, bias=False + ) + self.bn3 = FrozenBatchNorm2d(out_channels) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu_(out) + + out = self.conv2(out) + out = self.bn2(out) + out = F.relu_(out) + + out0 = self.conv3(out) + out = self.bn3(out0) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = F.relu_(out) + + return out + + +class StemWithFixedBatchNorm(nn.Module): + def __init__(self, cfg): + super(StemWithFixedBatchNorm, self).__init__() + + out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + + self.conv1 = Conv2d( + 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = FrozenBatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +_TRANSFORMATION_MODULES = {"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm} + +_STEM_MODULES = {"StemWithFixedBatchNorm": StemWithFixedBatchNorm} + +_STAGE_SPECS = { + "R-50-C4": ResNet50StagesTo4, + "R-50-C5": ResNet50StagesTo5, + "R-50-FPN": ResNet50FPNStagesTo5, + "R-101-FPN": ResNet101FPNStagesTo5, +} + + +def register_transformation_module(module_name, module): + _register_generic(_TRANSFORMATION_MODULES, module_name, module) + + +def register_stem_module(module_name, module): + _register_generic(_STEM_MODULES, module_name, module) + + +def register_stage_spec(stage_spec_name, stage_spec): + _register_generic(_STAGE_SPECS, stage_spec_name, stage_spec) + + +def _register_generic(module_dict, module_name, module): + assert module_name not in module_dict + module_dict[module_name] = module diff --git a/maskrcnn_benchmark/utils/c2_model_loading.py b/maskrcnn_benchmark/utils/c2_model_loading.py new file mode 100644 index 0000000..e9bd3aa --- /dev/null +++ b/maskrcnn_benchmark/utils/c2_model_loading.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import pickle +from collections import OrderedDict + +import torch + +from maskrcnn_benchmark.utils.model_serialization import load_state_dict + +# ----------------------------------------------------------------------------- +# Global cfgs for CGNL blocks +# ----------------------------------------------------------------------------- +nl_nums = 1 +nl_layer_id = 5 # r-50: 5 | r-101: 22 | r-152: 35 + + +def _rename_basic_resnet_weights(layer_keys): + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [k.replace(".w", ".weight") for k in layer_keys] + layer_keys = [k.replace(".bn", "_bn") for k in layer_keys] + layer_keys = [k.replace(".b", ".bias") for k in layer_keys] + layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys] + layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys] + layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys] + layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys] + layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys] + + # RPN / Faster RCNN + layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys] + layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys] + + # Affine-Channel -> BatchNorm enaming + layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys] + + # Make torchvision-compatible + layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys] + + layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys] + layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys] + layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys] + layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys] + + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys] + + layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] + layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] + + # Make CGNL networks compatible + for idx, k in enumerate(layer_keys): + ks = k.split('.') + layer_name = '.'.join(ks[0:2]) + if nl_nums == 1 and \ + layer_name == 'layer3.{}'.format(nl_layer_id): + ks[1] = str(int(ks[1]) + 1) + k = '.'.join(ks) + # rename the key + layer_keys[idx] = k + + return layer_keys + +def _rename_fpn_weights(layer_keys, stage_names): + for mapped_idx, stage_name in enumerate(stage_names, 1): + suffix = "" + if mapped_idx < 4: + suffix = ".lateral" + layer_keys = [ + k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys + ] + layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys] + + + layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [ + k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys + ] + + return layer_keys + + +def _rename_weights_for_resnet(weights, stage_names): + original_keys = sorted(weights.keys()) + layer_keys = sorted(weights.keys()) + + # for X-101, rename output to fc1000 to avoid conflicts afterwards + layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys] + layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys] + + # performs basic renaming: _ -> . , etc + layer_keys = _rename_basic_resnet_weights(layer_keys) + + # FPN + layer_keys = _rename_fpn_weights(layer_keys, stage_names) + + # Mask R-CNN + layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys] + layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys] + layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys] + + # Keypoint R-CNN + layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys] + layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys] + layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys] + + # Rename for our RPN structure + layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys] + + key_map = {k: v for k, v in zip(original_keys, layer_keys)} + + logger = logging.getLogger(__name__) + logger.info("Remapping C2 weights") + max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k]) + + new_weights = OrderedDict() + for k in original_keys: + v = weights[k] + if "_momentum" in k: + continue + # if 'fc1000' in k: + # continue + w = torch.from_numpy(v) + # if "bn" in k: + # w = w.view(1, -1, 1, 1) + logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k])) + new_weights[key_map[k]] = w + + return new_weights + + +def _load_c2_pickled_weights(file_path): + with open(file_path, "rb") as f: + if torch._six.PY3: + data = pickle.load(f, encoding="latin1") + else: + data = pickle.load(f) + if "blobs" in data: + weights = data["blobs"] + else: + weights = data + return weights + + +_C2_STAGE_NAMES = { + "R-50": ["1.2", "2.3", "3.5", "4.2"], + "R-101": ["1.2", "2.3", "3.22", "4.2"], +} + +def load_c2_format(cfg, f): + # TODO make it support other architectures + state_dict = _load_c2_pickled_weights(f) + conv_body = cfg.MODEL.BACKBONE.CONV_BODY + arch = conv_body.replace("-C4", "").replace("-FPN", "") + stages = _C2_STAGE_NAMES[arch] + state_dict = _rename_weights_for_resnet(state_dict, stages) + return dict(model=state_dict) diff --git a/model/resnet.py b/model/resnet.py new file mode 100755 index 0000000..b9fe1a4 --- /dev/null +++ b/model/resnet.py @@ -0,0 +1,532 @@ +# -------------------------------------------------------- +# CGNL Network +# Copyright (c) 2018 Kaiyu Yue +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +"""Functions for model building. + Based on https://github.com/pytorch/vision +""" + +import math +import torch +import torch.nn as nn + +from termcolor import cprint +from collections import OrderedDict + +__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] + + +def model_hub(arch, pretrained=True, nl_type=None, nl_nums=None, + pool_size=7): + """Model hub. + """ + if arch == '50': + return resnet50(pretrained=pretrained, + nl_type=nl_type, + nl_nums=nl_nums, + pool_size=pool_size) + elif arch == '101': + return resnet101(pretrained=pretrained, + nl_type=nl_type, + nl_nums=nl_nums, + pool_size=pool_size) + elif arch == '152': + return resnet152(pretrained=pretrained, + nl_type=nl_type, + nl_nums=nl_nums, + pool_size=pool_size) + else: + raise NameError("The arch '{}' is not supported yet in this repo. \ + You can add it by yourself.".format(arch)) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding. + """ + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SpatialCGNL(nn.Module): + """Spatial CGNL block with dot production kernel for image classfication. + """ + def __init__(self, inplanes, planes, use_scale=False, groups=None): + self.use_scale = use_scale + self.groups = groups + + super(SpatialCGNL, self).__init__() + # conv theta + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv phi + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv g + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv z + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, + groups=self.groups, bias=False) + self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialCGNL block uses 'SCALE'", \ + 'yellow') + if self.groups: + cprint("=> WARN: SpatialCGNL block uses '{}' groups".format(self.groups), \ + 'yellow') + + def kernel(self, t, p, g, b, c, h, w): + """The linear kernel (dot production). + + Args: + t: output of conv theata + p: output of conv phi + g: output of conv g + b: batch size + c: channels number + h: height of featuremaps + w: width of featuremaps + """ + t = t.view(b, 1, c * h * w) + p = p.view(b, 1, c * h * w) + g = g.view(b, c * h * w, 1) + + att = torch.bmm(p, g) + + if self.use_scale: + att = att.div((c*h*w)**0.5) + + x = torch.bmm(att, t) + x = x.view(b, c, h, w) + + return x + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + if self.groups and self.groups > 1: + _c = int(c / self.groups) + + ts = torch.split(t, split_size_or_sections=_c, dim=1) + ps = torch.split(p, split_size_or_sections=_c, dim=1) + gs = torch.split(g, split_size_or_sections=_c, dim=1) + + _t_sequences = [] + for i in range(self.groups): + _x = self.kernel(ts[i], ps[i], gs[i], + b, _c, h, w) + _t_sequences.append(_x) + + x = torch.cat(_t_sequences, dim=1) + else: + x = self.kernel(t, p, g, + b, c, h, w) + + x = self.z(x) + x = self.gn(x) + residual + + return x + + +class SpatialCGNLx(nn.Module): + """Spatial CGNL block with Gaussian RBF kernel for image classification. + """ + def __init__(self, inplanes, planes, use_scale=False, groups=None, order=2): + self.use_scale = use_scale + self.groups = groups + self.order = order + + super(SpatialCGNLx, self).__init__() + # conv theta + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv phi + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv g + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + # conv z + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, + groups=self.groups, bias=False) + self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialCGNLx block uses 'SCALE'", \ + 'yellow') + if self.groups: + cprint("=> WARN: SpatialCGNLx block uses '{}' groups".format(self.groups), \ + 'yellow') + + cprint('=> WARN: The Taylor expansion order in SpatialCGNLx block is {}'.format(self.order), \ + 'yellow') + + def kernel(self, t, p, g, b, c, h, w): + """The non-linear kernel (Gaussian RBF). + + Args: + t: output of conv theata + p: output of conv phi + g: output of conv g + b: batch size + c: channels number + h: height of featuremaps + w: width of featuremaps + """ + + t = t.view(b, 1, c * h * w) + p = p.view(b, 1, c * h * w) + g = g.view(b, c * h * w, 1) + + # gamma + gamma = torch.Tensor(1).fill_(1e-4) + + # beta + beta = torch.exp(-2 * gamma) + + t_taylor = [] + p_taylor = [] + for order in range(self.order+1): + # alpha + alpha = torch.mul( + torch.div( + torch.pow( + (2 * gamma), + order), + math.factorial(order)), + beta) + + alpha = torch.sqrt( + alpha.cuda()) + + _t = t.pow(order).mul(alpha) + _p = p.pow(order).mul(alpha) + + t_taylor.append(_t) + p_taylor.append(_p) + + t_taylor = torch.cat(t_taylor, dim=1) + p_taylor = torch.cat(p_taylor, dim=1) + + att = torch.bmm(p_taylor, g) + + if self.use_scale: + att = att.div((c*h*w)**0.5) + + att = att.view(b, 1, int(self.order+1)) + x = torch.bmm(att, t_taylor) + x = x.view(b, c, h, w) + + return x + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + if self.groups and self.groups > 1: + _c = int(c / self.groups) + + ts = torch.split(t, split_size_or_sections=_c, dim=1) + ps = torch.split(p, split_size_or_sections=_c, dim=1) + gs = torch.split(g, split_size_or_sections=_c, dim=1) + + _t_sequences = [] + for i in range(self.groups): + _x = self.kernel(ts[i], ps[i], gs[i], + b, _c, h, w) + _t_sequences.append(_x) + + x = torch.cat(_t_sequences, dim=1) + else: + x = self.kernel(t, p, g, + b, c, h, w) + + x = self.z(x) + x = self.gn(x) + residual + + return x + + +class SpatialNL(nn.Module): + """Spatial NL block for image classification. + [https://github.com/facebookresearch/video-nonlocal-net]. + """ + def __init__(self, inplanes, planes, use_scale=False): + self.use_scale = use_scale + + super(SpatialNL, self).__init__() + self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) + self.softmax = nn.Softmax(dim=2) + self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, bias=False) + self.bn = nn.BatchNorm2d(inplanes) + + if self.use_scale: + cprint("=> WARN: SpatialNL block uses 'SCALE' before softmax", 'yellow') + + def forward(self, x): + residual = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + b, c, h, w = t.size() + + t = t.view(b, c, -1).permute(0, 2, 1) + p = p.view(b, c, -1) + g = g.view(b, c, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) + + if self.use_scale: + att = att.div(c**0.5) + + att = self.softmax(att) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1) + x = x.contiguous() + x = x.view(b, c, h, w) + + x = self.z(x) + x = self.bn(x) + residual + + return x + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, + nl_type=None, nl_nums=None, pool_size=7): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + + if not nl_nums: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + nl_type=nl_type, nl_nums=nl_nums) + + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(pool_size, stride=1) + self.dropout = nn.Dropout(0.5) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if nl_nums == 1: + for name, m in self._modules['layer3'][-2].named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, mean=0, std=0.01) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, nl_type=None, nl_nums=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + if (i == 5 and blocks == 6) or \ + (i == 22 and blocks == 23) or \ + (i == 35 and blocks == 36): + if nl_type == 'nl': + layers.append(SpatialNL( + self.inplanes, + int(self.inplanes/2), + use_scale=True)) + elif nl_type == 'cgnl': + layers.append(SpatialCGNL( + self.inplanes, + int(self.inplanes/2), + use_scale=False, + groups=8)) + elif nl_type == 'cgnlx': + layers.append(SpatialCGNLx( + self.inplanes, + int(self.inplanes/2), + use_scale=False, + groups=8, + order=3)) + else: + pass + + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.fc(x) + + return x + + +def load_partial_weight(model, pretrained, nl_nums, nl_layer_id): + """Loads the partial weights for NL/CGNL network. + """ + _pretrained = pretrained + _model_dict = model.state_dict() + _pretrained_dict = OrderedDict() + for k, v in _pretrained.items(): + ks = k.split('.') + layer_name = '.'.join(ks[0:2]) + if nl_nums == 1 and \ + layer_name == 'layer3.{}'.format(nl_layer_id): + ks[1] = str(int(ks[1]) + 1) + k = '.'.join(ks) + _pretrained_dict[k] = v + _model_dict.update(_pretrained_dict) + return _model_dict + + +def resnet50(pretrained=False, nl_type=None, nl_nums=None, **kwargs): + """Constructs a ResNet-50 model. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], + nl_type=nl_type, nl_nums=nl_nums, **kwargs) + if pretrained: + _pretrained = torch.load('pretrained/resnet50-19c8e357.pth') + _model_dict = load_partial_weight(model, _pretrained, nl_nums, 5) + model.load_state_dict(_model_dict) + return model + + +def resnet101(pretrained=False, nl_type=None, nl_nums=None, **kwargs): + """Constructs a ResNet-101 model. + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], + nl_type=nl_type, nl_nums=nl_nums, **kwargs) + if pretrained: + _pretrained = torch.load('pretrained/resnet101-5d3b4d8f.pth') + _model_dict = load_partial_weight(model, _pretrained, nl_nums, 22) + model.load_state_dict(_model_dict) + return model + + +def resnet152(pretrained=False, nl_type=None, nl_nums=None, **kwargs): + """Constructs a ResNet-152 model. + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], + nl_type=nl_type, nl_nums=nl_nums, **kwargs) + if pretrained: + _pretrained = torch.load('pretrained/resnet152-b121ed2d.pth') + _model_dict = load_partial_weight(model, _pretrained, nl_nums, 35) + model.load_state_dict(_model_dict) + return model diff --git a/train_val.py b/train_val.py new file mode 100755 index 0000000..2fb7372 --- /dev/null +++ b/train_val.py @@ -0,0 +1,369 @@ +# -------------------------------------------------------- +# CGNL Network +# Copyright (c) 2018 Kaiyu Yue +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import os +import argparse +import time +import shutil +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.parallel + +from torchvision import transforms +from termcolor import cprint +from lib import dataloader +from model import resnet + +# torch version +cprint('=> Torch Vresion: ' + torch.__version__, 'green') + +# args +parser = argparse.ArgumentParser(description='PyTorch Training') +parser.add_argument('--debug', '-d', dest='debug', action='store_true', + help='enable debug mode') +parser.add_argument('--warmup', '-w', dest='warmup', action='store_true', + help='using warmup strategy') +parser.add_argument('--print-freq', '-p', default=1, type=int, metavar='N', + help='print frequency (default: 10)') +parser.add_argument('--nl-nums', default=0, type=int, metavar='N', + help='number of the NL | CGNL block (default: 0)') +parser.add_argument('--nl-type', default=None, type=str, + help='choose NL | CGNL | CGNLx block to add (default: None)') +parser.add_argument('--arch', default='50', type=str, + help='the depth of resnet (default: 50)') +parser.add_argument('--valid', '-v', dest='valid', + action='store_true', help='just run validation') +parser.add_argument('--checkpoints', default='', type=str, + help='the dir of checkpoints') +parser.add_argument('--dataset', default='cub', type=str, + help='cub | imagenet (default: cub)') +parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', + help='initial learning rate (default: 0.01)') + +best_prec1 = 0 +best_prec5 = 0 + +def main(): + global args + global best_prec1, best_prec5 + + args = parser.parse_args() + + # simple args + debug = args.debug + if debug: cprint('=> WARN: Debug Mode', 'yellow') + + dataset = args.dataset + num_classes = 200 if dataset == 'cub' else 1000 + base_size = 512 if dataset == 'cub' else 256 + pool_size = 14 if base_size == 512 else 7 + workers = 0 if debug else 8 + batch_size = 2 if debug else 256 + if base_size == 512 and \ + args.arch == '152': + batch_size = 128 + drop_ratio = 0.1 + lr_drop_epoch_list = [31, 61, 81] + epochs = 100 + eval_freq = 1 + gpu_ids = [0] if debug else [0,1,2,3,4,5,6,7] + crop_size = 224 if base_size == 256 else 448 + + # args for the nl and cgnl block + arch = args.arch + nl_type = args.nl_type # 'cgnl' | 'cgnlx' | 'nl' + nl_nums = args.nl_nums # 1: stage res4 + + # warmup setting + WARMUP_LRS = [args.lr * (drop_ratio**len(lr_drop_epoch_list)), args.lr] + WARMUP_EPOCHS = 10 + + # data loader + if dataset == 'cub': + data_root = 'data/cub' + imgs_fold = os.path.join(data_root, 'images') + train_ann_file = os.path.join(data_root, 'cub_train.list') + valid_ann_file = os.path.join(data_root, 'cub_val.list') + elif dataset == 'imagenet': + data_root = 'data/imagenet' + imgs_fold = os.path.join(data_root) + train_ann_file = os.path.join(data_root, 'imagenet_train.list') + valid_ann_file = os.path.join(data_root, 'imagenet_val.list') + else: + raise NameError("WARN: The dataset '{}' is not supported yet.") + + train_dataset = dataloader.ImgLoader( + root = imgs_fold, + ann_file = train_ann_file, + transform = transforms.Compose([ + transforms.RandomResizedCrop( + size=crop_size, scale=(0.08, 1.25)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + + val_dataset = dataloader.ImgLoader( + root = imgs_fold, + ann_file = valid_ann_file, + transform = transforms.Compose([ + transforms.Resize(base_size), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size = batch_size, + shuffle = True, + num_workers = workers, + pin_memory = True) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size = batch_size, + shuffle = False, + num_workers = workers, + pin_memory = True) + + # build model + model = resnet.model_hub(arch, + pretrained=True, + nl_type=nl_type, + nl_nums=nl_nums, + pool_size=pool_size) + + # change the fc layer + model._modules['fc'] = torch.nn.Linear(in_features=2048, + out_features=num_classes) + torch.nn.init.kaiming_normal_(model._modules['fc'].weight, + mode='fan_out', nonlinearity='relu') + print(model) + + # parallel + model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + # optimizer + optimizer = torch.optim.SGD( + model.parameters(), + args.lr, + momentum=0.9, + weight_decay=1e-4) + + # cudnn + cudnn.benchmark = True + + # warmup + if args.warmup: + epochs += WARMUP_EPOCHS + lr_drop_epoch_list = list( + np.array(lr_drop_epoch_list) + WARMUP_EPOCHS) + cprint('=> WARN: warmup is used in the first {} epochs'.format( + WARMUP_EPOCHS), 'yellow') + + # valid + if args.valid: + cprint('=> WARN: Validation Mode', 'yellow') + print('start validation ...') + checkpoint_fold = args.checkpoints + checkpoint_best = os.path.join(checkpoint_fold, 'model_best.pth.tar') + print('=> loading state_dict from {}'.format(checkpoint_best)) + model.load_state_dict( + torch.load(checkpoint_best)['state_dict']) + prec1, prec5 = validate(val_loader, model, criterion) + print(' * Final Accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(prec1, prec5)) + exit(0) + + # train + print('start training ...') + for epoch in range(0, epochs): + current_lr = adjust_learning_rate(optimizer, drop_ratio, epoch, lr_drop_epoch_list, + WARMUP_EPOCHS, WARMUP_LRS) + # train one epoch + train(train_loader, model, criterion, optimizer, epoch, epochs, current_lr) + + if nl_nums > 0: + checkpoint_name = '{}-r-{}-w-{}{}-block.pth.tar'.format(dataset, arch, nl_nums, nl_type) + else: + checkpoint_name = '{}-r-{}-base.pth.tar'.format(dataset, arch) + + if (epoch + 1) % eval_freq == 0: + prec1, prec5 = validate(val_loader, model, criterion) + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + best_prec5 = max(prec5, best_prec5) + print(' * Best accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(best_prec1, best_prec5)) + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), + }, is_best, filename=checkpoint_name) + + +def train(train_loader, model, criterion, optimizer, epoch, epochs, current_lr): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + target = target.cuda(non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Epoch: [{0:3d}/{1:3d}][{2:3d}/{3:3d}]\t' + 'LR: {lr:.7f}\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, epochs, i, len(train_loader), + lr=current_lr, batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg, top5.avg + + +def adjust_learning_rate(optimizer, drop_ratio, epoch, lr_drop_epoch_list, + WARMUP_EPOCHS, WARMUP_LRS): + if args.warmup and epoch < WARMUP_EPOCHS: + # achieve the warmup lr + lrs = np.linspace(WARMUP_LRS[0], WARMUP_LRS[1], num=WARMUP_EPOCHS) + cprint('=> warmup lrs {}'.format(lrs), 'green') + for param_group in optimizer.param_groups: + param_group['lr'] = lrs[epoch] + current_lr = lrs[epoch] + else: + decay = drop_ratio if epoch in lr_drop_epoch_list else 1.0 + for param_group in optimizer.param_groups: + param_group['lr'] = args.lr * decay + args.lr *= decay + current_lr = args.lr + return current_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k + """ + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +if __name__ == '__main__': + main()