In [None]:
# create a simple attention matrix and perform softmax using PyTorch

import torch
import torch.nn.functional as F

# Create a 3x4 matrix with some sample values
matrix = torch.tensor([
    [1.0, 2.0, 0.5, 3.0],  # First row
    [0.1, 0.8, 5.0, 1.2],  # Second row
    [3.0, 2.0, 1.0, 0.5]   # Third row
])

print("Original Matrix:")
print(matrix)

# Apply softmax row-wise (dim=1)
softmax_result = F.softmax(matrix, dim=1)

print("\nAfter Softmax (row-wise):")
print(softmax_result)

# Verify that each row sums to 1
print("\nRow sums (should all be 1):")
print(torch.sum(softmax_result, dim=1))

  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


Original Matrix:
tensor([[1.0000, 2.0000, 0.5000, 3.0000],
        [0.1000, 0.8000, 5.0000, 1.2000],
        [3.0000, 2.0000, 1.0000, 0.5000]])

After Softmax (row-wise):
tensor([[0.0854, 0.2321, 0.0518, 0.6308],
        [0.0071, 0.0144, 0.9571, 0.0214],
        [0.6308, 0.2321, 0.0854, 0.0518]])

Row sums (should all be 1):
tensor([1., 1., 1.])


In [None]:
# create a simple attention matrix and perform softmax using custom implementation which is fused

import torch

def stable_softmax(x):
    # Get batch size and sequence length
    batch_size = x.shape[0]
    seq_len = x.shape[1]
    
    # Initialize arrays to hold intermediate values
    m = torch.full((batch_size,), float('-inf'), device=x.device)
    l = torch.zeros(batch_size, device=x.device)
    
    # First loop: compute maximum and sum for each row
    for i in range(seq_len):
        column = x[:, i]  # Get current column for all batches
        
        # Update maximum values
        prev_m = m.clone()
        m = torch.maximum(prev_m, column)
        
        # Update sum with scaling
        scale_factor = torch.exp(prev_m - m)
        update = torch.exp(column - m)
        l = l * scale_factor + update
    
    # Second loop: compute final softmax values
    result = torch.zeros_like(x)
    for k in range(seq_len):
        result[:, k] = torch.exp(x[:, k] - m) / l
        
    return result

# Test with the example matrix
matrix = torch.tensor([
    [1.0, 2.0, 0.5, 3.0],
    [0.1, 0.8, 5.0, 1.2],
    [3.0, 2.0, 1.0, 0.5]
])

print("Original matrix:")
print(matrix)

print("\nStable softmax result:")
result = stable_softmax(matrix)
print(result)

print("\nRow sums (should be 1):")
print(torch.sum(result, dim=1))

# Compare with PyTorch's softmax
import torch.nn.functional as F
print("\nPyTorch's built-in softmax:")
torch_result = F.softmax(matrix, dim=1)
print(torch_result)

print("\nDifference:")
print(torch.abs(result - torch_result).max())

Original matrix:
tensor([[1.0000, 2.0000, 0.5000, 3.0000],
        [0.1000, 0.8000, 5.0000, 1.2000],
        [3.0000, 2.0000, 1.0000, 0.5000]])

Stable softmax result:
tensor([[0.0854, 0.2321, 0.0518, 0.6308],
        [0.0071, 0.0144, 0.9571, 0.0214],
        [0.6308, 0.2321, 0.0854, 0.0518]])

Row sums (should be 1):
tensor([1.0000, 1.0000, 1.0000])

PyTorch's built-in softmax:
tensor([[0.0854, 0.2321, 0.0518, 0.6308],
        [0.0071, 0.0144, 0.9571, 0.0214],
        [0.6308, 0.2321, 0.0854, 0.0518]])

Difference:
tensor(1.1921e-07)
