Skip to content

Commit

Permalink
Use embed.init_fn instead of scaled=True (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Oct 10, 2023
1 parent c7f576a commit 9f142ac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
2 changes: 0 additions & 2 deletions src/seamless_communication/models/unity/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,6 @@ def build_model(self) -> UnitYModel:
text_encoder_frontend = None
text_encoder = None

assert isinstance(text_embed.weight, Parameter)

final_proj = TiedProjection(text_embed.weight, bias=None)

if self.t2u_builder is None:
Expand Down
10 changes: 4 additions & 6 deletions src/seamless_communication/models/unity/t2u_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fairseq2.assets.card import AssetCard
from fairseq2.data import VocabularyInfo
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.nn.embedding import Embedding, StandardEmbedding
from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
from fairseq2.nn.projection import TiedProjection
from fairseq2.nn.transformer import (
Expand Down Expand Up @@ -242,8 +242,6 @@ def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:

decoder = self.build_decoder()

assert isinstance(embed_unit.weight, Parameter)

final_proj = TiedProjection(embed_unit.weight, bias=None)

if self.config.nar_decoder_config is None:
Expand All @@ -265,13 +263,13 @@ def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]:
self.config.unit_pad_idx,
)

def build_unit_embedding(self) -> Embedding:
def build_unit_embedding(self) -> StandardEmbedding:
"""Build a unit embedding table."""
return StandardEmbedding(
num_embeddings=self.config.unit_vocabulary_size,
embedding_dim=self.config.model_dim,
pad_idx=self.config.unit_pad_idx,
scaled=True,
init_fn=init_scaled_embedding,
device=self.device,
dtype=self.dtype,
)
Expand Down Expand Up @@ -381,7 +379,7 @@ def build_nar_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFronten
num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
embedding_dim=self.config.model_dim,
pad_idx=text_pad_idx,
scaled=True,
init_fn=init_scaled_embedding,
device=self.device,
dtype=self.dtype,
)
Expand Down

0 comments on commit 9f142ac

Please sign in to comment.