diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index a5066958b6c6..c4f6f5c65ded 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -326,7 +326,7 @@ class DetaObjectDetectionOutput(ModelOutput): encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None encoder_attentions: Optional[tuple[torch.FloatTensor]] = None - enc_outputs_class = None + enc_outputs_class: Optional[torch.FloatTensor] = None enc_outputs_coord_logits: Optional[torch.FloatTensor] = None output_proposals: Optional[torch.FloatTensor] = None @@ -1857,7 +1857,7 @@ def __init__(self, config: DetaConfig): prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) - self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value + self.class_embed.bias.data.fill_(bias_value) nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)