Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
pytext fp16 optimizer
Browse files Browse the repository at this point in the history
Summary: write the optimizer wrapper in pytext supporting mixed precision training without amp

Differential Revision: D16276949

fbshipit-source-id: aa2b859f98284bc5865993a7d3d208b6a0402092
  • Loading branch information
Yuqing Liu authored and facebook-github-bot committed Jul 16, 2019
1 parent 4124119 commit cd1f8d2
Show file tree
Hide file tree
Showing 2 changed files with 689 additions and 6 deletions.
361 changes: 355 additions & 6 deletions pytext/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def finalize(self) -> bool:
_FP16_ENABLED = False
_OPT_LEVEL = None
_DELAY_UNSCALE = False
TEST = False


@contextmanager
Expand Down Expand Up @@ -118,7 +119,10 @@ def initialize(model, optimizer):
global _OPT_LEVEL

if _FP16_ENABLED:

_OPT_LEVEL = "O2" if model.SUPPORT_FP16_OPTIMIZER else "O1"
if TEST:
return model.half(), FP16Optimizer(optimizer)
return amp.initialize(model, optimizer, opt_level=_OPT_LEVEL)
else:
return model, optimizer
Expand All @@ -128,20 +132,30 @@ def backward(optimizer, loss):
if _FP16_ENABLED:
# 1. Use automatic loss scaling to best use fp16 range
# 2. Clear handle's cache of casted parameters
if loss > 0:
with amp.scale_loss(
loss, optimizer, delay_unscale=_DELAY_UNSCALE
) as scaled_loss:
scaled_loss.backward()

print(loss)
if TEST:
if loss > 0:
optimizer.backward(loss)
else:
loss.backward()
else:
loss.backward()
if loss > 0:
with amp.scale_loss(
loss, optimizer, delay_unscale=_DELAY_UNSCALE
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()


def clip_grad_norm(model, optimizer, max_clip_norm):
print("------calling clip grad norm-------")
if _FP16_ENABLED:
# Refer: https://nvidia.github.io/apex/advanced.html
if TEST:
return optimizer.clip_grad_norm(max_clip_norm)

return torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), max_clip_norm
)
Expand Down Expand Up @@ -189,3 +203,338 @@ def pad_length(n):
n = n + 8 - remainder

return n


"""fp16 optimizer wraps torch.optim to support mixed precision training
usage:
1 optim.zero_grad()
2 for i in range(N):
3 model.forward() ---- fp16 weights
4 [pre_process()] ---- fp16 grads upscale
5 optim.backward() ---- upscaled fp16 grads
6 [post_process()] ---- downscale and float to fp32 grads
7 optim.step() ---- fp32 weights and grads
class FP16_Optimizer:
= Properties:
- inner_optimizer:
= type: Torch.optim
= contents: optimizer in pytext (eg. Adam)
which is initialized with fp16 params already
- param_groups:
= type: list of dictionaries where key is string and value is a list.
= contents: eg. [{'params':[]}]
- temp_fp32_params
= types: same as param_groups
= purpose: to support accumulating grads calculation
= contents: contain the temp fp32 grads from backward()
and will be unscaled and added to inner optimizer
- scaler:
- flags: BOOLEAN
= weights_update_needed: whether need to copy weights from master to model
= grads_update_needed: whether need to copy grads from model to master
= Methods:
- __init__()
- zero_grad
= effects: clear the grads in self.param_groups(fp16)
- backward()
- post_process()
- step(loss)
class DynamicLossScaler:
= properties:
- init_scale: the beginning scale number
- scale_factor: the step length that we use to increase the scale
- scale_window: the upper bound of iterations among which no overflow is triggered
- tolerance: the upper bound of the frequency that overflow happens
- threshold: the minimum value of the scale
- is_overflow
- is_scaled: whether grads are scaled
= Methods:
- check_overflow
- upscale
- unscale
- update_scale
"""


class DynamicLossScaler(object):
def __init__(self, init_scale, scale_factor, scale_window, tolerance, threshold):

self.scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = 0
self._last_rescale_iter = 0
self._overflows_since_rescale = 0
self.is_overflow = False

def upscale(self, loss):
return loss * self.scale

def unscale(self, grad_data):
return grad_data.mul_(1 / self.scale)

def check_overflow(self, grad_norm):
print("fp16optim is checking overflow")
if (
grad_norm == float("inf")
or grad_norm == -float("inf")
or grad_norm != grad_norm
):
self.is_overflow = True
else:
self.is_overflow = False

def check_overflow_step(self, model_params):
print("fp16optim is checking overflow")
for group in model_params:
for p in group["params"]:
if p.grad is not None and self._check_overflow_step(p.grad.data):
return True
return False

def _check_overflow_step(self, grad_data):
try:
cpu_sum = float(grad_data.float().sum())
except RuntimeError as instance:
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if (
cpu_sum == float("inf")
or cpu_sum == -float("inf")
or cpu_sum != cpu_sum
):
return True
return False

def update_scale(self):
"""
= effects:
- if last overflow is far from now, time to increase scale
- if more overflow happens than we expected, time to decrease the scale
"""
self._iter += 1
if self.is_overflow:
self._last_overflow_iter = self._iter
self.scale = max(self.scale / self.scale_factor, 1)
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.scale *= self.scale_factor
print("updated_scale in fp16 optim, new loss_scale is {}".format(self.scale))


class FP16Optimizer(object):
def __init__(
self,
init_optimizer,
init_scale=2.0 ** 16,
scale_factor=2,
scale_window=2000,
tolerance=0.05,
threshold=None,
):
"""
= input: init_optimizer(initialized already), init_scale, scale_factor,
scale_window, tolerance, threshold
= effects: initialize the optimizer and create master and loss scaling tools
= modifies:
- record the reference of model params (fp16)
- change the inner optimizer's params to fp32 with
torch.optim inner method
- initialized the scaler
- initialized state, default
"""
self.inner_optimizer = init_optimizer
self.param_groups = []
for _i, group in enumerate(self.inner_optimizer.param_groups):
fp16_group = {}
for key, value in group.items():
if key == "params":
fp16_param = []
for j, p in enumerate(value):
fp16_param.append(p)
master_p = p.detach().clone().float()
master_p.requires_grad_(True)
group["params"][j] = master_p
fp16_group["params"] = fp16_param
else:
fp16_group[key] = value
self.param_groups.append(fp16_group)

self.inner_optimizer = init_optimizer
self.loss_scaler = DynamicLossScaler(
init_scale, scale_factor, scale_window, tolerance, threshold
)
self.state = self.inner_optimizer.state
self.weights_update_needed = False
self.grads_update_needed = False
print("===========PyText FP16 wrapper is initialized======")

def zero_grad(self):
print("===========PyText FP16 wrapper is calling zero_grad===== ")
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

def backward(self, loss):
"""
= input: loss
= effects: do loss scaling and calculate grads
= modifies:
- upscale grads
- call loss.backward()
"""
print("===========PyText FP16 wrapper is calling backward===== ")
print("-----loss_scale is {}------".format(self.loss_scaler.scale))
scaled_loss = self.loss_scaler.upscale(loss)
scaled_loss.backward()
self.grads_update_needed = True

def _post_process_deprecated(self):
"""
= effects: check overflow and temp store the grads for accumulating grads
= modifies:
- check overflow:
= true: update_scale and skip
= false:
- update the period, if it is over the limit, update_scale
- copy the upscaled grads from backward() to temp_fp32_params
- float to fp32 and unscale grads
- add that to optim.params.grads
"""
self.loss_scaler.check_overflow(self.param_groups)
self.loss_scaler.update_scale()
if not self.loss_scaler.is_overflow:
self.temp_fp32_params = self.inner_optimizer.param_groups
for i, group in enumerate(self.params_groups):
for j, p in enumerate(group["params"]):
fp32_grads = self.temp_fp32_params[i]["params"][j].grad.data
fp32_grads = self.unscale(
fp32_grads.copy_(p.grad.data.view_as(fp32_grads))
)
self.inner_optimizer.param_groups[i]["params"][
j
].grad.data += fp32_grads

def clip_grad_norm(self, max_norm):
print("===========PyText FP16 wrapper is calling clip_grad_norm===== ")
self._grads_from_model_to_master()
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(), max_norm)
self.loss_scaler.check_overflow(grad_norm)
self.loss_scaler.update_scale()
if not self.loss_scaler.is_overflow:
return grad_norm
else:
print("for overflow debug: overflow happens")

def step(self):
"""
= input: closure
= effects:
- check model grads whether are overflow
- update the grads from model to master
- call inner optimizer's step
- copy back the weights from inner optimizer to model
"""
self.loss_scaler.is_overflow = self.loss_scaler.check_overflow_step(
self.param_groups
)
self.loss_scaler.update_scale()
if not self.loss_scaler.is_overflow:
print("===========PyText FP16 wrapper is calling step===== ")
self._grads_from_model_to_master()
self.inner_optimizer.step()
self.weights_update_needed = True
self._weights_from_master_to_model()
else:
print(
"overflow happens, skip step, new loss scale is {}".format(
self.loss_scaler.scale
)
)

def _weights_from_master_to_model(self):
if self.weights_update_needed:
for i, group in enumerate(self.inner_optimizer.param_groups):
for j, p in enumerate(group["params"]):
self.param_groups[i]["params"][j].data.copy_(p.data)
# self.param_groups[i]["params"][j].data.half()
self.weights_update_needed = False

def _grads_from_model_to_master(self):
if self.grads_update_needed:
for i, group in enumerate(self.param_groups):
for j, p in enumerate(group["params"]):
if self.inner_optimizer.param_groups[i]["params"][j].grad is None:
self.inner_optimizer.param_groups[i]["params"][
j
].grad = torch.empty_like(
self.inner_optimizer.param_groups[i]["params"][j]
)
self.inner_optimizer.param_groups[i]["params"][j].grad.data.copy_(
p.grad.data
)
# self.inner_optimizer.param_groups[i]["params"][j].grad.data.float()
self.loss_scaler.unscale(
self.inner_optimizer.param_groups[i]["params"][j].grad.data
)
self.grads_update_needed = False

def state_dict(self):
state_dict = {}
state_dict["loss_scaler"] = self.loss_scaler
state_dict["loss_scale"] = self.loss_scaler.scale
state_dict["overflow"] = self.loss_scaler.overflow
state_dict["param_groups"] = self.param_groups
state_dict["optimizer_state_dict"] = self.inner_optimizer.state_dict()
return state_dict

def load_state_dict(self, state_dict):
self.loss_scaler = state_dict["loss_scaler"]
self.loss_scaler.scale = state_dict["loss_scale"]
self.loss_scaler.overflow = state_dict["overflow"]
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
for i, group in state_dict["param_groups"]:
for j, p in group["params"]:
self.param_groups[i]["params"][j].data.copy_(p.data)

def master_params(self):
for group in self.inner_optimizer.param_groups:
for p in group["params"]:
yield p

def finalize(self):
return self.inner_optimizer.finalize()

def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")

def __setstate__(self, state):
raise RuntimeError(
"FP16_Optimizer should be deserialized using load_state_dict()."
)

def _get_loss_scale(self):
return self.loss_scaler.scale

def _set_loss_scale(self, value):
self.loss_scaler.scale = value

def _get_state(self):
return self.state

def _set_state(self, value):
self.state = value

def _get_param_groups(self):
return self.param_groups

def _set_param_groups(self, value):
self.param_groups = value
Loading

0 comments on commit cd1f8d2

Please sign in to comment.