In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [2]:
# Data Preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
# Custom Multi-Headed Attention Layer
class MultiHeadedAttentionLayer(nn.Module):
    def __init__(self, num_heads, embed_dim, MHA_arch):
        super(MultiHeadedAttentionLayer, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.MHA_arch = MHA_arch

    def forward(self, qkv):
        # qkv: (batch_size, seq_len, 3, num_heads, head_dim)
        
        # pad qkv to seqlen % 32 == 0
        seq_len = qkv.size(1)
        if seq_len % 32 != 0:
            pad_len = 32 - seq_len % 32
            qkv = F.pad(qkv, (0, 0, 0, 0, 0, 0, 0, pad_len), "constant", 0)

        output = self.MHA_arch.apply(qkv)
        return output

# Model for MNIST
class MHAForMNIST(nn.Module):
    def __init__(self, input_dim=28, num_heads=8, embed_dim=1024, num_classes=10, MHA_arch = None):
        super(MHAForMNIST, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.embedding = nn.Linear(input_dim, embed_dim)
        self.mha = MultiHeadedAttentionLayer(num_heads, embed_dim, MHA_arch)
        self.linear_projection = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # num_rows = 28 for MNIST
        # Reshape input (batch_size, 1, num_rows, num_rows) -> (batch_size, num_rows, embed_dim)
        # print(f"basic input: {x.shape}")
        x = x.view(-1, 28, 28)
        x = self.embedding(x)

        """Apply MHA"""
        # x: (batch_size, num_rows, embed_dim)
        qkv = torch.stack([x, x, x], dim=2)
        qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.head_dim)
        # qkv: (batch_size, seq_len, 3, num_heads, head_dim)
        
        # Apply multi-headed attention (batch_size, seq_len, num_heads, head_dim)
        x = self.mha(qkv)
        
        """Post MHA cleanup"""
        # x: (batch_sz, num_rows, num_heads, head_dim)
        x = x.reshape(x.size(0), x.size(1), self.embed_dim)  # Collapse num_heads and head_dim
        x = x.mean(dim = 1) # average num_rows
        

        """Project and softmax"""
        # x: (batch_size, num_rows, num_classes)
        x = self.linear_projection(x)
        x = F.log_softmax(x, dim=-1)
        return x

In [None]:
def forward(self, x):
    # num_rows = 28 for MNIST
    # Reshape input (batch_size, 1, num_rows, num_rows) 
    # -> (batch_size, num_rows, embed_dim)
    x = x.view(-1, 28, 28)
    x = self.embedding(x)

    """Apply MHA"""
    # x: (batch_size, num_rows, embed_dim)
    qkv = torch.stack([x, x, x], dim=2)
    qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.head_dim)
    # qkv: (batch_size, seq_len, 3, num_heads, head_dim)
    
    # Apply multi-headed attention (batch_size, seq_len, num_heads, head_dim)
    x = self.mha(qkv)
    
    """Post MHA cleanup"""
    # x: (batch_sz, num_rows, num_heads, head_dim)
    x = x.reshape(x.size(0), x.size(1), self.embed_dim)  # Collapse num_heads and head_dim
    x = x.mean(dim = 1) # average num_rows
    

    """Project and softmax"""
    # x: (batch_size, num_rows, num_classes)
    x = self.linear_projection(x)
    x = F.log_softmax(x, dim=-1)
    return x

In [None]:
from wrappers.naive_mha_wrapped import naive_AttnQKVPackedFunc
from wrappers.triton_FWDfp32_wrapped import triton_AttnQKVPackedFunc
from wrappers.triton_FWDbf16_wrapped import triton_FA2_AttnQKVPackedFunc
from wrappers.triton_FWDbf16_BWDbf16_wrapped import triton_both_AttnQKVPackedFunc

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MHAForMNIST(MHA_arch=naive_AttnQKVPackedFunc).to(device)
# model = MHAForMNIST(MHA_arch=triton_AttnQKVPackedFunc).to(device)
# model = MHAForMNIST(MHA_arch=triton_FA2_AttnQKVPackedFunc).to(device)
# model = MHAForMNIST(MHA_arch=triton_both_AttnQKVPackedFunc).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [21]:
import time

def run_val(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * correct / total
    return val_accuracy

epoch_data = {}
val_accuracy = run_val(model, val_loader)
epoch_data[-1] = {"val_accuracy": val_accuracy}
print(f"Initial Validation Accuracy: {val_accuracy:.2f}%")

losses = []
for epoch in range(5):
    start_time = time.time()  # Record start time of the epoch

    # Training
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if (torch.isnan(loss)):
            print(outputs)

            raise Exception("Loss is NaN")
        losses.append(loss.item())
    
    # Validation
    val_accuracy = run_val(model, val_loader)    

    epoch_time = time.time() - start_time  # Calculate elapsed time for the epoch

    print(f"Epoch {epoch} completed, Loss: {total_loss / len(train_loader):.4f}, "
          f"Validation Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f} seconds")

    epoch_data[epoch] = {
        "loss": total_loss / len(train_loader),
        "val_accuracy": val_accuracy,
        "epoch_time": epoch_time
    }


Initial Validation Accuracy: 9.94%
Epoch 0 completed, Loss: 1.4263, Validation Accuracy: 64.21%, Time: 11.61 seconds
Epoch 1 completed, Loss: 0.9274, Validation Accuracy: 72.60%, Time: 11.39 seconds
Epoch 2 completed, Loss: 0.8047, Validation Accuracy: 75.78%, Time: 11.40 seconds
Epoch 3 completed, Loss: 0.7406, Validation Accuracy: 75.83%, Time: 11.59 seconds
Epoch 4 completed, Loss: 0.7113, Validation Accuracy: 74.40%, Time: 11.69 seconds
