diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c46422784e13..1687a9c7b024 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -94,6 +94,7 @@ download_url, extract_commit_hash, has_file, + init_empty_weights, is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index a5066958b6c6..708e635e0799 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -1857,7 +1857,10 @@ def __init__(self, config: DetaConfig): prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) - self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value + self.class_embed.bias.data = torch.ones(config.num_labels, + device=self.class_embed.bias.data.device, + dtype=self.class_embed.bias.data.dtype, + ) * bias_value nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)