/
model.py
72 lines (64 loc) · 3.93 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torch import nn
import torch.nn.functional as F
from .utils import *
from fastai.vision.models.unet import _get_sz_change_idxs, hook_outputs
from fastai.layers import init_default, ConvLayer
from fastai.callback.hook import model_sizes
def conv2d(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias=False, init=nn.init.kaiming_normal_):
"Create and initialize `nn.Conv2d` layer."
if padding is None: padding = ks // 2
return init_default(nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias), init)
class LateralUpsampleMerge(nn.Module):
"Merge the features coming from the downsample path (in `hook`) with the upsample path."
def __init__(self, ch, ch_lat, hook):
super().__init__()
self.hook = hook
self.conv_lat = conv2d(ch_lat, ch, ks=1, bias=True)
def forward(self, x):
return self.conv_lat(self.hook.stored) + F.interpolate(x, self.hook.stored.shape[-2:], mode='nearest')
class RetinaNet(nn.Module):
"Implements RetinaNet from https://arxiv.org/abs/1708.02002"
def __init__(self, encoder:nn.Module, n_classes, final_bias=0., chs=256, n_anchors=9, flatten=True):
super().__init__()
self.n_classes,self.flatten = n_classes,flatten
imsize = (256,256)
sfs_szs = model_sizes(encoder, size=imsize)
sfs_idxs = list(reversed(_get_sz_change_idxs(sfs_szs)))
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
self.encoder = encoder
self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
self.merges = nn.ModuleList([LateralUpsampleMerge(chs, sfs_szs[idx][1], hook)
for idx,hook in zip(sfs_idxs[-2:-4:-1], self.sfs[-2:-4:-1])])
self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs)
self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)
def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256):
"Helper function to create one of the subnet for regression/classification."
layers = [ConvLayer(chs, chs, bias=True, norm_type=None) for _ in range(n_conv)]
layers += [conv2d(chs, n_classes * n_anchors, bias=True)]
layers[-1].bias.data.zero_().add_(final_bias)
layers[-1].weight.data.fill_(0)
return nn.Sequential(*layers)
def _apply_transpose(self, func, p_states, n_classes):
#Final result of the classifier/regressor is bs * (k * n_anchors) * h * w
#We make it bs * h * w * n_anchors * k then flatten in bs * -1 * k so we can contenate
#all the results in bs * anchors * k (the non flatten version is there for debugging only)
if not self.flatten:
sizes = [[p.size(0), p.size(2), p.size(3)] for p in p_states]
return [func(p).permute(0,2,3,1).view(*sz,-1,n_classes) for p,sz in zip(p_states,sizes)]
else:
return torch.cat([func(p).permute(0,2,3,1).contiguous().view(p.size(0),-1,n_classes) for p in p_states],1)
def forward(self, x):
c5 = self.encoder(x)
p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]
p_states.append(self.p6top7(p_states[-1]))
for merge in self.merges: p_states = [merge(p_states[0])] + p_states
for i, smooth in enumerate(self.smoothers[:3]):
p_states[i] = smooth(p_states[i])
return [self._apply_transpose(self.classifier, p_states, self.n_classes),
self._apply_transpose(self.box_regressor, p_states, 4),
[[p.size(2), p.size(3)] for p in p_states]]
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()