-
Notifications
You must be signed in to change notification settings - Fork 8
/
patchup.py
140 lines (122 loc) · 6.64 KB
/
patchup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os, sys
from enum import Enum
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PatchUpMode(Enum):
SOFT = 'soft'
HARD = 'hard'
class PatchUp(nn.Module):
"""
PatchUp Module.
This module is responsible for applying either Soft PatchUp or Hard PatchUp after a Convolutional module
or convolutional residual block.
"""
def __init__(self, block_size=7, gamma=0.9, patchup_type=PatchUpMode.SOFT):
"""
PatchUp constructor.
Args:
block_size: An odd integer number that defines the size of blocks in the Mask that defines
the continuous feature should be altered.
gamma: It is float number in [0, 1]. The gamma in PatchUp decides the probability of altering a feature.
patchup_type: It is an enum type of PatchUpMode. It defines PatchUp type that can be either
Soft PatchUp or Hard PatchUp.
"""
super(PatchUp, self).__init__()
self.patchup_type = patchup_type
self.block_size = block_size
self.gamma = gamma
self.gamma_adj = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size // 2, block_size // 2)
self.computed_lam = None
def adjust_gamma(self, x):
"""
The gamma in PatchUp decides the probability of altering a feature.
This function is responsible to adjust the probability based on the
gamma value since we are altering a continues blocks in feature maps.
Args:
x: feature maps for a minibatch generated by a convolutional module or convolutional layer.
Returns:
the gamma which is a float number in [0, 1]
"""
return self.gamma * x.shape[-1] ** 2 / \
(self.block_size ** 2 * (x.shape[-1] - self.block_size + 1) ** 2)
def forward(self, x, targets=None, lam=None, patchup_type=PatchUpMode.SOFT):
"""
Forward pass in for the PatchUp Module.
Args:
x: Feature maps for a mini-batch generated by a convolutional module or convolutional layer.
targets: target of samples in the mini-batch.
lam: In a case, you want to apply PatchUp for a fixed lambda instead of sampling from the Beta distribution.
patchup_type: either Hard PatchUp or Soft PatchUp.
Returns:
the interpolated hidden representation using PatchUp.
target_a: targets associated with the first samples in the randomly selected sample pairs in the mini-batch.
target_b: targets associated with the second samples in the randomly selected sample pairs in the mini-batch.
target_reweighted: target target re-weighted for interpolated samples after altering patches
with either Hard PatchUp or Soft PatchUp.
x: interpolated hidden representations.
total_unchanged_portion: the portion of hidden representation that remained unchanged after applying PatchUp.
"""
self.patchup_type = patchup_type
if type(self.training) == type(None):
Exception("model's mode is not set in to neither training nor testing mode")
if not self.training:
# if the model is at the inference time (evaluation or test), we are not applying patchUp.
return x, targets
if type(lam) == type(None):
# if we are not using fixed lambda, we should sample from the Beta distribution with fixed alpha equal to 2.
# lambda will be a float number in range [0, 1].
lam = np.random.beta(2.0, 2.0)
if self.gamma_adj is None:
self.gamma_adj = self.adjust_gamma(x)
p = torch.ones_like(x[0]) * self.gamma_adj
# For each feature in the feature map, we will sample from Bernoulli(p). If the result of this sampling
# for feature f_{ij} is 0, then Mask_{ij} = 1. If the result of this sampling for f_{ij} is 1,
# then the entire square region in the mask with the center Mask_{ij} and the width and height of
# the square of block_size is set to 0.
m_i_j = torch.bernoulli(p)
mask_shape = len(m_i_j.shape)
# after creating the binary Mask. we are creating the binary Mask created for first sample as a pattern
# for all samples in the minibatch as the PatchUp binary Mask. to do so, we can just expnand the pattern
# created for the first sample.
m_i_j = m_i_j.expand(x.size(0), m_i_j.size(0), m_i_j.size(1), m_i_j.size(2))
# following line provides the continues blocks that should be altered with PatchUp denoted as holes here.
holes = F.max_pool2d(m_i_j, self.kernel_size, self.stride, self.padding)
# following line gives the binary mask that contains 1 for the features that should be remain unchanged and 1
# for the features that lie in the continues blocks that selected for interpolation.
mask = 1 - holes
unchanged = mask * x
if mask_shape == 1:
total_feats = x.size(1)
else:
total_feats = x.size(1) * (x.size(2) ** 2)
total_changed_pixels = holes[0].sum()
total_changed_portion = total_changed_pixels / total_feats
total_unchanged_portion = (total_feats - total_changed_pixels) / total_feats
# following line gives the indices of second ssamples in the pair permuted randomly.
indices = np.random.permutation(x.size(0))
target_shuffled_onehot = targets[indices]
patches = None
target_reweighted = None
target_b = None
if self.patchup_type == PatchUpMode.SOFT:
# apply Soft PatchUp combining operation for the selected continues blocks.
target_reweighted = total_unchanged_portion * targets + lam * total_changed_portion * targets + \
target_shuffled_onehot * (1 - lam) * total_changed_portion
patches = holes * x
patches = patches * lam + patches[indices] * (1 - lam)
target_b = lam * targets + (1 - lam) * target_shuffled_onehot
elif self.patchup_type == PatchUpMode.HARD:
# apply Hard PatchUp combining operation for the selected continues blocks.
target_reweighted = total_unchanged_portion * targets + total_changed_portion * target_shuffled_onehot
patches = holes * x
patches = patches[indices]
target_b = targets[indices]
x = unchanged + patches
target_a = targets
return target_a, target_b, target_reweighted, x, total_unchanged_portion