This repository has been archived by the owner on Apr 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet.py
124 lines (94 loc) · 5.62 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# --------------------------------------------------------------------------
# Source: https://github.com/UdonDa/3D-UNet-PyTorch/blob/master/src/model.py
# --------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_block_3d(in_dim, out_dim, activation):
return nn.Sequential(
nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_dim),
activation,)
def conv_trans_block_3d(in_dim, out_dim, activation):
return nn.Sequential(
nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm3d(out_dim),
activation,)
def max_pooling_3d():
return nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
def conv_block_2_3d(in_dim, out_dim, activation):
return nn.Sequential(
conv_block_3d(in_dim, out_dim, activation),
nn.Conv3d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_dim),)
class UNet(nn.Module):
def __init__(self, in_dim, out_dim, num_filters):
super(UNet, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_filters = num_filters
activation = nn.LeakyReLU(0.2, inplace=True)
# Down sampling
self.down_1 = conv_block_2_3d(self.in_dim, self.num_filters, activation)
self.pool_1 = max_pooling_3d()
self.down_2 = conv_block_2_3d(self.num_filters, self.num_filters * 2, activation)
self.pool_2 = max_pooling_3d()
self.down_3 = conv_block_2_3d(self.num_filters * 2, self.num_filters * 4, activation)
self.pool_3 = max_pooling_3d()
self.down_4 = conv_block_2_3d(self.num_filters * 4, self.num_filters * 8, activation)
self.pool_4 = max_pooling_3d()
self.down_5 = conv_block_2_3d(self.num_filters * 8, self.num_filters * 16, activation)
self.pool_5 = max_pooling_3d()
# Bridge
self.bridge = conv_block_2_3d(self.num_filters * 16, self.num_filters * 32, activation)
# Up sampling
self.trans_1 = conv_trans_block_3d(self.num_filters * 32, self.num_filters * 32, activation)
self.up_1 = conv_block_2_3d(self.num_filters * 48, self.num_filters * 16, activation)
self.trans_2 = conv_trans_block_3d(self.num_filters * 16, self.num_filters * 16, activation)
self.up_2 = conv_block_2_3d(self.num_filters * 24, self.num_filters * 8, activation)
self.trans_3 = conv_trans_block_3d(self.num_filters * 8, self.num_filters * 8, activation)
self.up_3 = conv_block_2_3d(self.num_filters * 12, self.num_filters * 4, activation)
self.trans_4 = conv_trans_block_3d(self.num_filters * 4, self.num_filters * 4, activation)
self.up_4 = conv_block_2_3d(self.num_filters * 6, self.num_filters * 2, activation)
self.trans_5 = conv_trans_block_3d(self.num_filters * 2, self.num_filters * 2, activation)
self.up_5 = conv_block_2_3d(self.num_filters * 3, self.num_filters * 1, activation)
# Output
self.out = conv_block_3d(self.num_filters, out_dim, activation)
def forward(self, x):
# Down sampling
down_1 = self.down_1(x) # -> [1, 4, 128, 128, 128]
pool_1 = self.pool_1(down_1) # -> [1, 4, 64, 64, 64]
down_2 = self.down_2(pool_1) # -> [1, 8, 64, 64, 64]
pool_2 = self.pool_2(down_2) # -> [1, 8, 32, 32, 32]
down_3 = self.down_3(pool_2) # -> [1, 16, 32, 32, 32]
pool_3 = self.pool_3(down_3) # -> [1, 16, 16, 16, 16]
down_4 = self.down_4(pool_3) # -> [1, 32, 16, 16, 16]
pool_4 = self.pool_4(down_4) # -> [1, 32, 8, 8, 8]
down_5 = self.down_5(pool_4) # -> [1, 64, 8, 8, 8]
pool_5 = self.pool_5(down_5) # -> [1, 64, 4, 4, 4]
# Bridge
bridge = self.bridge(pool_5) # -> [1, 128, 4, 4, 4]
# Up sampling
trans_1 = self.trans_1(bridge) # -> [1, 128, 8, 8, 8]
trans_1 = F.interpolate(trans_1, size = down_5.shape[-3:], mode = 'trilinear', align_corners = False)
concat_1 = torch.cat([trans_1, down_5], dim=1) # -> [1, 192, 8, 8, 8]
up_1 = self.up_1(concat_1) # -> [1, 64, 8, 8, 8]
trans_2 = self.trans_2(up_1) # -> [1, 64, 16, 16, 16]
trans_2 = F.interpolate(trans_2, size = down_4.shape[-3:], mode = 'trilinear', align_corners = False)
concat_2 = torch.cat([trans_2, down_4], dim=1) # -> [1, 96, 16, 16, 16]
up_2 = self.up_2(concat_2) # -> [1, 32, 16, 16, 16]
trans_3 = self.trans_3(up_2) # -> [1, 32, 32, 32, 32]
trans_3 = F.interpolate(trans_3, size = down_3.shape[-3:], mode = 'trilinear', align_corners = False)
concat_3 = torch.cat([trans_3, down_3], dim=1) # -> [1, 48, 32, 32, 32]
up_3 = self.up_3(concat_3) # -> [1, 16, 32, 32, 32]
trans_4 = self.trans_4(up_3) # -> [1, 16, 64, 64, 64]
trans_4 = F.interpolate(trans_4, size = down_2.shape[-3:], mode = 'trilinear', align_corners = False)
concat_4 = torch.cat([trans_4, down_2], dim=1) # -> [1, 24, 64, 64, 64]
up_4 = self.up_4(concat_4) # -> [1, 8, 64, 64, 64]
trans_5 = self.trans_5(up_4) # -> [1, 8, 128, 128, 128]
trans_5 = F.interpolate(trans_5, size = down_1.shape[-3:], mode = 'trilinear', align_corners = False)
concat_5 = torch.cat([trans_5, down_1], dim=1) # -> [1, 12, 128, 128, 128]
up_5 = self.up_5(concat_5) # -> [1, 4, 128, 128, 128]
# Output
out = self.out(up_5) # -> [1, 3, 128, 128, 128]
return out