-
Notifications
You must be signed in to change notification settings - Fork 400
/
stochastic_layers.py
143 lines (104 loc) · 4.45 KB
/
stochastic_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Stochastic forward functions for ResNet Bottleneck modules."""
from typing import Optional
import torch
import torch.nn as nn
from torch.fx import GraphModule
from torchvision.models.resnet import Bottleneck
__all__ = ['make_resnet_bottleneck_stochastic', 'BlockStochasticModule']
def block_stochastic_forward(self, x):
"""ResNet Bottleneck forward function where the layers are randomly
skipped with probability ``drop_rate`` during training.
"""
identity = x
sample = (not self.training) or bool(torch.bernoulli(1 - self.drop_rate))
if sample:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if not self.training:
out = out * (1 - self.drop_rate)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
else:
if self.downsample is not None:
out = self.relu(self.downsample(identity))
else:
out = identity
return out
def _sample_drop(x: torch.Tensor, sample_drop_rate: float, is_training: bool):
"""Randomly drops samples from the input batch according to the `sample_drop_rate`.
This is implemented by setting the samples to be dropped to zeros.
"""
keep_probability = (1 - sample_drop_rate)
if not is_training:
return x * keep_probability
rand_dim = [x.shape[0]] + [1] * len(x.shape[1:])
sample_mask = keep_probability + torch.rand(rand_dim, dtype=x.dtype, device=x.device)
sample_mask.floor_() # binarize
x *= sample_mask
return x
def sample_stochastic_forward(self, x):
"""ResNet Bottleneck forward function where samples are randomly
dropped with probability ``drop_rate`` during training.
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
if self.drop_rate:
out = _sample_drop(out, self.drop_rate, self.training)
out += identity
return self.relu(out)
def make_resnet_bottleneck_stochastic(module: Bottleneck, module_index: int, module_count: int, drop_rate: float,
drop_distribution: str, stochastic_method: str):
"""Model surgery policy that dictates how to convert a ResNet Bottleneck layer into a stochastic version.
"""
if drop_distribution == 'linear':
drop_rate = ((module_index + 1) / module_count) * drop_rate
module.drop_rate = torch.tensor(drop_rate)
stochastic_func = block_stochastic_forward if stochastic_method == 'block' else sample_stochastic_forward
module.forward = stochastic_func.__get__(module) # Bind new forward function to ResNet Bottleneck Module
return module
class BlockStochasticModule(nn.Module):
"""A convenience class that stochastically executes the provided main path of a residual block.
Args:
main (GraphModule): Operators in the main (non-residual) path of a residual block.
residual (GraphModule | None): Operators, if any, in the residual path of a residual block.
drop_rate: The base probability of dropping this layer. Must be between 0.0 (inclusive) and 1.0 (inclusive).
Returns:
BlockStochasticModule: An instance of :class:`.BlockStochasticModule`.
"""
def __init__(self, main: GraphModule, residual: Optional[GraphModule] = None, drop_rate: float = 0.2):
super().__init__()
self.drop_rate = torch.tensor(drop_rate)
self.main = main
self.residual = residual
def forward(self, x):
sample = (not self.training) or bool(torch.bernoulli(1 - self.drop_rate))
# main side is the non-residual connection
residual_result = x
# residual side may or may not have any operations
if self.residual:
residual_result = self.residual(x)
if sample:
main_result = self.main(x)
if not self.training:
main_result = main_result * (1 - self.drop_rate)
residual_result = torch.add(main_result, residual_result)
return residual_result