In [None]:
# TODO: Fix attention to more closely mimic standard or GPT transformer block 

In [90]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionTanhConv2d(nn.Module):
    def __init__(self, in_channels, patch_size=2, stride=1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.patch_size = patch_size
        self.stride = stride

        self.attn = nn.MultiheadAttention(embed_dim=in_channels, num_heads=1, batch_first=True)
        self.pos_embed = nn.Parameter(torch.randn(4, in_channels))

        self.summary_token = nn.Parameter(torch.randn(1, in_channels))
        self.fc1 = nn.Linear(in_channels, in_channels)

        self.ln = nn.LayerNorm(in_channels)

    def forward(self, x):
        B, C, H, W = x.shape

        # Unfold to get 2x2 patches: shape (B, C*4, L) where L = num patches
        patches = F.unfold(x, kernel_size=self.patch_size, stride=self.stride)
        L = patches.shape[-1]
        
        # Reshape to (B, L, 4, C)
        patches = patches.transpose(1, 2).reshape(B, L, self.patch_size * self.patch_size, C)
        
        # (B*L, 4, C) for attention
        tokens = patches.reshape(B * L, 4, C)
        tokens = tokens + self.pos_embed.unsqueeze(0)

        # Append summary token: (B*L, 5, C)
        summary = self.summary_token.expand(B * L, 1, C)
        tokens = torch.cat([summary, tokens], dim=1)

        # Apply multi-head attention
        attn_output, _ = self.attn(tokens, tokens, tokens)  # Still (B*L, 5, C)

        # Extract the summary output: (B*L, C)
        summary_output = attn_output[:, 0, :]

        out = self.fc1(summary_output)  # (B*L, out_channels)
        out = torch.relu(summary_output)
        out = self.ln(out)
        
        # Reshape to (B, out_channels, H_out, W_out)
        H_out = (H - self.patch_size) // self.stride + 1
        W_out = (W - self.patch_size) // self.stride + 1
        out = out.view(B, H_out, W_out, self.out_channels).permute(0, 3, 1, 2)

        return out

In [91]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [92]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [93]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [94]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 2500
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [95]:
class MNISTAttentionGRUCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2, padding=1)
        self.attn_gru_conv = AttentionTanhConv2d(in_channels=64, patch_size=2, stride=1)

        self.fc1 = nn.Linear(64, 10)

    def forward(self, x): # (B, 1, 28, 28)
        x = self.conv1(x) # (B, 16, 15, 15)
        x = self.conv2(x) # (B, 16, 8, 8)
        x = self.conv3(x) # (B, 16, 5, 5)
        x = self.attn_gru_conv(x) # (B, 16, 4, 4)
        x = self.attn_gru_conv(x) # (B, 16, 4, 4)
        x = self.attn_gru_conv(x) # (B, 16, 2, 2)
        x = self.attn_gru_conv(x) # (B, 16, 1, 1)
        
        x = x.reshape(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        return x

In [96]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNISTAttentionGRUCNN().to(device)
# model = torch.compile(MNIST2DLSTMClassifier()).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [97]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.weight: 64 params, requires_grad=True
conv1.bias: 16 params, requires_grad=True
conv2.weight: 4608 params, requires_grad=True
conv2.bias: 32 params, requires_grad=True
conv3.weight: 8192 params, requires_grad=True
conv3.bias: 64 params, requires_grad=True
attn_gru_conv.pos_embed: 256 params, requires_grad=True
attn_gru_conv.summary_token: 64 params, requires_grad=True
attn_gru_conv.attn.in_proj_weight: 12288 params, requires_grad=True
attn_gru_conv.attn.in_proj_bias: 192 params, requires_grad=True
attn_gru_conv.attn.out_proj.weight: 4096 params, requires_grad=True
attn_gru_conv.attn.out_proj.bias: 64 params, requires_grad=True
attn_gru_conv.fc1.weight: 4096 params, requires_grad=True
attn_gru_conv.fc1.bias: 64 params, requires_grad=True
attn_gru_conv.ln.weight: 64 params, requires_grad=True
attn_gru_conv.ln.bias: 64 params, requires_grad=True
fc1.weight: 640 params, requires_grad=True
fc1.bias: 10 params, requires_grad=True

34874


In [98]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.1728
Epoch [1/1000], Validation Loss: 1.6802, Validation Accuracy: 56.44%
Output Summary: Max=2.3652, Min=-2.7602, Median=0.0619, Mean=0.0334

Epoch [2/1000], Training Loss: 1.2055
Epoch [2/1000], Validation Loss: 0.7723, Validation Accuracy: 84.77%
Output Summary: Max=4.2014, Min=-2.8849, Median=-0.0332, Mean=0.1487

Epoch [3/1000], Training Loss: 0.5983
Epoch [3/1000], Validation Loss: 0.4247, Validation Accuracy: 90.35%
Output Summary: Max=5.0898, Min=-3.8560, Median=-0.2723, Mean=0.0734

Epoch [4/1000], Training Loss: 0.3736
Epoch [4/1000], Validation Loss: 0.3020, Validation Accuracy: 92.28%
Output Summary: Max=5.7756, Min=-4.4066, Median=-0.3293, Mean=0.0801

Epoch [5/1000], Training Loss: 0.2840
Epoch [5/1000], Validation Loss: 0.2441, Validation Accuracy: 93.43%
Output Summary: Max=6.2063, Min=-4.4058, Median=-0.3786, Mean=0.0622

Epoch [6/1000], Training Loss: 0.2379
Epoch [6/1000], Validation Loss: 0.2234, Validation Accuracy: 93.91%
Output Su

KeyboardInterrupt: 

In [89]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3023
Epoch [1/1000], Validation Loss: 2.3016, Validation Accuracy: 15.66%
Output Summary: Max=0.1028, Min=-0.0965, Median=0.0467, Mean=0.0335

Epoch [2/1000], Training Loss: 2.2845
Epoch [2/1000], Validation Loss: 2.1563, Validation Accuracy: 27.01%
Output Summary: Max=0.7537, Min=-0.6424, Median=0.0407, Mean=0.0154

Epoch [3/1000], Training Loss: 1.7491
Epoch [3/1000], Validation Loss: 1.2844, Validation Accuracy: 61.41%
Output Summary: Max=3.8563, Min=-3.6543, Median=0.1253, Mean=0.0692

Epoch [4/1000], Training Loss: 1.0003
Epoch [4/1000], Validation Loss: 0.7231, Validation Accuracy: 80.32%
Output Summary: Max=5.1553, Min=-4.6620, Median=-0.1656, Mean=0.0501

Epoch [5/1000], Training Loss: 0.6128
Epoch [5/1000], Validation Loss: 0.4861, Validation Accuracy: 86.23%
Output Summary: Max=6.4259, Min=-4.8845, Median=-0.1543, Mean=0.0677

Epoch [6/1000], Training Loss: 0.4516
Epoch [6/1000], Validation Loss: 0.3873, Validation Accuracy: 88.88%
Output Summ

KeyboardInterrupt: 

In [81]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
Epoch [1/1000], Training Loss: 2.3031
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
Epoch [1/1000], Validation Loss: 2.3011, Validation Accuracy: 10.28%
Output Summary: Max=0.0767, Min=-0.0834, Median=0.0046, Mean=0.0008

torch.Size([

KeyboardInterrupt: 