-
Notifications
You must be signed in to change notification settings - Fork 18
/
aan.py
175 lines (125 loc) · 5.04 KB
/
aan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
def make_model(args, parent=False):
return AAN(args)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class PA(nn.Module):
'''PA is pixel attention'''
def __init__(self, nf):
super(PA, self).__init__()
self.conv = nn.Conv2d(nf, nf, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.conv(x)
y = self.sigmoid(y)
out = torch.mul(x, y)
return out
# Attention Branch
class AttentionBranch(nn.Module):
def __init__(self, nf, k_size=3):
super(AttentionBranch, self).__init__()
self.k1 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf
self.sigmoid = nn.Sigmoid()
self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
def forward(self, x):
y = self.k1(x)
y = self.lrelu(y)
y = self.k2(y)
y = self.sigmoid(y)
out = torch.mul(self.k3(x), y)
out = self.k4(out)
return out
class AAB(nn.Module):
def __init__(self, nf, reduction=4, K=2, t=30):
super(AAB, self).__init__()
self.t=t
self.K = K
self.conv_first = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
self.conv_last = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# Attention Dropout Module
self.ADM = nn.Sequential(
nn.Linear(nf, nf // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(nf // reduction, self.K, bias=False),
)
# attention branch
self.attention = AttentionBranch(nf)
# non-attention branch
# 3x3 conv for A2N
self.non_attention = nn.Conv2d(nf, nf, kernel_size=3, padding=(3 - 1) // 2, bias=False)
# 1x1 conv for A2N-M (Recommended, less parameters)
# self.non_attention = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
def forward(self, x):
residual = x
a, b, c, d = x.shape
x = self.conv_first(x)
x = self.lrelu(x)
# Attention Dropout
y = self.avg_pool(x).view(a,b)
y = self.ADM(y)
ax = F.softmax(y/self.t, dim = 1)
attention = self.attention(x)
non_attention = self.non_attention(x)
x = attention * ax[:,0].view(a,1,1,1) + non_attention * ax[:,1].view(a,1,1,1)
x = self.lrelu(x)
out = self.conv_last(x)
out += residual
return out
class AAN(nn.Module):
def __init__(self, args):
super(AAN, self).__init__()
in_nc = 3
out_nc = 3
nf = 40
unf = 24
nb = 16
scale = args.scale[0]
# AAB
AAB_block_f = functools.partial(AAB, nf=nf)
self.scale = scale
### first convolution
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
### main blocks
self.AAB_trunk = make_layer(AAB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True)
self.att1 = PA(unf)
self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
if self.scale == 4:
self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
self.att2 = PA(unf)
self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.AAB_trunk(fea))
fea = fea + trunk
if self.scale == 2 or self.scale == 3:
fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
elif self.scale == 4:
fea = self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att2(fea))
fea = self.lrelu(self.HRconv2(fea))
out = self.conv_last(fea)
ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
out = out + ILR
return out