-
Notifications
You must be signed in to change notification settings - Fork 111
/
utils.py
116 lines (99 loc) · 4.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy.random as random
import torch
import torch.nn as nn
import torch.nn.functional as F
from MinkowskiEngine import SparseTensor
class MinkowskiGRN(nn.Module):
""" GRN layer for sparse tensors.
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, dim))
self.beta = nn.Parameter(torch.zeros(1, dim))
def forward(self, x):
cm = x.coordinate_manager
in_key = x.coordinate_map_key
Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return SparseTensor(
self.gamma * (x.F * Nx) + self.beta + x.F,
coordinate_map_key=in_key,
coordinate_manager=cm)
class MinkowskiDropPath(nn.Module):
""" Drop Path for sparse tensors.
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(MinkowskiDropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
cm = x.coordinate_manager
in_key = x.coordinate_map_key
keep_prob = 1 - self.drop_prob
mask = torch.cat([
torch.ones(len(_)) if random.uniform(0, 1) > self.drop_prob
else torch.zeros(len(_)) for _ in x.decomposed_coordinates
]).view(-1, 1).to(x.device)
if keep_prob > 0.0 and self.scale_by_keep:
mask.div_(keep_prob)
return SparseTensor(
x.F * mask,
coordinate_map_key=in_key,
coordinate_manager=cm)
class MinkowskiLayerNorm(nn.Module):
""" Channel-wise layer normalization for sparse tensors.
"""
def __init__(
self,
normalized_shape,
eps=1e-6,
):
super(MinkowskiLayerNorm, self).__init__()
self.ln = nn.LayerNorm(normalized_shape, eps=eps)
def forward(self, input):
output = self.ln(input.F)
return SparseTensor(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager)
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x