Skip to content

Commit

Permalink
split pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
gpucce committed Jun 28, 2023
1 parent fb72f4d commit dce72e8
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ def __init__(

self.global_average_pool = global_average_pool
if attentional_pool:
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
self.attn_pool_cls = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=1)
self.attn_pool_tokens = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
else:
self.attn_pool = None
self.attn_pool_cls = None
self.attn_pool_tokens = None
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

Expand Down Expand Up @@ -486,20 +488,19 @@ def forward(self, x: torch.Tensor):
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

if self.attn_pool is not None:
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
else:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
pooled, tokens = self._global_pool(x)
if self.attn_pool_cls is not None:
pooled = self.attn_pool_cls(pooled.unsqueeze(1)).squeeze(1)
tokens = self.attn_pool_tokens(tokens)

pooled = self.ln_post(pooled)

if self.proj is not None:
pooled = pooled @ self.proj

if self.output_tokens:
return pooled, tokens

return pooled


Expand Down

0 comments on commit dce72e8

Please sign in to comment.