/
nb_006a.py
81 lines (68 loc) · 2.91 KB
/
nb_006a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
# file to edit: dev_nb/006a_unet.ipynb
from nb_006 import *
import gc
Sizes = List[List[int]]
def in_channels(m:Model) -> List[int]:
"Returns the shape of the first weight layer"
for l in flatten_model(m):
if hasattr(l, 'weight'): return l.weight.shape[1]
raise Exception('No weight layer')
def model_sizes(m:Model, size:tuple=(256,256), full:bool=True) -> Tuple[Sizes,Tensor,Hooks]:
"Passes a dummy input through the model to get the various sizes"
hooks = hook_outputs(m)
ch_in = in_channels(m)
x = torch.zeros(1,ch_in,*size)
x = m.eval()(x)
res = [o.stored.shape for o in hooks]
if not full: hooks.remove()
return res,x,hooks if full else res
def get_sfs_idxs(sizes:Sizes, last:bool=True) -> List[int]:
"Get the indexes of the layers where the size of the activation changes"
if last:
feature_szs = [size[-1] for size in sizes]
sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
else: sfs_idxs = list(range(len(sfs)))
return sfs_idxs
class UnetBlock(nn.Module):
"An basic unet block"
def __init__(self, up_in_c:int, x_in_c:int, hook:Hook):
super().__init__()
self.hook = hook
ni = up_in_c
self.upconv = conv2d_trans(ni, ni//2) # H, W -> 2H, 2W
ni = ni//2 + x_in_c
self.conv1 = conv2d(ni, ni//2)
ni = ni//2
self.conv2 = conv2d(ni, ni)
self.bn = nn.BatchNorm2d(ni)
def forward(self, up_in:Tensor) -> Tensor:
up_out = self.upconv(up_in)
cat_x = torch.cat([up_out, self.hook.stored], dim=1)
x = F.relu(self.conv1(cat_x))
x = F.relu(self.conv2(x))
return self.bn(x)
class DynamicUnet(nn.Sequential):
"Unet created from a given architecture"
def __init__(self, encoder:Model, n_classes:int, last:bool=True):
imsize = (256,256)
sfs_szs,x,self.sfs = model_sizes(encoder, size=imsize)
sfs_idxs = reversed(get_sfs_idxs(sfs_szs, last))
ni = sfs_szs[-1][1]
middle_conv = nn.Sequential(conv2d_relu(ni, ni*2, bn=True), conv2d_relu(ni*2, ni, bn=True))
x = middle_conv(x)
layers = [encoder, nn.ReLU(), middle_conv]
for idx in sfs_idxs:
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[idx])
layers.append(unet_block)
x = unet_block(x)
ni = unet_block.conv2.out_channels
if imsize != sfs_szs[0][-2:]: layers.append(conv2d_trans(ni, ni))
layers.append(conv2d(ni, n_classes, 1))
super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()