Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nanoGPT/model.py where manual implementation of attention,Is it correct to modify it like I did? #478

Open
wmx-github opened this issue Apr 27, 2024 · 1 comment

Comments

@wmx-github
Copy link

`
class CausalSelfAttention(nn.Module):
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

      # calculate query, key, values for all heads in batch and move head forward to be the batch dim
      q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
      k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
      q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
      v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

      # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
      if self.flash:
      # if False:
          # efficient attention using Flash Attention CUDA kernels
          y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
      else:
          # manual implementation of attention

          #wmx add mask=(B, nh, T, T)
          mask = torch.tensor( np.tril(  np.ones((B,self.n_head,T, T)) ) )

          att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
          att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
          att = F.softmax(att, dim=-1)
          att = self.attn_dropout(att)
          y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
      y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

      # output projection
      y = self.resid_dropout(self.c_proj(y))
      return y

`

@Flecart
Copy link

Flecart commented May 4, 2024

Can you please format it, add python syntax highlight and show what you modified?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants