In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionGRUConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size=2, stride=1, heads=2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch_size = patch_size
        self.stride = stride
        self.heads = heads
        self.hidden_size = out_channels

        self.input_dim = patch_size * patch_size * in_channels
        self.attn_input_dim = in_channels

        # Positional encoding for 2x2 grid
        self.pos_embed = nn.Parameter(torch.randn(4, in_channels))  # 4 corners of the 2x2 patch

        # Multi-head attention for 2x2 pixels
        self.attn = nn.MultiheadAttention(embed_dim=in_channels, num_heads=heads, batch_first=True)

        # GRU cell to generate output feature vector
        self.gru = nn.GRU(input_size=self.input_dim, hidden_size=self.hidden_size, batch_first=True)

    def forward(self, x):
        B, C, H, W = x.shape

        # Use unfold to get 2x2 patches
        patches = F.unfold(x, kernel_size=self.patch_size, stride=self.stride)  # (B, C*K*K, L)
        L = patches.shape[-1]  # Number of patches
        patches = patches.transpose(1, 2)  # (B, L, C*K*K)

        # Get 4 corners for attention: reshape to (B*L, 4, C)
        corners = patches.view(B, L, self.patch_size * self.patch_size, C)
        corners = corners[:, :, [0, 1, 2, 3]]  # 2x2 order; assumed flattened row-wise
        corners = corners.reshape(B * L, 4, C)

        # Add positional encoding
        corners_pe = corners + self.pos_embed.unsqueeze(0)  # (B*L, 4, C)

        # Apply MHA: output will be (B*L, 4, C)
        attn_output, _ = self.attn(corners_pe, corners_pe, corners_pe)  # Self-attn

        # Aggregate attention outputs into a vector to use as GRU hidden state
        init_hidden = attn_output.mean(dim=1).unsqueeze(0)  # (1, B*L, C)

        # Flatten patch for input to GRU
        patch_inputs = patches.reshape(B * L, 1, -1)  # (B*L, 1, C*4)

        # Run GRU
        _, h_n = self.gru(patch_inputs, init_hidden)  # h_n: (1, B*L, hidden_size)

        # Reshape to (B, out_channels, H_out, W_out)
        output = h_n.squeeze(0).view(B, L, self.out_channels).transpose(1, 2)  # (B, out_channels, L)
        H_out = (H - self.patch_size) // self.stride + 1
        W_out = (W - self.patch_size) // self.stride + 1
        output = output.view(B, self.out_channels, H_out, W_out)

        return output

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [10]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 2000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [38]:
class MNISTAttentionGRUCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=2, padding=1)
        self.attn_gru_conv = AttentionGRUConv2d(in_channels=16, out_channels=16, patch_size=2, stride=1, heads=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2, padding=1)

        self.fc1 = nn.Linear(16, 10)

    def forward(self, x):
        x = self.conv1(x) # (B, 16, 15, 15)
        x = self.attn_gru_conv(x)  # (B, 16, 14, 14)
        x = self.conv2(x) # (B, 16, 8, 8)
        x = self.conv3(x) # (B, 16, 5, 5)
        x = self.attn_gru_conv(x) # (B, 16, 4, 4)
        x = self.conv4(x) # (B, 16, 3, 3)
        x = self.attn_gru_conv(x) # (B, 16, 2, 2)
        x = self.attn_gru_conv(x) # (B, 16, 1, 1)
        
        x = x.reshape(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        print("here")
        return x

In [39]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNISTAttentionGRUCNN().to(device)
# model = torch.compile(MNIST2DLSTMClassifier()).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [40]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.weight: 64 params, requires_grad=True
conv1.bias: 16 params, requires_grad=True
attn_gru_conv.pos_embed: 64 params, requires_grad=True
attn_gru_conv.attn.in_proj_weight: 768 params, requires_grad=True
attn_gru_conv.attn.in_proj_bias: 48 params, requires_grad=True
attn_gru_conv.attn.out_proj.weight: 256 params, requires_grad=True
attn_gru_conv.attn.out_proj.bias: 16 params, requires_grad=True
attn_gru_conv.gru.weight_ih_l0: 3072 params, requires_grad=True
attn_gru_conv.gru.weight_hh_l0: 768 params, requires_grad=True
attn_gru_conv.gru.bias_ih_l0: 48 params, requires_grad=True
attn_gru_conv.gru.bias_hh_l0: 48 params, requires_grad=True
conv2.weight: 1024 params, requires_grad=True
conv2.bias: 16 params, requires_grad=True
conv3.weight: 1024 params, requires_grad=True
conv3.bias: 16 params, requires_grad=True
conv4.weight: 1024 params, requires_grad=True
conv4.bias: 16 params, requires_grad=True
fc1.weight: 160 params, requires_grad=True
fc1.bias: 10 params, requires_grad=True

8458

In [34]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

torch.Size([2000, 16, 1, 1])
here


TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType