From 8997ccf3f402a0960fc0aaafbb1c768c607189c9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 27 Nov 2023 11:57:45 -0700 Subject: [PATCH 01/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 3f828c6d0a..83891e324f 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -42,7 +42,7 @@ def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activati self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) @@ -64,7 +64,7 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: tgt = tgt + self.dropout1(tgt) tgt = self.norm1(tgt) - tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt2 = self.self_attn(tgt, memory, memory)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) From 69f3634ee4790df2806ea25f5d426c56c9935372 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 27 Nov 2023 12:13:56 -0700 Subject: [PATCH 02/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 83891e324f..8e5d85e42f 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -59,9 +59,9 @@ def __setstate__(self, state): super(TransformerDecoderLayerOptimal, self).__setstate__(state) def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False) -> Tensor: tgt = tgt + self.dropout1(tgt) tgt = self.norm1(tgt) tgt2 = self.self_attn(tgt, memory, memory)[0] From 92f0b7fcabbce206e5d6f6a58ab424f9bc183ea4 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 8 Dec 2023 13:17:23 -0700 Subject: [PATCH 03/34] Update beit.py --- timm/models/beit.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 663dcc4bd4..953dd2602b 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -49,6 +49,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table +from timm.layers import NormMlpClassifierHead from ._builder import build_model_with_cfg @@ -335,12 +336,22 @@ def __init__( window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, ) for i in range(depth)]) - - use_fc_norm = self.global_pool == 'avg' - self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + # pre_norm if using cls token features, fc_norm otherwise + # cls -> norm -> ... + # pool -> norm -> ... + #use_fc_norm = self.global_pool == 'avg' + #self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + #self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + #self.head_drop = nn.Dropout(drop_rate) + #self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=norm_layer, + ) self.apply(self._init_weights) if self.pos_embed is not None: @@ -417,12 +428,9 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool: - x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] - x = self.fc_norm(x) - x = self.head_drop(x) - return x if pre_logits else self.head(x) - + # feed in token outputs if pooling, otherwise take cls token features + x = x[:, self.num_prefix_tokens:] if self.global_pool else x[:, 0] + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) From d10f002120feff98c1976cbb8d93fbe36dfe6717 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 8 Dec 2023 13:33:02 -0700 Subject: [PATCH 04/34] Update beit.py --- timm/models/beit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 953dd2602b..e57e908798 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -424,7 +424,7 @@ def forward_features(self, x): x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) else: x = blk(x, shared_rel_pos_bias=rel_pos_bias) - x = self.norm(x) + #x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): From f59c1f16812806e4b6cd4f5614ceff5a49adef7e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 8 Dec 2023 13:41:46 -0700 Subject: [PATCH 05/34] Update beit.py --- timm/models/beit.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index e57e908798..663dcc4bd4 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -49,7 +49,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table -from timm.layers import NormMlpClassifierHead from ._builder import build_model_with_cfg @@ -336,22 +335,12 @@ def __init__( window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, ) for i in range(depth)]) - - # pre_norm if using cls token features, fc_norm otherwise - # cls -> norm -> ... - # pool -> norm -> ... - #use_fc_norm = self.global_pool == 'avg' - #self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) - #self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - #self.head_drop = nn.Dropout(drop_rate) - #self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head = NormMlpClassifierHead( - self.num_features, - num_classes, - pool_type=global_pool, - drop_rate=drop_rate, - norm_layer=norm_layer, - ) + + use_fc_norm = self.global_pool == 'avg' + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) if self.pos_embed is not None: @@ -424,13 +413,16 @@ def forward_features(self, x): x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) else: x = blk(x, shared_rel_pos_bias=rel_pos_bias) - #x = self.norm(x) + x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): - # feed in token outputs if pooling, otherwise take cls token features - x = x[:, self.num_prefix_tokens:] if self.global_pool else x[:, 0] - return self.head(x, pre_logits=True) if pre_logits else self.head(x) + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) From 318369a9bf905efa9f261a441a6dec95bdc7fbf1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 10 Dec 2023 05:32:29 -0700 Subject: [PATCH 06/34] Update davit.py --- timm/models/davit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index eb6492f962..feeb2a3f89 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -568,11 +568,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - x = self.head.global_pool(x) - x = self.head.norm(x) - x = self.head.flatten(x) - x = self.head.drop(x) - return x if pre_logits else self.head.fc(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) From 03606ab8627c3248cf6c371f292de3d0ada654a2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 10 Dec 2023 15:18:54 -0700 Subject: [PATCH 07/34] Update edgenext.py --- timm/models/edgenext.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 2d511e6cbc..4c095041b8 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -375,13 +375,23 @@ def __init__( self.stages = nn.Sequential(*stages) self.num_features = dims[-1] - self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() - self.head = nn.Sequential(OrderedDict([ - ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), - ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + if head_norm_first: + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @@ -418,12 +428,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( - x = self.head.global_pool(x) - x = self.head.norm(x) - x = self.head.flatten(x) - x = self.head.drop(x) - return x if pre_logits else self.head.fc(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) From 4f8898a033f336577913082d6a6c05a2603f7cec Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 10 Dec 2023 15:18:54 -0700 Subject: [PATCH 08/34] Update edgenext.py --- timm/models/edgenext.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index d90471fb1b..3baf54767f 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -410,13 +410,23 @@ def __init__( self.stages = nn.Sequential(*stages) self.num_features = dims[-1] - self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() - self.head = nn.Sequential(OrderedDict([ - ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), - ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + if head_norm_first: + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @@ -453,12 +463,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( - x = self.head.global_pool(x) - x = self.head.norm(x) - x = self.head.flatten(x) - x = self.head.drop(x) - return x if pre_logits else self.head.fc(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) From c0206b8fe89725b2ffd69ba7f0aa7fb0d6fe3c9f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 10 Dec 2023 15:23:38 -0700 Subject: [PATCH 09/34] Update edgenext.py --- timm/models/edgenext.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 3baf54767f..6d2ab8d5cb 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -18,6 +18,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq @@ -451,10 +452,7 @@ def get_classifier(self): return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) From 95bd9dcb3930a6b5d484f0dcd212c2865adc8090 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 13 Dec 2023 23:33:55 -0800 Subject: [PATCH 10/34] vectorize GroupFC --- timm/layers/ml_decoder.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 8e5d85e42f..ff7d498680 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -88,6 +88,7 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None # out = out.view((h.shape[0], self.group_size * self.num_queries)) # return out +''' @torch.jit.script class GroupFC(object): def __init__(self, embed_len_decoder: int): @@ -98,7 +99,19 @@ def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: h_i = h[:, i, :] w_i = duplicate_pooling[i, :, :] out_extrap[:, i, :] = torch.matmul(h_i, w_i) - +''' +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + ''' + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): # [B, K, C], [K, C, N/K], [B, K, N/K] + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] # [B, 1, C] + w_i = duplicate_pooling[i, :, :] # [1, C, N/K] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) # [B, 1, N/K] + ''' + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + out_extrap = (h.permute(1, 0, 2) @ duplicate_pooling).permute(1,0,2) class MLDecoder(nn.Module): def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): From 1448b7dcee0173ef6e6c1864ba2589b63ecce2ce Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 13 Dec 2023 23:36:58 -0800 Subject: [PATCH 11/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index ff7d498680..57d2b442ce 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -161,8 +161,9 @@ def forward(self, x): h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] h = h.transpose(0, 1) - out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) - self.group_fc(h, self.duplicate_pooling, out_extrap) + #out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + #self.group_fc(h, self.duplicate_pooling, out_extrap) + out_extrap = (h.permute(1, 0, 2) @ duplicate_pooling).permute(1,0,2) # [B, K, N/K] h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out From ab3eacbc860e1b1ed581d3091aa6ad29e3ba4051 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 13 Dec 2023 23:38:07 -0800 Subject: [PATCH 12/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 57d2b442ce..8e72a68b51 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -163,7 +163,7 @@ def forward(self, x): #out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) #self.group_fc(h, self.duplicate_pooling, out_extrap) - out_extrap = (h.permute(1, 0, 2) @ duplicate_pooling).permute(1,0,2) # [B, K, N/K] + out_extrap = (h.permute(1, 0, 2) @ self.duplicate_pooling).permute(1,0,2) # [B, K, N/K] h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out From b4afe6cfc205ec3ec8f6989804069a5c9b0a9e62 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 14 Dec 2023 15:42:52 -0800 Subject: [PATCH 13/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 8e72a68b51..185818b7d2 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -88,34 +88,22 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None # out = out.view((h.shape[0], self.group_size * self.num_queries)) # return out -''' + @torch.jit.script class GroupFC(object): def __init__(self, embed_len_decoder: int): self.embed_len_decoder = embed_len_decoder - def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): - for i in range(self.embed_len_decoder): - h_i = h[:, i, :] - w_i = duplicate_pooling[i, :, :] - out_extrap[:, i, :] = torch.matmul(h_i, w_i) -''' -class GroupFC(object): - def __init__(self, embed_len_decoder: int): - self.embed_len_decoder = embed_len_decoder - ''' def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): # [B, K, C], [K, C, N/K], [B, K, N/K] for i in range(self.embed_len_decoder): h_i = h[:, i, :] # [B, 1, C] w_i = duplicate_pooling[i, :, :] # [1, C, N/K] out_extrap[:, i, :] = torch.matmul(h_i, w_i) # [B, 1, N/K] - ''' - def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): - out_extrap = (h.permute(1, 0, 2) @ duplicate_pooling).permute(1,0,2) -class MLDecoder(nn.Module): - def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): - super(MLDecoder, self).__init__() + +class MLDecoderLegacy(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048, simple_group_fc = True): + super(MLDecoderLegacy, self).__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes @@ -137,6 +125,7 @@ def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial self.query_embed.requires_grad_(False) # group fully-connected + self.simple_group_fc = simple_group_fc self.num_classes = num_classes self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) self.duplicate_pooling = torch.nn.Parameter( @@ -144,7 +133,7 @@ def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.constant_(self.duplicate_pooling_bias, 0) - self.group_fc = GroupFC(embed_len_decoder) + self.group_fc = None if simple_group_fc else GroupFC(embed_len_decoder) def forward(self, x): if len(x.shape) == 4: # [bs,2048, 7,7] @@ -159,11 +148,14 @@ def forward(self, x): # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] - h = h.transpose(0, 1) - - #out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) - #self.group_fc(h, self.duplicate_pooling, out_extrap) - out_extrap = (h.permute(1, 0, 2) @ self.duplicate_pooling).permute(1,0,2) # [B, K, N/K] + + if(self.simple_group_fc): + out_extrap = (h @ self.duplicate_pooling).permute(1,0,2) # [B, K, N/K] + else: + h = h.transpose(0, 1) + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out From dd8d2310dec1132b2126c45857ef3f0c2ddc04bc Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 18:06:23 -0800 Subject: [PATCH 14/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 146 +++++++++++++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 185818b7d2..2530596dba 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -1,9 +1,10 @@ from typing import Optional import torch -from torch import nn from torch import nn, Tensor +import torch.nn.functional as F from torch.nn.modules.transformer import _get_activation_fn +from timm.layers import Mlp def add_ml_decoder_head(model): @@ -160,3 +161,146 @@ def forward(self, x): h_out += self.duplicate_pooling_bias logits = h_out return logits + +class GroupLinear(nn.Module): + def __init__( + self, + dim + num_classes, + num_groups, + ): + super().__init__() + self.num_classes = num_classes + duplicate_factor = int(num_classes / num_groups + 0.999) + self.weight = nn.Parameter(torch.Tensor(num_groups, dim, duplicate_factor)) + self.bias = nn.Parameter(torch.Tensor(num_classes)) + nn.init.xavier_normal_(self.weight) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + x = (x @ self.weight).permute(1, 0, 2).flatten(1)[:, :self.num_classes] + x += self.bias + return x + +class MLDecoder(nn.Module): + def __init__( + self, + in_features: int, + num_classes: int, + dim: int = 768, + num_groups: int = 100, + num_heads: int = 8, + embed_drop: float = 0.1, + embed_norm: bool = True, + k_norm: bool = False, + attn_drop: float = 0.1, + mlp_ratio: float = 8/3, + proj_drop: float = 0.1, + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + + ): + super().__init__() + + + # non-learnable queries + self.query_embed = nn.Embedding(num_groups, dim) + self.query_embed.requires_grad_(False) + self.embed_drop = nn.Dropout(embed_drop) + self.embed_norm = norm_layer(dim) + + self.norm1 = norm_layer(dim) + + + self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop) + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) +] + + def forward(self, x): + # BCHW to BNC + if(len(x.shape) = 4): + x = x.flatten(2).transpose(1, 2) + + + q = self.embed_norm(self.embed_drop(self.query_embed.weight)) + xN = self.norm1(x) + x = x + self.attn(q, xN, xN)[0] + x = x + self.mlp(self.norm2(x)) + + + +class CrossAttention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + ''' + q = self.embed_norm(self.embed_drop(self.query_embed.weight)) + q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + else: + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + ''' \ No newline at end of file From 51c85bd4f6547f68be430baa580e58e1e9ba70c2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 19:53:11 -0800 Subject: [PATCH 15/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 2530596dba..9f27582d32 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -220,6 +220,7 @@ def __init__( act_layer=act_layer, drop=proj_drop, ) + self.fc = GroupLinear(dim, num_classes, num_groups) ] def forward(self, x): @@ -232,6 +233,7 @@ def forward(self, x): xN = self.norm1(x) x = x + self.attn(q, xN, xN)[0] x = x + self.mlp(self.norm2(x)) + x = self.fc(x) From 36d69a45905a1ba5cf1885a1a240f9b5178eb25b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 19:55:52 -0800 Subject: [PATCH 16/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 9f27582d32..66e9e8dee7 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -221,7 +221,6 @@ def __init__( drop=proj_drop, ) self.fc = GroupLinear(dim, num_classes, num_groups) -] def forward(self, x): # BCHW to BNC From 4161159715316d800294d201b92623ed5eccf689 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 19:57:07 -0800 Subject: [PATCH 17/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 66e9e8dee7..e47513d541 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -165,7 +165,7 @@ def forward(self, x): class GroupLinear(nn.Module): def __init__( self, - dim + dim, num_classes, num_groups, ): From aa19600082d67f678fd457908e7d41952a2170ef Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 19:58:08 -0800 Subject: [PATCH 18/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index e47513d541..b8a46214d4 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -224,7 +224,7 @@ def __init__( def forward(self, x): # BCHW to BNC - if(len(x.shape) = 4): + if(len(x.shape) == 4): x = x.flatten(2).transpose(1, 2) From 6b99c9129265ea8be23c56fd246f29462b44346f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 19:59:22 -0800 Subject: [PATCH 19/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index b8a46214d4..31d2b0cc2b 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -234,7 +234,7 @@ def forward(self, x): x = x + self.mlp(self.norm2(x)) x = self.fc(x) - +''' class CrossAttention(nn.Module): fused_attn: Final[bool] @@ -286,7 +286,7 @@ def forward(self, q, x) -> torch.Tensor: x = self.proj(x) x = self.proj_drop(x) return x - ''' + q = self.embed_norm(self.embed_drop(self.query_embed.weight)) q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) else: @@ -304,4 +304,5 @@ def forward(self, q, x) -> torch.Tensor: attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v - ''' \ No newline at end of file + +''' \ No newline at end of file From db24aea61ee7c39aad492eff21c40edf99a9ab1f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 20:04:59 -0800 Subject: [PATCH 20/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 31d2b0cc2b..fd93c1817d 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -209,6 +209,8 @@ def __init__( self.embed_drop = nn.Dropout(embed_drop) self.embed_norm = norm_layer(dim) + self.proj = nn.Linear(in_features, dim) + self.act = act_layer() self.norm1 = norm_layer(dim) @@ -227,7 +229,7 @@ def forward(self, x): if(len(x.shape) == 4): x = x.flatten(2).transpose(1, 2) - + x = self.act(self.proj(x)) q = self.embed_norm(self.embed_drop(self.query_embed.weight)) xN = self.norm1(x) x = x + self.attn(q, xN, xN)[0] From 09749be33474ff4d2a7e8cc68a936c5b1a186921 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 22:44:51 -0800 Subject: [PATCH 21/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 46 +++++++++++++-------------------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index fd93c1817d..39fd9b2248 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -177,7 +177,7 @@ def __init__( nn.init.xavier_normal_(self.weight) nn.init.constant_(self.bias, 0) - def forward(self, x): + def forward(self, x): # [B, K, C] x = (x @ self.weight).permute(1, 0, 2).flatten(1)[:, :self.num_classes] x += self.bias return x @@ -214,7 +214,7 @@ def __init__( self.norm1 = norm_layer(dim) - self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop) + self.attn = CrossAttention(dim, num_heads=num_heads) self.norm2 = norm_layer(dim) self.mlp = Mlp( in_features=dim, @@ -231,12 +231,11 @@ def forward(self, x): x = self.act(self.proj(x)) q = self.embed_norm(self.embed_drop(self.query_embed.weight)) - xN = self.norm1(x) - x = x + self.attn(q, xN, xN)[0] + x = x + self.attn(q, self.norm1(x)) x = x + self.mlp(self.norm2(x)) x = self.fc(x) -''' + class CrossAttention(nn.Module): fused_attn: Final[bool] @@ -247,8 +246,8 @@ def __init__( num_heads: int = 8, qkv_bias: bool = True, qk_norm: bool = False, - attn_drop: float = 0., - proj_drop: float = 0., + attn_drop: float = 0.1, + proj_drop: float = 0.1, norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() @@ -267,9 +266,11 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) def forward(self, q, x) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) + K, _ = q.shape # [K, C] + B, N, C = x.shape # [B, N, C] + q = self.q(q).reshape(1, K, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [1, n_h, K, d_h] + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [2, B, n_h, N, d_h] + k, v = kv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: @@ -279,32 +280,15 @@ def forward(self, q, x) -> torch.Tensor: ) else: q = q * self.scale - attn = q @ k.transpose(-2, -1) + attn = q @ k.transpose(-2, -1) # [B, n_h, K, N] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = attn @ v + x = attn @ v # [B, n_h, K, d_h] - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, K, C) x = self.proj(x) x = self.proj_drop(x) return x - q = self.embed_norm(self.embed_drop(self.query_embed.weight)) - q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) - else: - kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - k, v = kv.unbind(0) - - if self.fused_attn: - x = F.scaled_dot_product_attention( - q, k, v, - dropout_p=self.attn_drop.p if self.training else 0., - ) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v + -''' \ No newline at end of file From 2a05a28a2323dfeff4092e4bb75878f72978f61d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 22:47:17 -0800 Subject: [PATCH 22/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 39fd9b2248..cc3dfff512 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -4,7 +4,9 @@ from torch import nn, Tensor import torch.nn.functional as F from torch.nn.modules.transformer import _get_activation_fn -from timm.layers import Mlp +from torch.jit import Final + +from timm.layers import Mlp, use_fused_attn def add_ml_decoder_head(model): From 50fe44f00dc9ba97289450a9fcfc40a08d5698a7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 22:52:02 -0800 Subject: [PATCH 23/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index cc3dfff512..75db63df8b 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -233,7 +233,7 @@ def forward(self, x): x = self.act(self.proj(x)) q = self.embed_norm(self.embed_drop(self.query_embed.weight)) - x = x + self.attn(q, self.norm1(x)) + x = self.attn(q, self.norm1(x)) x = x + self.mlp(self.norm2(x)) x = self.fc(x) From 565a57adb112a9c02247e6cff70afb5bd72fca9f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 22:59:26 -0800 Subject: [PATCH 24/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 75db63df8b..3c5c756d71 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -287,7 +287,7 @@ def forward(self, q, x) -> torch.Tensor: attn = self.attn_drop(attn) x = attn @ v # [B, n_h, K, d_h] - x = x.transpose(1, 2).reshape(B, K, C) + x = x.permute(2, 0, 1, 3).reshape(K, B, C) x = self.proj(x) x = self.proj_drop(x) return x From 95a24fc020a31d0a547b8d2746e066a258784f78 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 23:00:45 -0800 Subject: [PATCH 25/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 3c5c756d71..a4376084fd 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -236,6 +236,7 @@ def forward(self, x): x = self.attn(q, self.norm1(x)) x = x + self.mlp(self.norm2(x)) x = self.fc(x) + return x From 7ebfb85990fbbf98e1f5ada8a049321a4c978841 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 23:11:37 -0800 Subject: [PATCH 26/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index a4376084fd..ab6839380f 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -105,7 +105,7 @@ def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: class MLDecoderLegacy(nn.Module): - def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048, simple_group_fc = True): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, in_features=2048, simple_group_fc = True): super(MLDecoderLegacy, self).__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: From b2917f1d3959d7ccf9573fa1e056fb6873316b14 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 16 Dec 2023 23:16:50 -0800 Subject: [PATCH 27/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index ab6839380f..70be6a4ee3 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -113,7 +113,7 @@ def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, in_feat # switching to 768 initial embeddings decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding - self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + self.embed_standart = nn.Linear(in_features, decoder_embedding) # decoder decoder_dropout = 0.1 From e0c9f1448c7030fb6cd46f309bb95aa59d2851a9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 00:31:57 -0800 Subject: [PATCH 28/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 108 ++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 70be6a4ee3..fc32d831c0 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -164,6 +164,61 @@ def forward(self, x): logits = h_out return logits + +class CrossAttention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.1, + proj_drop: float = 0.1, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x) -> torch.Tensor: + K, _ = q.shape # [K, C] + B, N, C = x.shape # [B, N, C] + q = self.q(q).reshape(1, K, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [1, n_h, K, d_h] + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [2, B, n_h, N, d_h] + k, v = kv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) # [B, n_h, K, N] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v # [B, n_h, K, d_h] + + x = x.permute(2, 0, 1, 3).reshape(K, B, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + class GroupLinear(nn.Module): def __init__( self, @@ -233,65 +288,14 @@ def forward(self, x): x = self.act(self.proj(x)) q = self.embed_norm(self.embed_drop(self.query_embed.weight)) - x = self.attn(q, self.norm1(x)) + x = self.attn(q, self.norm1(x)) + q x = x + self.mlp(self.norm2(x)) x = self.fc(x) return x -class CrossAttention(nn.Module): - fused_attn: Final[bool] - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - qk_norm: bool = False, - attn_drop: float = 0.1, - proj_drop: float = 0.1, - norm_layer: nn.Module = nn.LayerNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.fused_attn = use_fused_attn() - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, q, x) -> torch.Tensor: - K, _ = q.shape # [K, C] - B, N, C = x.shape # [B, N, C] - q = self.q(q).reshape(1, K, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [1, n_h, K, d_h] - kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [2, B, n_h, N, d_h] - k, v = kv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention( - q, k, v, - dropout_p=self.attn_drop.p if self.training else 0., - ) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) # [B, n_h, K, N] - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v # [B, n_h, K, d_h] - - x = x.permute(2, 0, 1, 3).reshape(K, B, C) - x = self.proj(x) - x = self.proj_drop(x) - return x From 0124d0f4848dbb2f7cd9ad4ce58978545d7ff18c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 18:20:00 -0800 Subject: [PATCH 29/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index fc32d831c0..731f0804db 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -288,7 +288,7 @@ def forward(self, x): x = self.act(self.proj(x)) q = self.embed_norm(self.embed_drop(self.query_embed.weight)) - x = self.attn(q, self.norm1(x)) + q + x = self.attn(q, self.norm1(x))# + q.unsqueeze(1) x = x + self.mlp(self.norm2(x)) x = self.fc(x) return x From 5a39bb310914aadb2ac876fe10a90784e6feea7a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 13:24:39 -0800 Subject: [PATCH 30/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 731f0804db..8066c61dfd 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -242,8 +242,8 @@ def forward(self, x): # [B, K, C] class MLDecoder(nn.Module): def __init__( self, - in_features: int, num_classes: int, + in_features: int, dim: int = 768, num_groups: int = 100, num_heads: int = 8, From b680c4811740120d23bbc69af6490b9ad61dd243 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 13:32:00 -0800 Subject: [PATCH 31/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 8066c61dfd..731f0804db 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -242,8 +242,8 @@ def forward(self, x): # [B, K, C] class MLDecoder(nn.Module): def __init__( self, - num_classes: int, in_features: int, + num_classes: int, dim: int = 768, num_groups: int = 100, num_heads: int = 8, From 3b506e9bfb9ce9da58e88b9b97a7f8370ac58e23 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 16:11:09 -0800 Subject: [PATCH 32/34] Update ml_decoder.py --- timm/layers/ml_decoder.py | 106 +++++++++++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 13 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 731f0804db..f6c0e48f38 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -7,29 +7,109 @@ from torch.jit import Final from timm.layers import Mlp, use_fused_attn +from timm.layers.classifier import _create_pool +class MLDecoderHead(nn.Module): + """MLDecoder wrapper with forward compatible with ClassifierHead""" + + def __init__(self, in_features, num_classes, pool_type='avg', use_conv=False, input_fmt='NCHW'): + super(MLDecoderHead, self).__init__() + self.in_features = in_features + self.use_conv = use_conv + self.input_fmt = input_fmt + + self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv, input_fmt=input_fmt) + self.head = MLDecoder(in_features=in_features, num_classes=num_classes) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() + + + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + if global_pool != self.global_pool.pool_type: + self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) + self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() + num_pooled_features = self.in_features * self.global_pool.feat_mult() + self.head = MLDecoder(in_features=in_features, num_classes=num_classes) + + + def forward(self, x, pre_logits: bool = False): + # pool for compatibility with ClassifierHead + if self.input_fmt == 'NHWC': + x = x.permute(0, 3, 1, 2) + if pre_logits: + x = self.global_pool(x) + return x.flatten(1) + else: + x = self.head(x) + return self.flatten(x) + def add_ml_decoder_head(model): + + # ignore CoaT, crossvit + # ignore distillation models: deit_distilled, efficientformer V2 + num_classes = model.num_classes + num_features = model.num_features + + assert num_classes > 0, "MLDecoder requires a model to have num_classes > 0" + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 model.global_pool = nn.Identity() del model.fc - num_classes = model.num_classes - num_features = model.num_features - model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + + model.fc = MLDecoder(num_classes=num_classes, in_features=num_features) + + elif hasattr(model, 'fc_norm') or 'Cait' in model._get_name(): # ViT, BEiT, EVA + model.global_pool = None # disable any pooling, model instantiation leaves 1 norm layer after features, [B, n + K x K, C] + if hasattr(model, 'attn_pool'): + model.attn_pool = None + model.head_drop = nn.Identity() + model.head = MLDecoder(num_classes=num_classes, in_features=num_features) + + elif 'MetaFormer' in model._get_name(): + if hasattr(model.head, 'flatten'): # ConvNext case + model.head.flatten = nn.Identity() + model.head.global_pool = nn.Identity() + model.head.drop = nn.Identity() + del model.head.fc + model.head.fc = MLDecoder(num_classes=num_classes, in_features=num_features) + + # maybe and isinstance(model.head, (NormMlpClassifierHead, ClassifierHead) ? + elif hasattr(model, 'head'): # ClassifierHead, nn.Sequential + input_fmt = getattr(model.head, 'input_fmt', 'NCHW') + model.head = MLDecoderHead(num_features, num_classes) + if hasattr(model, 'global_pool'): + if(isinstance(model.global_pool, nn.Module)): + model.global_pool = nn.Identity() + else: + model.global_pool = None + if hasattr(model, 'head_drop'): + model.head_drop = nn.Identity() + + elif 'MobileNetV3' in model._get_name(): # mobilenetv3 - conflict with efficientnet + + model.flatten = nn.Identity() + del model.classifier + model.classifier = MLDecoder(num_classes=num_classes, in_features=num_features) + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet model.global_pool = nn.Identity() del model.classifier - num_classes = model.num_classes - num_features = model.num_features - model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) - elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') - del model.head - num_classes = model.num_classes - num_features = model.num_features - model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + model.classifier = MLDecoder(num_classes=num_classes, in_features=num_features) + elif hasattr(model, 'global_pool') and hasattr(model, 'last_linear'): # InceptionV4 + model.global_pool = nn.Identity() + del model.last_linear + model.last_linear = MLDecoder(num_classes=num_classes, in_features=num_features) + + elif hasattr(model, 'global_pool') and hasattr(model, 'classif'): # InceptionResnetV2 + model.global_pool = nn.Identity() + del model.classif + model.classif = MLDecoder(num_classes=num_classes, in_features=num_features) + else: - print("Model code-writing is not aligned currently with ml-decoder") - exit(-1) + raise Exception("Model code-writing is not aligned currently with ml-decoder") + + # FIXME does not work if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout model.drop_rate = 0 return model From 4f1b76bcd42de5855a44fcee6c73e52dd7792d9a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 16:18:17 -0800 Subject: [PATCH 33/34] tests --- tests/test_layers.py | 15 +++++++++++++++ timm/layers/ml_decoder.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 92f6b683d3..e76737dec7 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -76,3 +76,18 @@ def test_hard_swish_grad(): def test_hard_mish_grad(): for _ in range(100): _run_act_layer_grad('hard_mish') + + +MLDECODER_EXCLUDE_FILTERS = [ + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', + '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*', + '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560'] + +def test_ml_decoder(): + for modelName in timm.list_models(pretrained=False, exclude_filters = MLDECODER_EXCLUDE_FILTERS): + model = timm.create_model(modelName, num_classes=1000) + model = add_ml_decoder_head(model) + model.eval() + with torch.set_grad_enabled(False): + model(torch.randn([1,*model.default_cfg['input_size']])) \ No newline at end of file diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index f6c0e48f38..66fd2af5d5 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -45,7 +45,7 @@ def forward(self, x, pre_logits: bool = False): return self.flatten(x) def add_ml_decoder_head(model): - + #FIXME toggle between implementations? # ignore CoaT, crossvit # ignore distillation models: deit_distilled, efficientformer V2 num_classes = model.num_classes From 1bf6fbdf60b92aa82bc3cac77f5c1f672082b09e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 1 Jul 2024 20:23:05 -0700 Subject: [PATCH 34/34] Update ml_decoder.py allow external class embed (ex text embeddings of class descriptions), head version toggle --- timm/layers/ml_decoder.py | 105 +++++++++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 29 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 66fd2af5d5..e409bd35a3 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -13,14 +13,14 @@ class MLDecoderHead(nn.Module): """MLDecoder wrapper with forward compatible with ClassifierHead""" - def __init__(self, in_features, num_classes, pool_type='avg', use_conv=False, input_fmt='NCHW'): + def __init__(self, head, in_features, num_classes, pool_type='avg', use_conv=False, input_fmt='NCHW'): super(MLDecoderHead, self).__init__() self.in_features = in_features self.use_conv = use_conv self.input_fmt = input_fmt self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv, input_fmt=input_fmt) - self.head = MLDecoder(in_features=in_features, num_classes=num_classes) + self.head = head self.flatten = nn.Flatten(1) if pool_type else nn.Identity() @@ -30,6 +30,7 @@ def reset(self, num_classes, global_pool=None): self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() num_pooled_features = self.in_features * self.global_pool.feat_mult() + # TODO fix this it is incorrect, need to impl a reset for mldecoder itself i think self.head = MLDecoder(in_features=in_features, num_classes=num_classes) @@ -44,12 +45,18 @@ def forward(self, x, pre_logits: bool = False): x = self.head(x) return self.flatten(x) -def add_ml_decoder_head(model): - #FIXME toggle between implementations? +def add_ml_decoder_head(model, head_version='new', **kwargs): # ignore CoaT, crossvit # ignore distillation models: deit_distilled, efficientformer V2 num_classes = model.num_classes num_features = model.num_features + + if head_version == 'old': + head_fn = MLDecoderLegacy + else: + head_fn = MLDecoder + + head = head_fn(num_features, num_classes, **kwargs) assert num_classes > 0, "MLDecoder requires a model to have num_classes > 0" @@ -57,14 +64,14 @@ def add_ml_decoder_head(model): model.global_pool = nn.Identity() del model.fc - model.fc = MLDecoder(num_classes=num_classes, in_features=num_features) + model.fc = head elif hasattr(model, 'fc_norm') or 'Cait' in model._get_name(): # ViT, BEiT, EVA model.global_pool = None # disable any pooling, model instantiation leaves 1 norm layer after features, [B, n + K x K, C] if hasattr(model, 'attn_pool'): model.attn_pool = None model.head_drop = nn.Identity() - model.head = MLDecoder(num_classes=num_classes, in_features=num_features) + model.head = head elif 'MetaFormer' in model._get_name(): if hasattr(model.head, 'flatten'): # ConvNext case @@ -72,12 +79,12 @@ def add_ml_decoder_head(model): model.head.global_pool = nn.Identity() model.head.drop = nn.Identity() del model.head.fc - model.head.fc = MLDecoder(num_classes=num_classes, in_features=num_features) + model.head.fc = head # maybe and isinstance(model.head, (NormMlpClassifierHead, ClassifierHead) ? elif hasattr(model, 'head'): # ClassifierHead, nn.Sequential input_fmt = getattr(model.head, 'input_fmt', 'NCHW') - model.head = MLDecoderHead(num_features, num_classes) + model.head = MLDecoderHead(head, num_features, num_classes) if hasattr(model, 'global_pool'): if(isinstance(model.global_pool, nn.Module)): model.global_pool = nn.Identity() @@ -90,21 +97,21 @@ def add_ml_decoder_head(model): model.flatten = nn.Identity() del model.classifier - model.classifier = MLDecoder(num_classes=num_classes, in_features=num_features) + model.classifier = head elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet model.global_pool = nn.Identity() del model.classifier - model.classifier = MLDecoder(num_classes=num_classes, in_features=num_features) + model.classifier = head elif hasattr(model, 'global_pool') and hasattr(model, 'last_linear'): # InceptionV4 model.global_pool = nn.Identity() del model.last_linear - model.last_linear = MLDecoder(num_classes=num_classes, in_features=num_features) + model.last_linear = head elif hasattr(model, 'global_pool') and hasattr(model, 'classif'): # InceptionResnetV2 model.global_pool = nn.Identity() del model.classif - model.classif = MLDecoder(num_classes=num_classes, in_features=num_features) + model.classif = head else: raise Exception("Model code-writing is not aligned currently with ml-decoder") @@ -185,26 +192,33 @@ def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: class MLDecoderLegacy(nn.Module): - def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, in_features=2048, simple_group_fc = True): + def __init__( + self, + in_features: int, + num_classes: int, + dim: int = 768, + num_groups: int = 100, + simple_group_fc = True, + ): super(MLDecoderLegacy, self).__init__() - embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + embed_len_decoder = 100 if num_groups < 0 else num_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes # switching to 768 initial embeddings - decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding - self.embed_standart = nn.Linear(in_features, decoder_embedding) + dim = 768 if dim < 0 else dim + self.embed_standart = nn.Linear(in_features, dim) # decoder decoder_dropout = 0.1 num_layers_decoder = 1 dim_feedforward = 2048 - layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + layer_decode = TransformerDecoderLayerOptimal(d_model=dim, dim_feedforward=dim_feedforward, dropout=decoder_dropout) self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) # non-learnable queries - self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed = nn.Embedding(embed_len_decoder, dim) self.query_embed.requires_grad_(False) # group fully-connected @@ -212,7 +226,7 @@ def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, in_feat self.num_classes = num_classes self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) self.duplicate_pooling = torch.nn.Parameter( - torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + torch.Tensor(embed_len_decoder, dim, self.duplicate_factor)) self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.constant_(self.duplicate_pooling_bias, 0) @@ -251,6 +265,7 @@ class CrossAttention(nn.Module): def __init__( self, dim: int, + query_dim: Optional[int] = None, num_heads: int = 8, qkv_bias: bool = True, qk_norm: bool = False, @@ -264,8 +279,9 @@ def __init__( self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + self.query_dim = dim if query_dim is None else query_dim - self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.q = nn.Linear(query_dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -274,7 +290,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) def forward(self, q, x) -> torch.Tensor: - K, _ = q.shape # [K, C] + K, _ = q.shape # [K, C_q] B, N, C = x.shape # [B, N, C] q = self.q(q).reshape(1, K, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [1, n_h, K, d_h] kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [2, B, n_h, N, d_h] @@ -327,6 +343,10 @@ def __init__( dim: int = 768, num_groups: int = 100, num_heads: int = 8, + class_embed: Optional[torch.Tensor] = None, + concat_class_embed: bool = True, + learnable_embed: bool = False, + learnable_class_embed: bool = False, embed_drop: float = 0.1, embed_norm: bool = True, k_norm: bool = False, @@ -338,20 +358,46 @@ def __init__( ): super().__init__() - - - # non-learnable queries - self.query_embed = nn.Embedding(num_groups, dim) - self.query_embed.requires_grad_(False) + have_class_embed = class_embed is not None + self.concat_class_embed = have_class_embed and concat_class_embed + self.class_embed = None + self.query_embed = None + self.query_dim = 0 + if have_class_embed: + assert len(class_embed) == num_classes, 'ML-Decoder got class_embed where dim 0 != num_classes' + class_embed = class_embed.clone().detach() # copy instead of reference, detach gradient flow + self.query_dim += class_embed.shape[1] + duplicate_factor = int(num_classes / num_groups + 0.999) + class_embed_pad_length = (duplicate_factor - num_classes % duplicate_factor) % duplicate_factor + + # pad and reshape into groups + class_embed = torch.cat([class_embed, torch.zeros(class_embed_pad_length, class_embed.shape[1])]) + class_embed = class_embed.reshape(num_groups, duplicate_factor, -1) + + # reduce each group to a single embed with mean + class_embed = class_embed.mean(1) + self.class_embed = nn.Embedding.from_pretrained(class_embed) + + + # TODO can use tensor instead of nn.Embedding and simply register as either a parameter or a buffer for learnability + self.class_embed.requires_grad_(learnable_class_embed) + + # case no class embed or using both + if not have_class_embed or concat_class_embed: + self.query_dim += dim + self.query_embed = nn.Embedding(num_groups, dim) + # TODO can use tensor instead of nn.Embedding and simply register as either a parameter or a buffer for learnability + self.query_embed.requires_grad_(learnable_embed) + self.embed_drop = nn.Dropout(embed_drop) - self.embed_norm = norm_layer(dim) + self.embed_norm = norm_layer(self.query_dim) self.proj = nn.Linear(in_features, dim) self.act = act_layer() self.norm1 = norm_layer(dim) - self.attn = CrossAttention(dim, num_heads=num_heads) + self.attn = CrossAttention(dim, query_dim=self.query_dim, num_heads=num_heads) self.norm2 = norm_layer(dim) self.mlp = Mlp( in_features=dim, @@ -367,7 +413,8 @@ def forward(self, x): x = x.flatten(2).transpose(1, 2) x = self.act(self.proj(x)) - q = self.embed_norm(self.embed_drop(self.query_embed.weight)) + q = torch.cat([x.weight for x in [self.query_embed, self.class_embed] if x is not None], dim=1) + q = self.embed_norm(self.embed_drop(q)) x = self.attn(q, self.norm1(x))# + q.unsqueeze(1) x = x + self.mlp(self.norm2(x)) x = self.fc(x)