-
Notifications
You must be signed in to change notification settings - Fork 3
/
dynamic_mixup.py
65 lines (57 loc) · 4.18 KB
/
dynamic_mixup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Curriculum Mixup
# Performs Manifold Mixup and Output Mixup with an increasing alpha to get a gradual increase in difficulty.
# source: https://github.com/nestordemeure/ManifoldMixupV2/blob/master/dynamic_mixup.py
from torch.distributions.beta import Beta
from fastai.basics import *
from fastai.callback.schedule import *
from manifold_mixup import *
__all__ = ['DynamicManifoldMixup', 'DynamicOutputMixup']
class DynamicManifoldMixup(ManifoldMixup):
"Implements a scheduling policy on top of manifold mixup, letting you increase the difficulty progressively."
def __init__(self, alpha_min=0., alpha_max:float=0.6, scheduler=SchedCos, **kwargs):
"""
`alpha_min` is the minimum value of the parameter for the beta law, we recommand keeping it to 0.
`alpha_max` is the parameter for the beta law, we recommend a value between 0. and 1.
`scheduler` is the scheduling function (such as SchedLin, SchedCos, SchedNo, SchedExp or SchedPoly)
See the [Annealing](http://dev.fast.ai/callback.schedule#Annealing) section of fastai2's documentation for a list of available schedulers, ways to combine them and provide your own.
Note that you can pass a raw scheduler (`SchedCos`), that will go from 0 to alpha_max, but also a partially applied scheduler to have full control over the minimum and maximum values (`SchedCos(0.,0.8)`)
"""
if 'alpha' in kwargs:
# insures that the user is using alpha_max
raise Exception('`alpha` parameter detected, please use `alpha_max` (and optionally `alpha_min`) when calling a curriculum based mixup callback.')
print("Scheduler detected, growing alpha from", alpha_min, "to", alpha_max)
super().__init__(alpha=0., **kwargs)
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.scheduler = scheduler
def before_batch(self):
"Updates alpha as a function of the training percentage."
# we do the partial application here (and not in the constructor) to avoid a pickle ambiguity error on learn.export
# due to the fact that the partially applied function as the same name as the original function
alpha = self.scheduler(self.alpha_min, self.alpha_max)(self.pct_train)
self.distrib = Beta(tensor(alpha), tensor(alpha))
super().before_batch()
class DynamicOutputMixup(OutputMixup):
"Implements a scheduling policy on top of output mixup, letting you increase the difficulty progressively."
def __init__(self, alpha_min:float=0.0, alpha_max:float=0.6, scheduler=SchedCos, **kwargs):
"""
`alpha_min` is the minimum value of the parameter for the beta law, we recommand keeping it to 0.
`alpha_max` is the parameter for the beta law, we recommend a value between 0. and 1.
`scheduler` is the scheduling function (such as SchedLin, SchedCos, SchedNo, SchedExp or SchedPoly)
See the [Annealing](http://dev.fast.ai/callback.schedule#Annealing) section of fastai2's documentation for a list of available schedulers, ways to combine them and provide your own.
Note that you can pass a raw scheduler (`SchedCos`), that will go from 0 to alpha_max, but also a partially applied scheduler to have full control over the minimum and maximum values (`SchedCos(0.,0.8)`)
"""
if 'alpha' in kwargs:
# insures that the user is using alpha_max
raise Exception('`alpha` parameter detected, please use `alpha_max` (and optionally `alpha_min`) when calling a curriculum based mixup callback.')
print("Scheduler detected, growing alpha from", alpha_min, "to", alpha_max)
super().__init__(alpha=0., **kwargs)
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.scheduler = scheduler
def before_batch(self):
"Updates alpha as a function of the training percentage."
# we do the partial application here (and not in the constructor) to avoid a pickle ambiguity error on learn.export
# due to the fact that the partially applied function as the same name as the original function
alpha = self.scheduler(self.alpha_min, self.alpha_max)(self.pct_train)
self.distrib = Beta(tensor(alpha), tensor(alpha))