In [8]:
import torch
import torch.nn as nn

In [9]:
import onnxruntime as ort
import math

In [10]:
import copy

In [11]:
seed = 25

torch.manual_seed(seed)

<torch._C.Generator at 0x7f5d0fbdbb50>

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, mask=None):
        super(MultiHeadAttention, self).__init__()
        self.mask = mask
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads
        
        # Initialize weight matrices
        self.in_weights_q = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        self.in_weights_k = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        self.in_weights_v = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        
        self.out_weights = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        
        
    def forward(self, q, k , v):
        batch_size, seq_len, _ = q.shape
        
        # Project inputs to queries, keys and values
        q = torch.matmul(q, self.in_weights_q.T)   
        k = torch.matmul(k, self.in_weights_k.T)
        v = torch.matmul(v, self.in_weights_v.T)

        # Reshape for multi-head attention
        q = q.reshape(batch_size, seq_len * self.num_heads, self.head_dim).transpose(0, 1)
        k = k.reshape(batch_size, seq_len * self.num_heads, self.head_dim).transpose(0, 1)
        v = v.reshape(batch_size, seq_len * self.num_heads, self.head_dim).transpose(0, 1)
        
        # Calculate scaled dot-product attention
        scale = 1.0 / math.sqrt(self.head_dim)
        #print(q.shape, k.transpose(1, 2).shape)
        scores = torch.bmm(q, k.transpose(1, 2))
        attn = torch.softmax(scores * scale, dim=-1)
        #print(scores.shape)
        
        # Apply attention to values
        output = torch.bmm(attn, v)
        print(output.shape)
        
        # Reshape and project back to output dimension
        output = output.transpose(0, 1).reshape(seq_len * batch_size, self.embed_dim)
        output = output @ self.out_weights.T
        output = output.reshape(batch_size, seq_len, self.embed_dim)
        
        return output



In [13]:
# Example usage
embed_dim = 20
num_heads = 4
seq_len = 3
batch_size = 1
head_dim = int(embed_dim // num_heads)


input_tensor = torch.randn(batch_size, seq_len,  embed_dim)

multihead_attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0, bias=False, add_bias_kv=False, add_zero_attn=False)

wq, wk, wv = torch.split(mha.in_proj_weight, [embed_dim, embed_dim, embed_dim], dim=0)
multihead_attn.in_weights_q.data = wq
multihead_attn.in_weights_k.data = wk
multihead_attn.in_weights_v.data = wv
multihead_attn.out_weights.data = mha.out_proj.weight


In [14]:
torch.abs(multihead_attn(input_tensor, input_tensor, input_tensor) / mha(input_tensor, input_tensor, input_tensor)[0]-1).max() * 100

torch.Size([12, 1, 5])


tensor(0., grad_fn=<MulBackward0>)

In [15]:
class mha_(nn.Module): 
    def __init__(self): 
        super(mha_, self).__init__()

        self.li = nn.Linear(embed_dim, embed_dim)

        self.ma = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)

    def forward(self, x):
        li = self.li(x)
        q = li
        k = li
        v = li

        output = self.ma(q,k,v)
        return output
mha_c = mha_()

In [16]:
# Export the model to ONNX
torch.onnx.export(
    mha_c,                    
    input_tensor,             
    "mha_model.onnx",          
    export_params=True,      
    input_names=['input'],   
    output_names=['output'],  
)

print("Model successfully exported to ONNX format.")


torch.Size([12, 1, 5])
Model successfully exported to ONNX format.


In [17]:
class mha_(nn.Module): 
    def __init__(self): 
        super(mha_, self).__init__()

        self.li = nn.Linear(embed_dim, embed_dim)

        self.ma = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0, bias=False, add_bias_kv=False, add_zero_attn=False)

    def forward(self, x):
        li = self.li(x)
        q = li
        k = li
        v = li

        output, _ = self.ma(q, k, v)
        return output
mha_c = mha_()

In [18]:
# Export the model to ONNX
torch.onnx.export(
    mha_c,                    
    input_tensor,             
    "std_mha_model.onnx",          
    export_params=True,      
    input_names=['input'],   
    output_names=['output'], )

print("Model successfully exported to ONNX format.")


Model successfully exported to ONNX format.
