In [None]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_size = input_size // num_heads
        self.linear_q = nn.Linear(input_size, input_size)
        self.linear_k = nn.Linear(input_size, input_size)
        self.linear_v = nn.Linear(input_size, input_size)
        self.linear_out = nn.Linear(input_size, input_size)
        
    def forward(self, x):
        # Reshape input for multihead attention
        batch_size, seq_len, input_size = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.head_size)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size * self.num_heads, seq_len, self.head_size)

        # Apply linear transformations
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)

        # Compute attention scores
        scores = torch.bmm(q, k.transpose(1, 2)) / (self.head_size ** 0.5)
        attn = torch.softmax(scores, dim=2)

        # Apply attention weights and linear transformation
        out = torch.bmm(attn, v)
        out = out.view(batch_size, self.num_heads, seq_len, self.head_size)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(batch_size, seq_len, input_size)
        out = self.linear_out(out)

        return out
