diff --git a/src/seamless_communication/models/unity/builder.py b/src/seamless_communication/models/unity/builder.py index 5da29d14..8e4f5864 100644 --- a/src/seamless_communication/models/unity/builder.py +++ b/src/seamless_communication/models/unity/builder.py @@ -293,8 +293,6 @@ def build_model(self) -> UnitYModel: text_encoder_frontend = None text_encoder = None - assert isinstance(text_embed.weight, Parameter) - final_proj = TiedProjection(text_embed.weight, bias=None) if self.t2u_builder is None: diff --git a/src/seamless_communication/models/unity/t2u_builder.py b/src/seamless_communication/models/unity/t2u_builder.py index 3da83638..b5ced1d7 100644 --- a/src/seamless_communication/models/unity/t2u_builder.py +++ b/src/seamless_communication/models/unity/t2u_builder.py @@ -11,7 +11,7 @@ from fairseq2.assets.card import AssetCard from fairseq2.data import VocabularyInfo from fairseq2.models.utils.arch_registry import ArchitectureRegistry -from fairseq2.nn.embedding import Embedding, StandardEmbedding +from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding from fairseq2.nn.position_encoder import SinusoidalPositionEncoder from fairseq2.nn.projection import TiedProjection from fairseq2.nn.transformer import ( @@ -242,8 +242,6 @@ def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]: decoder = self.build_decoder() - assert isinstance(embed_unit.weight, Parameter) - final_proj = TiedProjection(embed_unit.weight, bias=None) if self.config.nar_decoder_config is None: @@ -265,13 +263,13 @@ def build_model(self) -> Union[UnitYT2UModel, UnitYNART2UModel]: self.config.unit_pad_idx, ) - def build_unit_embedding(self) -> Embedding: + def build_unit_embedding(self) -> StandardEmbedding: """Build a unit embedding table.""" return StandardEmbedding( num_embeddings=self.config.unit_vocabulary_size, embedding_dim=self.config.model_dim, pad_idx=self.config.unit_pad_idx, - scaled=True, + init_fn=init_scaled_embedding, device=self.device, dtype=self.dtype, ) @@ -381,7 +379,7 @@ def build_nar_decoder_frontend(self, embed_unit: Embedding) -> NARDecoderFronten num_embeddings=self.config.nar_decoder_config.char_vocabulary_size, embedding_dim=self.config.model_dim, pad_idx=text_pad_idx, - scaled=True, + init_fn=init_scaled_embedding, device=self.device, dtype=self.dtype, )