Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 55 additions & 31 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Literal, Optional, Tuple
from typing import Callable, ClassVar, Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -1534,27 +1534,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
self.model_args = model_args

def init_weights(self, buffer_device: torch.device | None = None) -> None:
buffer_device = buffer_device or self.freqs_cis.device # type: ignore[has-type]
with torch.device(buffer_device):
self.freqs_cis = precompute_freqs_cis(self.model_args)
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
if layer is not None:
assert isinstance(layer, TransformerBlock)
layer.init_weights(buffer_device=buffer_device)
if self.norm is not None:
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
if self.output is not None:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
_init_weights_tok_embeddings(self)
_init_weights_layers(self, buffer_device)
_init_weights_norm_and_output(self)

def forward(
self,
Expand Down Expand Up @@ -1593,12 +1575,13 @@ def forward(


class DeepSeekV3StageI(nn.Module):
def __init__(self, layers, config):
def __init__(self, layers, model_args):
super().__init__()
self.layers = layers
self.register_buffer(
"freqs_cis", precompute_freqs_cis(config), persistent=False
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
)
self.model_args = model_args

def forward(self, h):
# intermediate stages only have layers
Expand All @@ -1607,14 +1590,12 @@ def forward(self, h):
return h

def init_weights(self, buffer_device: torch.device | None = None) -> None:
for layer in self.layers.values():
if layer is not None:
layer.init_weights(buffer_device=buffer_device)
_init_weights_layers(self, buffer_device)


class DeepSeekV3Stage0(DeepSeekV3StageI):
def __init__(self, embed, layers, config):
super().__init__(layers, config)
def __init__(self, embed, layers, model_args):
super().__init__(layers, model_args)
self.tok_embeddings = embed

def forward(self, tokens):
Expand All @@ -1623,20 +1604,63 @@ def forward(self, tokens):
# torch.Size([1024, 1024, 2048])
return super().forward(h)

def init_weights(self, buffer_device: torch.device | None = None) -> None:
_init_weights_tok_embeddings(self)
super().init_weights(buffer_device=buffer_device)


class DeepSeekV3StageN(DeepSeekV3StageI):
def __init__(self, layers, norm, output, config):
super().__init__(layers, config)
def __init__(self, layers, norm, output, model_args):
super().__init__(layers, model_args)
self.norm = norm
self.output = output
self.model_args = model_args

def forward(self, h):
h = super().forward(h)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output

def init_weights(self, buffer_device: torch.device | None = None) -> None:
super().init_weights(buffer_device=buffer_device)
_init_weights_norm_and_output(self)


######################
# Pipeline stuff end #
######################


def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]):
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)


def _init_weights_layers(
self: Union[DeepSeekV3Model, DeepSeekV3StageI],
buffer_device: torch.device | None,
):
if buffer_device is None:
buffer_device = self.freqs_cis.device # type: ignore[assignment]
with torch.device(buffer_device): # type: ignore[arg-type]
self.freqs_cis = precompute_freqs_cis(self.model_args)
for layer in self.layers.values():
if layer is not None:
assert isinstance(layer, TransformerBlock)
layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type]


def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]):
if self.norm is not None:
self.norm.reset_parameters()
if self.output is not None:
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)