-
Notifications
You must be signed in to change notification settings - Fork 183
/
balance.py
133 lines (100 loc) · 3.89 KB
/
balance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
r"""
Support for monitoring loss in Megatron
"""
import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {}
num_layers = 0
def reset_gate_hook(_num_layers=None):
from megatron import get_args
global balance_dict, num_layers
if _num_layers is not None:
num_layers = _num_layers
reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)
def get_balance_profile():
global balance_dict
return balance_dict
def generate_megatron_gate_hook(layer_idx, num_expert_global):
from megatron import get_args
balance_strategy = get_args().balance_strategy
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
global balance_dict
update_balance_profile(
balance_dict,
gate_top_k_idx,
gate_score_top_k,
gate_context,
layer_idx,
num_expert_global,
balance_strategy,
)
return megatron_gate_hook
def add_balance_log(model, writer, iteration):
from megatron import is_last_rank
if hasattr(model, 'module'):
model = model.module
balance_dict_tensor = torch.vstack(
[l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers]
).detach()
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
balance_dict_tensor /= world_size
if writer and is_last_rank():
for idx, metric_name in enumerate(balance_dict):
for layer_id, val in enumerate(balance_dict_tensor[idx]):
writer.add_scalar(
f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
)
writer.add_scalar(
f"balance-{metric_name}/all",
balance_dict_tensor[idx].mean().item(),
iteration,
)
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
if hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider