Your task is to modify the custom implementation of MultiHeadAttention. At present, every token in a sequence can attend to every other token.  


Your job is to change this behavior in a specific way.
Let $S$ be our input sequence of length $2 \cdot k$:
- tokens on positions $i \le k$ should attend to prefix of $S$ of length $k$ ($S[:k]$)
- tokens on positions $i \gt k$ should attend to prefix of $S$  of length $i$ ($S[:i]$)

(Note: You can assume the sequence length is always an even number)

In [None]:
import torch
import math
import torch.nn.functional as F
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads, d_head):
      super().__init__()
      self.d_model = d_model
      self.num_heads = num_heads
      self.d_head = d_head

      self.W_Q = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_K = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_V = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_O = torch.nn.Linear(num_heads*d_head, d_model, bias=True)

    def forward(self, x):

      seq_len, batch_size, _ = x.shape

      Q = self.W_Q(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)
      K = self.W_K(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)
      V = self.W_V(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)

      scaled_QK = torch.einsum("ibhd,jbhd->bhij", Q, K) / math.sqrt(self.d_head)
      # shape of scaled_QK is (batch_size, num_heads, seq_len, seq_len)
      #TODO
      mask_1 = torch.triu(torch.ones(seq_len, seq_len), 1).bool()
      mask_2 = torch.ones(seq_len, seq_len, dtype=torch.bool)
      k = int(seq_len/2)
      mask_2[:k, :k] = False
      mask = mask_1 & mask_2
      scaled_QK.masked_fill_(mask, float('-inf'))
      #END TODO
      weights = F.softmax(scaled_QK, -1)
      attention = torch.einsum("bhij,jbhd->ibhd", weights, V)

      result = self.W_O(attention.reshape(seq_len, batch_size,self.num_heads * self.d_head))

      return result, weights

In [None]:
# Test your solution
d_model = 4
num_heads= 4
d_head = 2
k = 3
batch_size = 3

with torch.no_grad():
  mha = MultiHeadAttention(d_model, num_heads, d_head)
  batched_x= torch.randn((2*k, batch_size, d_model))
  result, weights = mha(batched_x)
print("Result:", result)
print("Weights:", weights)

Result: tensor([[[-1.2155e-01,  2.6393e-01,  2.8529e-01,  8.7492e-03],
         [-3.0860e-01,  7.6529e-02,  7.5335e-01, -9.3537e-02],
         [-3.2243e-01,  3.0948e-01,  4.6269e-01,  1.4608e-03]],

        [[-9.7908e-02,  2.5062e-01,  3.2149e-01,  1.7905e-03],
         [-5.0542e-01,  1.9742e-01,  5.0122e-01, -2.4212e-03],
         [-2.9994e-01,  3.1374e-01,  4.0071e-01,  1.3043e-02]],

        [[-8.6501e-02,  2.4332e-01,  3.4093e-01, -4.7499e-03],
         [-5.8198e-01, -5.3442e-03,  8.7805e-01, -1.1195e-01],
         [-3.1591e-01,  2.6517e-01,  4.2930e-01,  3.7990e-03]],

        [[-6.2012e-02,  2.6314e-01,  3.1483e-01,  4.7588e-04],
         [-4.3939e-01,  1.0091e-01,  6.7119e-01, -6.6481e-02],
         [ 5.2178e-02,  3.4111e-01,  5.5215e-02,  2.5473e-02]],

        [[-1.3085e-01,  3.0381e-01,  2.6836e-01,  2.4589e-02],
         [-5.6368e-01,  7.1615e-02,  7.3330e-01, -6.9056e-02],
         [ 1.3761e-02,  3.3741e-01,  2.3844e-01,  1.4618e-02]],

        [[-1.2987e-01,  3.1723e-01,  