Skip to content

Commit

Permalink
Clean up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dianaml0 committed Jan 6, 2023
1 parent 5a3f0bb commit 9401934
Showing 1 changed file with 0 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def forward(
.transpose(0, 1)
.reshape(seq_len, bsz, num_heads * head_dim)
)
# TODO: Reshape q/k/v back to original?
else:
q = q.view(seq_len, -1, head_dim)
k = k.view(seq_len, -1, head_dim)
Expand Down Expand Up @@ -396,11 +395,6 @@ def backward(ctx, grad_output):

# recalculate attention
if xf_eff_attn:
# TODO: reshape q/k/v?
# q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1)
# k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1)
# v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1)

num_heads = embed_dim_per_partition // head_dim

attn, lse = xops.memory_efficient_attention_forward_requires_grad(
Expand Down

0 comments on commit 9401934

Please sign in to comment.