From 8c0b98adfca6604fdb61999fdebdf39141da2512 Mon Sep 17 00:00:00 2001 From: Penn Date: Thu, 21 Jul 2022 11:35:03 -0700 Subject: [PATCH 1/3] fix inefficient attention computation --- clip/model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/clip/model.py b/clip/model.py index 3121dd75d..b563413e2 100644 --- a/clip/model.py +++ b/clip/model.py @@ -67,10 +67,10 @@ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, + query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, @@ -88,8 +88,7 @@ def forward(self, x): training=self.training, need_weights=False ) - - return x[0] + return x.squeeze(0) class ModifiedResNet(nn.Module): From 5eb1828ead57f4dfa49e24eef48aa70bdc804423 Mon Sep 17 00:00:00 2001 From: Penn Date: Thu, 21 Jul 2022 11:53:46 -0700 Subject: [PATCH 2/3] remove erroneous formatting --- clip/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clip/model.py b/clip/model.py index b563413e2..f264b6350 100644 --- a/clip/model.py +++ b/clip/model.py @@ -67,8 +67,8 @@ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], From 2fc776c0f9e62a9f795ed1a0552c84fb1ac644db Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Thu, 21 Jul 2022 12:51:00 -0700 Subject: [PATCH 3/3] simplified flatten --- clip/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clip/model.py b/clip/model.py index f264b6350..808bf16f3 100644 --- a/clip/model.py +++ b/clip/model.py @@ -66,7 +66,7 @@ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: self.num_heads = num_heads def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward(