This repository has been archived by the owner on Aug 18, 2020. It is now read-only.
/
unet.py
94 lines (83 loc) · 4.42 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/15a_vision.models.unet.ipynb (unless otherwise specified).
__all__ = ['UnetBlock', 'ResizeToOrig', 'DynamicUnet']
# Cell
from ...torch_basics import *
from ...callback.hook import *
# Cell
def _get_sz_change_idxs(sizes):
"Get the indexes of the layers where the size of the activation changes."
feature_szs = [size[-1] for size in sizes]
sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
if feature_szs[0] != feature_szs[1]: sz_chg_idxs = [0] + sz_chg_idxs
return sz_chg_idxs
# Cell
class UnetBlock(Module):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
@delegates(ConvLayer.__init__)
def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
self.hook = hook
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
self.bn = BatchNorm(x_in_c)
ni = up_in_c//2 + x_in_c
nf = ni if final_div else ni//2
self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
xtra=SelfAttention(nf) if self_attention else None, **kwargs)
self.relu = act_cls()
apply_init(nn.Sequential(self.conv1, self.conv2), init)
def forward(self, up_in):
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))
# Cell
class ResizeToOrig(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, mode='nearest'): self.mode = mode
def forward(self, x):
if x.orig.shape[-2:] != x.shape[-2:]:
x = F.interpolate(x, x.orig.shape[-2:], mode=self.mode)
return x
# Cell
class DynamicUnet(SequentialEx):
"Create a U-Net from a given architecture."
def __init__(self, encoder, n_classes, img_size, blur=False, blur_final=True, self_attention=False,
y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
imsize = img_size
sizes = model_sizes(encoder, size=imsize)
sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()
ni = sizes[-1][1]
middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
x = middle_conv(x)
layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
for i,idx in enumerate(sz_chg_idxs):
not_final = i!=len(sz_chg_idxs)-1
up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i==len(sz_chg_idxs)-3)
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
layers.append(unet_block)
x = unet_block(x)
ni = x.shape[1]
if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
layers.append(ResizeToOrig())
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
layers += [ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
apply_init(nn.Sequential(layers[3], layers[-2]), init)
#apply_init(nn.Sequential(layers[2]), init)
if y_range is not None: layers.append(SigmoidRange(*y_range))
super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()