-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Execuse me, I am leaning the code of class StableAudioDiTModel , I do not know what is the argument global_states_input_dim used to? It seems that it is a must component that should be packed before the hidden_states sequence. and its default dim seems larger then the transformer inner_dim. What is that componenet means? If it is used to take in additional conditions, that seems can be done in the encoder outside. and compared with the concatenate, I think it may be better to repeat condition embedding to the sequence length and concat on hidden_dim.
And what is the sample_size: int = 1024, parameter used in the model creation? it seems not used during forward call
The func doc of class StableAudioDiTModel:forward, it said encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):. why the shape of encoder_attention_mask is batch_size X sequence_len instead of batch_size X encoder_sequence_len to be identical with the shape of the input encoder_hidden_states
and why thee return value of this forward is the direct (hidden_states,) but not (hidden_states * attention_mask, )?