In [1]:
import torch
import torch.nn as nn

In [2]:
# Define Self-Attention mechanism
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Compute query, key, and value matrices
        Q = self.query(x)  # (batch_size, input_dim)
        K = self.key(x)    # (batch_size, input_dim)
        V = self.value(x)  # (batch_size, input_dim)
        
        # Compute attention scores
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.shape[-1] ** 0.5)
        attention_weights = self.softmax(attention_scores)
        
        # Apply attention weights to value matrix
        attended_output = torch.matmul(attention_weights, V)
        return attended_output

In [3]:
# Define MLP with Self-Attention
class SelfAttentionMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SelfAttentionMLP, self).__init__()
        self.attention = SelfAttention(input_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        # Apply self-attention to the input
        attended_x = self.attention(x)
        
        # Pass the attended input through the MLP
        output = self.mlp(attended_x)
        return output

In [4]:
# Input: batch of 3D vectors (batch_size, 3)
batch_size = 10
input_dim = 3
hidden_dim = 64
output_dim = 5  # Predicting 5D vectors

model = SelfAttentionMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)

# Generate random input data
inputs = torch.randn(batch_size, input_dim)

# Forward pass through the model
outputs = model(inputs)

print("Input shape:", inputs.shape)
print("Output shape:", outputs.shape)


Input shape: torch.Size([10, 3])
Output shape: torch.Size([10, 5])
