Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
version: v1.0
# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara)
train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting
name: unpadded3 #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/aj
name: mup_test #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/ablations
dump_dir: '' # No need now
steps: 500000
seed: 777
Expand Down
24 changes: 17 additions & 7 deletions apps/Castor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .modules.vae import VideoVAEArgs, create_vae
from .modules.vision_encoder import VisionEncoderArgs, create_vision_encoder

from .modules.component import layer_init_kaiming_normal
from mup import MuReadout

logger = logging.getLogger()


Expand Down Expand Up @@ -45,19 +48,26 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int):
super(AlignmentProjection, self).__init__()

self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, encoder_dim),
)
MuReadout(input_dim, hidden_dim), # mup
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), # mup
nn.SiLU(),
nn.Linear(hidden_dim, encoder_dim), # mup
)

self.reset_parameters()

def forward(self, x):
x = self.proj(x)
return x

def reset_parameters(self):
# MuReadout has its own initialization
layer_init_kaiming_normal(self.proj[2])
nn.init.constant_(self.proj[4].weight, 0.) # initialize output weights by zero.
if self.proj[4].bias is not None:
nn.init.constant_(self.proj[4].bias, 0.)


class Castor(nn.Module):
VERSION: str = "v1.0"
Expand Down
193 changes: 60 additions & 133 deletions apps/Castor/modules/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from xformers.ops import AttentionBias, fmha
from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb
from types import SimpleNamespace
from mup import MuReadout

# fa3
from flash_attn_interface import flash_attn_varlen_func
Expand All @@ -23,11 +24,10 @@
flex_attention_comp = torch.compile(flex_attention)


class InitStdFactor(Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
def layer_init_kaiming_normal(x):
nn.init.kaiming_normal_(x.weight, a=1, mode='fan_in')
if x.bias is not None:
nn.init.constant_(x.bias, 0.)


@dataclass
Expand Down Expand Up @@ -482,26 +482,11 @@ def forward(

return output

def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.dim ** (-0.5))

for w in [self.wq, self.wk, self.wv]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std / factor,
a=-3 * init_std,
b=3 * init_std,
)

def reset_parameters(self):
layer_init_kaiming_normal(self.wq)
layer_init_kaiming_normal(self.wk)
layer_init_kaiming_normal(self.wv)
layer_init_kaiming_normal(self.wo)
if isinstance(self.q_norm, RMSNorm):
self.q_norm.reset_parameters()
if isinstance(self.k_norm, RMSNorm):
Expand Down Expand Up @@ -540,44 +525,45 @@ def __init__(
self.liger_rotary_emb = liger_rotary_emb
self.liger_rms_norm = liger_rms_norm
self.window_size = window_size
self.qk_norm = qk_norm

self.wq = nn.Linear(
dim,
n_heads * self.head_dim,
bias=False,
)
self.wk = nn.Linear(
) # mup
self.wk = MuReadout(
dim,
n_kv_heads * self.head_dim,
bias=False,
)
self.wv = nn.Linear(
) # mup
self.wv = MuReadout(
dim,
n_kv_heads * self.head_dim,
bias=False,
)
nn.init.xavier_uniform_(self.wq.weight)
nn.init.xavier_uniform_(self.wk.weight)
nn.init.xavier_uniform_(self.wv.weight)
) # mup

self.wo = nn.Linear(
n_heads * self.head_dim,
dim,
bias=False,
)
nn.init.xavier_uniform_(self.wo.weight)
) # mup

if qk_norm:
if self.qk_norm:
self.q_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm)
self.k_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm)
else:
self.q_norm = self.k_norm = nn.Identity()

self.reset_parameters()

def reset_parameters(self, *args, **kwargs):
nn.init.xavier_uniform_(self.wq.weight)
nn.init.xavier_uniform_(self.wk.weight)
nn.init.xavier_uniform_(self.wv.weight)
nn.init.xavier_uniform_(self.wo.weight)
layer_init_kaiming_normal(self.wq)
# MuReadout layers have their own initialization
layer_init_kaiming_normal(self.wo)
if self.qk_norm:
self.q_norm.reset_parameters()
self.k_norm.reset_parameters()

