-
Notifications
You must be signed in to change notification settings - Fork 1
/
unet.py
118 lines (108 loc) · 4.52 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import torch.nn as nn
class EncoderBlock(nn.Module):
# Consists of Conv -> ReLU -> MaxPool
def __init__(self, in_chans, out_chans, layers=2, sampling_factor=2, padding="same"):
"""
Parameters:
in_chans: number of channels in the input
out_chans: number of channels in the output
upsample: downsampling factor. Same as upsampling factor in DecoderBlock
padding: padding applied to Conv2d
"""
super().__init__()
self.encoder = nn.ModuleList()
self.encoder.append(nn.Conv2d(in_chans, out_chans, 3, 1, padding=padding))
self.encoder.append(nn.ReLU())
for _ in range(layers-1):
self.encoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
self.encoder.append(nn.ReLU())
self.mp = nn.MaxPool2d(sampling_factor)
def forward(self, x):
"""
Input:
x: input feature map of size (B, C_in, H, W)
Returns:
output feature map of size (B, C_out, H//downsample, W//downsample)
"""
for enc in self.encoder:
x = enc(x)
mp_out = self.mp(x)
return mp_out, x
class DecoderBlock(nn.Module):
# Consists of 2x2 transposed convolution -> Conv -> relu
def __init__(self, in_chans, out_chans, layers=2, skip_connection=True, sampling_factor=2, padding="same"):
"""
Parameters:
in_chans: number of channels in the input
out_chans: number of channels in the output
skip: whether or not to have skip connections
upsample: upsampling factor. Same as downsampling factor
padding: padding applied to Conv2d
"""
super().__init__()
skip_factor = 1 if skip_connection else 2
self.decoder = nn.ModuleList()
self.tconv = nn.ConvTranspose2d(in_chans, in_chans//2, sampling_factor, sampling_factor)
self.decoder.append(nn.Conv2d(in_chans//skip_factor, out_chans, 3, 1, padding=padding))
self.decoder.append(nn.ReLU())
for _ in range(layers-1):
self.decoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
self.decoder.append(nn.ReLU())
self.skip_connection = skip_connection
self.padding = padding
def forward(self, x, enc_features=None):
"""
Input:
x: input feature map of size (B, C_in, H, W)
enc_feature: input feature from EncoderBlock if skip connection is used
Returns:
output feature map of size (B, C_out, H*upsample, W*upsample)
"""
x = self.tconv(x)
if self.skip_connection:
if self.padding != "same":
# Crop the enc_features to the same size as input
w = x.size(-1)
c = (enc_features.size(-1) - w) // 2
enc_features = enc_features[:,:,c:c+w,c:c+w]
x = torch.cat((enc_features, x), dim=1)
for dec in self.decoder:
x = dec(x)
return x
class UNet(nn.Module):
def __init__(self, nclass=1, in_chans=1, depth=5, layers=2, sampling_factor=2, skip_connection=True, padding="same"):
"""
Parameters:
nclass: number of class
in_chans: number of channels in the input
depth: depth of the U-Net
skip: whether or not to have skip connections
sample_factor: upsampling & downsampling factor
padding: padding applied to Conv2d
"""
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
out_chans = 64
for _ in range(depth):
self.encoder.append(EncoderBlock(in_chans, out_chans, layers, sampling_factor, padding))
in_chans, out_chans = out_chans, out_chans*2
out_chans = in_chans // 2
for _ in range(depth-1):
self.decoder.append(DecoderBlock(in_chans, out_chans, layers, skip_connection, sampling_factor, padding))
in_chans, out_chans = out_chans, out_chans//2
# Add a 1x1 convolution to produce final classes
self.logits = nn.Conv2d(in_chans, nclass, 1, 1)
def forward(self, x):
encoded = []
for enc in self.encoder:
x, enc_output = enc(x)
encoded.append(enc_output)
x = encoded.pop()
for dec in self.decoder:
enc_output = encoded.pop()
x = dec(x, enc_output)
# Return the logits
return self.logits(x)