-
Notifications
You must be signed in to change notification settings - Fork 0
/
mup.py
137 lines (115 loc) · 5.17 KB
/
mup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from collections import defaultdict
import torch
def get_mup_multipliers(base_model, main_model):
"""
Make a dict of name:multiplier for each parameter in main model
"""
base_shapes = _get_shapes(base_model)
model_shapes = _get_shapes(main_model)
basenames = set(base_shapes.keys())
names = set(model_shapes.keys())
assert basenames == names, (
f"`base_shapes` has extra names {basenames - names}. " f"`shapes` has extra names {names - basenames}."
)
multipliers = {}
for name, b_shape in base_shapes.items():
multipliers[name] = _get_multiplier(b_shape, model_shapes[name])
return multipliers
def _get_multiplier(base_dims, dims):
# the 'multiplier' is the ratio of dim / base_dim for the **last dimension** that is infinite
# the weight is 'matrix like' if it has >1 infinite dimension
# eg if base_dims=[d1, d2_base] and dims=[d1, d2] we would return (d2/d2_base, False)
num_inf_dims = 0
multiplier = 1
for base_dim, dim in zip(base_dims, dims):
assert isinstance(base_dim, int), f"Unknown base_dim type: {type(base_dim)}"
if base_dim != dim:
num_inf_dims += 1
multiplier = dim / base_dim
is_matrix_like = True if num_inf_dims > 1 else False
return (multiplier, is_matrix_like)
def mup_init(model, mup_multipliers_dict):
for name, module in model.named_modules():
if isinstance(module, MuReadout):
module.width_mult = mup_multipliers_dict[f"{name}.weight"][0]
module._rescale_parameters()
for name, param in model.named_parameters():
if "layers" in name:
name = ".".join(name.split(".")[2:])
if "bias" in name and name in mup_multipliers_dict:
param.data *= mup_multipliers_dict[name][0] ** 0.5
def build_optimizer_param_groups(model, mup_multipliers_dict, decoupled_wd=False, **optimizer_kwargs):
"""
MuP scales the lr according to if a param is 'matrix like' or 'vector like'
We build params_groups based on this scaled lr
"""
def new_group():
new_g = {k: v for k, v in optimizer_kwargs.items()}
new_g["params"] = []
return new_g
param_groups = defaultdict(new_group)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if "_fsdp_wrapped_module." in name:
name = name.split("_fsdp_wrapped_module.")[-1]
else:
if name.startswith("module."):
name = name.split("module.")[-1]
if "layers" in name:
name = ".".join(name.split(".")[2:])
multiplier, is_matrix_like = mup_multipliers_dict[name]
if is_matrix_like:
param_groups[multiplier]["params"].append(param)
else:
param_groups[1.0]["params"].append(param)
for width_mult, group in param_groups.items():
# Scale learning rate and weight decay accordingly
group["lr"] /= width_mult
if not decoupled_wd:
group["weight_decay"] *= width_mult
return list(param_groups.values())
def _get_shapes(model):
"""
Returns a dictionary of name:shape for each unique layer in a model.
If a model comprises multiple 'blocks' (eg TransformerBlocks)
we assume every block has the same dimensions
"""
shapes_dict = {}
for name, param in model.named_parameters():
if "layers.0" in name:
name = ".".join(name.split(".")[2:])
shapes_dict[name] = param.shape
elif "layers" in name:
name = ".".join(name.split(".")[2:])
assert shapes_dict[name] == param.shape, "_get_shapes assumes all blocks have the same dimensions"
else:
shapes_dict[name] = param.shape
return shapes_dict
class MuReadout(torch.nn.Linear):
"""Drop-in replacement for all output linear layers.
An "output" linear layer is one that maps from a width dimension (e.g.,
`d_model` in a Transformer) to a non-width dimension (e.g., vocab size).
This layer implements the version of μP with a 1/width multiplier and a
constant variance initialization for both weights and biases.
"""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias=bias)
self.width_mult = None
self._has_rescaled_params = False
def _rescale_parameters(self):
"""
Rescale parameters to convert SP initialization to μP initialization.
Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
"""
assert self.width_mult is not None, "Width multiplier not set - have you called mup_init on the model?"
if self._has_rescaled_params:
raise RuntimeError("`_rescale_parameters` has been called once before already.")
if self.bias is not None:
self.bias.data *= self.width_mult**0.5
self.weight.data *= self.width_mult**0.5
self._has_rescaled_params = True
def forward(self, x):
assert self.width_mult is not None, "Width multiplier not set - have you called mup_init on the model?"
return super().forward(x / self.width_mult)