You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#2. converted to fp32 https://github.com/fastai/fastai2/blob/master/fastai2/callback/fp16.py
"""
def convert_module(module, dtype):
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
module.flatten_parameters()
return network
"""
The text was updated successfully, but these errors were encountered:
I am looking for a way to convert the deep speed fp16 model to fp32.
First, I pretrained with bing_bert(zero optimizer, 2).
Second, I fine-tuned with the squad(fp16).
Then, the "mp_rank_00_model_states.pt" file was converted to fp32(#2) with the code below and fine-tuned(#1).
Is there a good way?
And where should I put this?
"checkpoint_state_dict['optimizer']['fp32_groups_flat']"
model.network, optimizer, _, scheduler = deepspeed.initialize(args=args,
model=model.network,
model_parameters=checkpoint_state_dict['optimizer']['fp32_groups_flat'] ) <-- here?
#1. finetune
checkpoint_state_dict = torch.load(args.checkpoint, map_location=torch.device("cpu"))
model.network.load_state_dict(checkpoint_state_dict['module'], strict=False)
#2. converted to fp32
https://github.com/fastai/fastai2/blob/master/fastai2/callback/fp16.py
"""
def convert_module(module, dtype):
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
module.flatten_parameters()
return network
"""
The text was updated successfully, but these errors were encountered: