Skip to content

Commit

Permalink
MultiHeadAttention to return qk as well
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Dec 30, 2022
1 parent 9323b25 commit 5380767
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def forward(
k = kv_cache[self.key]
v = kv_cache[self.value]

wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk

def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
Expand All @@ -95,9 +95,10 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()

w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()


class ResidualAttentionBlock(nn.Module):
Expand All @@ -121,9 +122,9 @@ def forward(
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x

Expand Down

0 comments on commit 5380767

Please sign in to comment.