diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index ff09c65..18eb897 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -5,7 +5,7 @@ import math from dataclasses import dataclass, field -from typing import Callable, ClassVar, Literal, Optional, Tuple +from typing import Callable, ClassVar, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -1534,27 +1534,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.model_args = model_args def init_weights(self, buffer_device: torch.device | None = None) -> None: - buffer_device = buffer_device or self.freqs_cis.device # type: ignore[has-type] - with torch.device(buffer_device): - self.freqs_cis = precompute_freqs_cis(self.model_args) - if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - if layer is not None: - assert isinstance(layer, TransformerBlock) - layer.init_weights(buffer_device=buffer_device) - if self.norm is not None: - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - if self.output is not None: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) + _init_weights_tok_embeddings(self) + _init_weights_layers(self, buffer_device) + _init_weights_norm_and_output(self) def forward( self, @@ -1593,12 +1575,13 @@ def forward( class DeepSeekV3StageI(nn.Module): - def __init__(self, layers, config): + def __init__(self, layers, model_args): super().__init__() self.layers = layers self.register_buffer( - "freqs_cis", precompute_freqs_cis(config), persistent=False + "freqs_cis", precompute_freqs_cis(model_args), persistent=False ) + self.model_args = model_args def forward(self, h): # intermediate stages only have layers @@ -1607,14 +1590,12 @@ def forward(self, h): return h def init_weights(self, buffer_device: torch.device | None = None) -> None: - for layer in self.layers.values(): - if layer is not None: - layer.init_weights(buffer_device=buffer_device) + _init_weights_layers(self, buffer_device) class DeepSeekV3Stage0(DeepSeekV3StageI): - def __init__(self, embed, layers, config): - super().__init__(layers, config) + def __init__(self, embed, layers, model_args): + super().__init__(layers, model_args) self.tok_embeddings = embed def forward(self, tokens): @@ -1623,12 +1604,17 @@ def forward(self, tokens): # torch.Size([1024, 1024, 2048]) return super().forward(h) + def init_weights(self, buffer_device: torch.device | None = None) -> None: + _init_weights_tok_embeddings(self) + super().init_weights(buffer_device=buffer_device) + class DeepSeekV3StageN(DeepSeekV3StageI): - def __init__(self, layers, norm, output, config): - super().__init__(layers, config) + def __init__(self, layers, norm, output, model_args): + super().__init__(layers, model_args) self.norm = norm self.output = output + self.model_args = model_args def forward(self, h): h = super().forward(h) @@ -1636,7 +1622,45 @@ def forward(self, h): output = self.output(h) if self.output is not None else h return output + def init_weights(self, buffer_device: torch.device | None = None) -> None: + super().init_weights(buffer_device=buffer_device) + _init_weights_norm_and_output(self) + ###################### # Pipeline stuff end # ###################### + + +def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]): + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + + +def _init_weights_layers( + self: Union[DeepSeekV3Model, DeepSeekV3StageI], + buffer_device: torch.device | None, +): + if buffer_device is None: + buffer_device = self.freqs_cis.device # type: ignore[assignment] + with torch.device(buffer_device): # type: ignore[arg-type] + self.freqs_cis = precompute_freqs_cis(self.model_args) + for layer in self.layers.values(): + if layer is not None: + assert isinstance(layer, TransformerBlock) + layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type] + + +def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]): + if self.norm is not None: + self.norm.reset_parameters() + if self.output is not None: + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + )