In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary

Experiment with exploiting symmetries through convolution filter rotations and reflections. Works pretty well :)

In [1]:
!pip install -U -q triton

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 43841482.92it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1163300.97it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 10659561.31it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4175000.83it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 2000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ExperimentalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 4, kernel_size=2, stride=2)       # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(8 * 4, 8, kernel_size=2, stride=2)   # 14x14 -> 7x7
        self.conv3 = nn.Conv2d(8 * 8, 8, kernel_size=2, stride=1)   # 7x7   -> 6x6
        self.conv4 = nn.Conv2d(8 * 8, 16, kernel_size=2, stride=2)  # 6x6   -> 3x3
        self.conv5 = nn.Conv2d(8 * 16, 16, kernel_size=2, stride=1) # 3x3   -> 2x2
        self.conv6 = nn.Conv2d(8 * 16, 32, kernel_size=2, stride=1) # 2x2   -> 1x1

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(32 * 8, 10)

    def _conv_relu(self, x, conv_layer):
        weight = conv_layer.weight
        bias = conv_layer.bias
        stride = conv_layer.stride
        padding = conv_layer.padding

        def conv_with(kernel):
            return F.conv2d(x, kernel, bias=bias, stride=stride, padding=padding)

        kernels = [
            weight,
            torch.rot90(weight, k=1, dims=[2, 3]),
            torch.rot90(weight, k=2, dims=[2, 3]),
            torch.rot90(weight, k=3, dims=[2, 3]),
            torch.flip(weight, dims=[3]),
            torch.flip(weight, dims=[2]),
            weight.transpose(2, 3),
            torch.flip(weight.transpose(2, 3), dims=[2, 3])
        ]

        outputs = [conv_with(k) for k in kernels]
        y = torch.cat(outputs, dim=1)
        return self.relu(y)

    def _conv_relu(self, x: torch.Tensor, conv_layer: torch.nn.Conv2d):
        w   = conv_layer.weight
        s, p = conv_layer.stride, conv_layer.padding
    
        kernels = torch.cat([
            w,
            torch.rot90(w, 1, (2, 3)),
            torch.rot90(w, 2, (2, 3)),
            torch.rot90(w, 3, (2, 3)),
            torch.flip(w, (3,)),
            torch.flip(w, (2,)),
            w.transpose(2, 3),
            torch.flip(w.transpose(2, 3), (2, 3)),
        ], dim=0)
    
        y = F.conv2d(x, kernels, bias=conv_layer.bias.repeat(8), stride=s, padding=p)    
        return self.relu(y)

    def forward(self, x):
        x = self._conv_relu(x, self.conv1)
        x = self._conv_relu(x, self.conv2)
        x = self._conv_relu(x, self.conv3)
        x = self._conv_relu(x, self.conv4)
        x = self._conv_relu(x, self.conv5)
        x = self._conv_relu(x, self.conv6)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [22]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExperimentalModel().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [23]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.weight: 16 params, requires_grad=True
conv1.bias: 4 params, requires_grad=True
conv2.weight: 1024 params, requires_grad=True
conv2.bias: 8 params, requires_grad=True
conv3.weight: 2048 params, requires_grad=True
conv3.bias: 8 params, requires_grad=True
conv4.weight: 4096 params, requires_grad=True
conv4.bias: 16 params, requires_grad=True
conv5.weight: 8192 params, requires_grad=True
conv5.bias: 16 params, requires_grad=True
conv6.weight: 16384 params, requires_grad=True
conv6.bias: 32 params, requires_grad=True
fc.weight: 2560 params, requires_grad=True
fc.bias: 10 params, requires_grad=True

34414


In [24]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2464
Epoch [1/1000], Validation Loss: 2.0184, Validation Accuracy: 38.88%
Output Summary: Max=3.5927, Min=-1.6119, Median=0.0114, Mean=0.0279

Epoch [2/1000], Training Loss: 1.3496
Epoch [2/1000], Validation Loss: 0.7923, Validation Accuracy: 74.37%
Output Summary: Max=18.9871, Min=-13.3122, Median=2.1739, Mean=1.7690

Epoch [3/1000], Training Loss: 0.6346
Epoch [3/1000], Validation Loss: 0.5064, Validation Accuracy: 85.13%
Output Summary: Max=19.1920, Min=-11.6659, Median=2.1465, Mean=2.0111

Epoch [4/1000], Training Loss: 0.4763
Epoch [4/1000], Validation Loss: 0.4288, Validation Accuracy: 86.99%
Output Summary: Max=20.5189, Min=-14.0178, Median=2.1201, Mean=2.0524

Epoch [5/1000], Training Loss: 0.4275
Epoch [5/1000], Validation Loss: 0.3959, Validation Accuracy: 88.01%
Output Summary: Max=20.7456, Min=-14.6761, Median=1.8109, Mean=1.8743

Epoch [6/1000], Training Loss: 0.3866
Epoch [6/1000], Validation Loss: 0.3471, Validation Accuracy: 89.67%
Outpu