-
Notifications
You must be signed in to change notification settings - Fork 2
/
modules.py
47 lines (42 loc) · 2.24 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
Module definitions for tensor reorder-based compressed models.
Written by Matej Ulicny.
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch_dct as dct
class CompConv(nn.Module):
def __init__(self, ni, no, kernel_size, stride=1, padding=0, bias=True, groups=1, g=4, r=2, progressive=False, ref_size=64):
super(CompConv, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.ni = ni
self.no = no
self.groups = groups
self.r = r * ((no * ni * kernel_size**2)**(0.5) / ref_size) + 1 if progressive else r
self.index = nn.Parameter(torch.IntTensor((no * ni * kernel_size**2) // g), requires_grad=False)
self.weight = nn.Parameter(nn.init.kaiming_normal_(torch.Tensor(g, int(((no * self.ni * kernel_size**2) // g)/self.r)), mode='fan_out', nonlinearity='relu'))
self.bias = nn.Parameter(nn.init.zeros_(torch.Tensor(no))) if bias else None
def forward(self, x):
filt = dct.idct(F.pad(self.weight, (0, self.index.size(0) - self.weight.size(1))), norm='ortho')
filt = filt[:, self.index.long()]
filt = torch.reshape(filt, (self.no, self.ni, self.kernel_size, self.kernel_size))
x = F.conv2d(x, filt, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
return x
class CompLinear(nn.Module):
def __init__(self, ni, no, bias=True, g=4, r=2, progressive=False, ref_size=64):
super(CompLinear, self).__init__()
self.ni = ni
self.no = no
self.r = r * ((no * ni)**(0.5) / ref_size) + 1 if progressive else r
self.index = nn.Parameter(torch.IntTensor((no * ni) // g), requires_grad=False)
self.weight = nn.Parameter(nn.init.kaiming_normal_(torch.Tensor(g, int(((no * self.ni) // g)/self.r)), mode='fan_out', nonlinearity='relu'))
self.bias = nn.Parameter(nn.init.zeros_(torch.Tensor(no))) if bias else None
def forward(self, x):
filt = dct.idct(F.pad(self.weight, (0, self.index.size(0) - self.weight.size(1))), norm='ortho')
filt = filt[:, self.index.long()]
filt = torch.reshape(filt, (self.no, self.ni))
x = F.linear(x, filt, bias=self.bias)
return x