-
Notifications
You must be signed in to change notification settings - Fork 0
/
Unex_net.py
349 lines (279 loc) · 13.1 KB
/
Unex_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import torch
from torch import nn
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from utils import *
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
path = os.getcwd()+'\\runs'
writer = SummaryWriter(path+'\\experiment')
__all__ = ['UNext']
import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types
import math
# from abc import ABCMeta, abstractmethod
# from mmcv.cnn import ConvModule
# import pdb
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)
def shift(dim):
x_shift = [torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
x_cat = torch.cat(x_shift, 1)
x_cat = torch.narrow(x_cat, 2, self.pad, H)
x_cat = torch.narrow(x_cat, 3, self.pad, W)
return x_cat
class shiftmlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dim = in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.shift_size = shift_size
self.pad = shift_size // 2
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def shift(x, dim):
x = F.pad(x, "constant", 0)
x = torch.chunk(x, shift_size, 1)
x = [torch.roll(x_c, shift, dim) for x_s, shift in zip(x, range(-pad, pad + 1))]
x = torch.cat(x, 1)
return x[:, :, pad:-pad, pad:-pad]
def forward(self, x, H, W):
# pdb.set_trace()
B, N, C = x.shape
xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
xs = torch.chunk(xn, self.shift_size, 1)
x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
x_cat = torch.cat(x_shift, 1)
x_cat = torch.narrow(x_cat, 2, self.pad, H)
x_s = torch.narrow(x_cat, 3, self.pad, W)
x_s = x_s.reshape(B, C, H * W).contiguous()
x_shift_r = x_s.transpose(1, 2)
x = self.fc1(x_shift_r)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
xs = torch.chunk(xn, self.shift_size, 1)
x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
x_cat = torch.cat(x_shift, 1)
x_cat = torch.narrow(x_cat, 2, self.pad, H)
x_s = torch.narrow(x_cat, 3, self.pad, W)
x_s = x_s.reshape(B, C, H * W).contiguous()
x_shift_c = x_s.transpose(1, 2)
x = self.fc2(x_shift_c)
x = self.drop(x)
return x
class shiftedBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class UNext(nn.Module):
## Conv 3 + MLP 2 + shifted MLP
def __init__(self, num_classes, input_channels=1, deep_supervision=False, img_size=224, patch_size=16, in_chans=3,
embed_dims=[32, 64, 128, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
super().__init__()
self.encoder1 = nn.Conv2d(input_channels, 8, 3, stride=1, padding=1)
self.encoder2 = nn.Conv2d(8, 16, 3, stride=1, padding=1)
self.encoder3 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
self.ebn1 = nn.BatchNorm2d(8)
self.ebn2 = nn.BatchNorm2d(16)
self.ebn3 = nn.BatchNorm2d(32)
self.norm3 = norm_layer(embed_dims[1])
self.norm4 = norm_layer(embed_dims[2])
self.dnorm3 = norm_layer(64)
self.dnorm4 = norm_layer(32)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.block1 = nn.ModuleList([shiftedBlock(
dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])])
self.block2 = nn.ModuleList([shiftedBlock(
dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])])
self.dblock1 = nn.ModuleList([shiftedBlock(
dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])])
self.dblock2 = nn.ModuleList([shiftedBlock(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.decoder1 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
self.decoder2 = nn.Conv2d(64, 32, 3, stride=1, padding=1)
self.decoder3 = nn.Conv2d(32, 16, 3, stride=1, padding=1)
self.decoder4 = nn.Conv2d(16, 8, 3, stride=1, padding=1)
self.decoder5 = nn.Conv2d(8, 8, 3, stride=1, padding=1)
self.dbn1 = nn.BatchNorm2d(64)
self.dbn2 = nn.BatchNorm2d(32)
self.dbn3 = nn.BatchNorm2d(16)
self.dbn4 = nn.BatchNorm2d(8)
self.final = nn.Conv2d(8, num_classes, kernel_size=1)
self.soft = nn.Softmax(dim=1)
def forward(self, x):
B = x.shape[0]
### Encoder
### Conv Stage
### Stage 1
out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)), 2, 2))
t1 = out
### Stage 2
out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)), 2, 2))
t2 = out
### Stage 3
out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)), 2, 2))
t3 = out
### Tokenized MLP Stage
### Stage 4
out, H, W = self.patch_embed3(out)
for i, blk in enumerate(self.block1):
out = blk(out, H, W)
out = self.norm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
t4 = out
### Bottleneck
out, H, W = self.patch_embed4(out)
for i, blk in enumerate(self.block2):
out = blk(out, H, W)
out = self.norm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
### Stage 4
out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t4)
_, _, H, W = out.shape
out = out.flatten(2).transpose(1, 2)
for i, blk in enumerate(self.dblock1):
out = blk(out, H, W)
### Stage 3
out = self.dnorm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t3)
_, _, H, W = out.shape
out = out.flatten(2).transpose(1, 2)
for i, blk in enumerate(self.dblock2):
out = blk(out, H, W)
out = self.dnorm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t2)
out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t1)
out = F.relu(F.interpolate(self.decoder5(out), scale_factor=(2, 2), mode='bilinear'))
return self.final(out)
model=UNext(1)
# # model = Unet()
# dummy_input = torch.rand(1, 1, 224, 224)
# with SummaryWriter(comment='Unet') as w:
# w.add_graph(model, (dummy_input, ))
# print(dummy_input.shape)
# tensorboard --logdir=D:\opencv-python\条纹投影\Unet_test\runs\Apr05_11-53-43_LAPTOP-85DL9FARUnet
# ch = 1
# h = 224
# w = 224
# summary(model, input_size=(ch, h, w), device='cpu')