# copied from huggingface modeling_llama.py
def _upad_input(
Expand Down Expand Up @@ -780,24 +766,24 @@ def __init__(
hidden_size=dim,
intermediate_size=hidden_dim,
hidden_act="silu",
)
) # mup
self.ffn = LigerSwiGLUMLP(config)
else:
self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
) # mup
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
) # mup
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
) # mup

def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
Expand All @@ -809,45 +795,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.w2(F.silu(x1) * x3)
return output

def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
out_init_std = out_init_std / factor
def reset_parameters(self):
if self.liger_ffn:
# Initialize LigerSwiGLUMLP parameters
# gate_proj and up_proj correspond to w1 and w3
for w in [self.ffn.gate_proj, self.ffn.up_proj]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
# down_proj corresponds to w2
nn.init.trunc_normal_(
self.ffn.down_proj.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
layer_init_kaiming_normal(self.ffn.gate_proj)
layer_init_kaiming_normal(self.ffn.up_proj)
layer_init_kaiming_normal(self.ffn.down_proj)
else:
for w in [self.w1, self.w3]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
layer_init_kaiming_normal(self.w1)
layer_init_kaiming_normal(self.w3)
layer_init_kaiming_normal(self.w2)


class TransformerBlock(nn.Module):
Expand Down Expand Up @@ -901,11 +857,11 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def init_weights(self, init_std=None, factor=1.0):
self.attention.reset_parameters(init_std, factor)
def init_weights(self):
self.attention.reset_parameters()
self.attention_norm.reset_parameters()

self.feed_forward.reset_parameters(init_std, factor)
self.feed_forward.reset_parameters()
self.ffn_norm.reset_parameters()


Expand Down Expand Up @@ -1014,16 +970,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.w1(F.silu(x))
return output

def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.in_dim ** (-0.5))
init_std = init_std / factor
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
def reset_parameters(self):
layer_init_kaiming_normal(self.w1)


class TimestepEmbedder(nn.Module):
Expand All @@ -1033,10 +981,14 @@ class TimestepEmbedder(nn.Module):

def __init__(self, hidden_size: int, time_embedding_size: int = 256):
super().__init__()
self.w1 = nn.Linear(
time_embedding_size,
hidden_size,
bias=True,
self.mlp = nn.Sequential(
nn.Linear(
time_embedding_size,
hidden_size,
bias=True,
), # mup: input weights
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True), # mup: hidden weights
)
self.w2 = nn.Linear(
hidden_size,
Expand All @@ -1046,6 +998,8 @@ def __init__(self, hidden_size: int, time_embedding_size: int = 256):
self.hidden_size = hidden_size
self.time_embedding_size = time_embedding_size

self.reset_parameters()

@staticmethod
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
Expand All @@ -1065,30 +1019,12 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000):

def forward(self, t: torch.Tensor) -> torch.Tensor:
t_freq = self.timestep_embedding(t, self.time_embedding_size)
t_emb = self.w1(t_freq.to(self.w1.weight.dtype))
t_emb = self.w2(F.silu(t_emb))
t_emb = self.mlp(t_freq.to(self.w1.weight.dtype))
return t_emb

def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.time_embedding_size ** (-0.5))
out_init_std = init_std or (self.hidden_size ** (-0.5))
out_init_std = out_init_std / factor
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
nn.init.normal_(self.w1.bias, std=0.02)
nn.init.normal_(self.w2.bias, std=0.02)
def reset_parameters(self):
layer_init_kaiming_normal(self.mlp[0])
layer_init_kaiming_normal(self.mlp[2])


class ImageEmbedder(nn.Module):
Expand All @@ -1102,21 +1038,12 @@ def __init__(self, in_dim, out_dim):
in_features=in_dim,
out_features=out_dim,
bias=True,
)
) # mup: input weights
self.in_dim = in_dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w1(x)

def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.in_dim ** (-0.5))
init_std = init_std / factor
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)

nn.init.normal_(self.w1.bias, std=0.02)
def reset_parameters(self):
layer_init_kaiming_normal(self.w1)

Loading