In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


Experiment with exploiting symmetries through convolution filter rotations and reflections. Works pretty well :)

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 2000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ExperimentalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 4, kernel_size=2, stride=2, padding=1)       # 28x28 -> 15x15
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1)   # 15x15 -> 8x8
        self.conv3 = nn.Conv2d(8, 8, kernel_size=4, stride=1, padding=1)   # 8x8   -> 7x7
        self.conv4 = nn.Conv2d(8, 16, kernel_size=5, stride=1, padding=1)  # 7x7   -> 5x5
        self.conv5 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=0) # 5x5   -> 1x1

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(16, 10)

    def _conv_relu(self, x: torch.Tensor, conv_layer: torch.nn.Conv2d):
        w   = conv_layer.weight
        s, p = conv_layer.stride, conv_layer.padding
    
        kernels = torch.cat([
            w,
            torch.rot90(w, 1, (2, 3)),
            torch.rot90(w, 2, (2, 3)),
            torch.rot90(w, 3, (2, 3)),
            torch.flip(w, (3,)),
            torch.flip(w, (2,)),
            w.transpose(2, 3),
            torch.flip(w.transpose(2, 3), (2, 3)),
        ], dim=0)
    
        y = F.conv2d(x, kernels, bias=conv_layer.bias.repeat(1), stride=s, padding=p)    
        return self.relu(y)

    def forward(self, x):
        x = self._conv_relu(x, self.conv1)
        x = self._conv_relu(x, self.conv2)
        x = self._conv_relu(x, self.conv3)
        x = self._conv_relu(x, self.conv4)
        x = self._conv_relu(x, self.conv5)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, bias=True)
        self.pointwise = nn.Conv2d(in_channels * 8, out_channels, kernel_size=1, bias=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        w = self.depthwise.weight
        b = self.depthwise.bias
        s, p = self.depthwise.stride, self.depthwise.padding

        kernels = torch.cat([
            w,
            torch.rot90(w, 1, (2, 3)),
            torch.rot90(w, 2, (2, 3)),
            torch.rot90(w, 3, (2, 3)),
            torch.flip(w, (3,)),
            torch.flip(w, (2,)),
            w.transpose(2, 3),
            torch.flip(w.transpose(2, 3), (2, 3)),
        ], dim=0)

        x = F.conv2d(x, kernels, bias=b.repeat(8), stride=s, padding=p, groups=x.shape[1])
        x = self.relu(x)

        x = self.pointwise(x)
        return x

class ExperimentalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = DepthwiseSeparableConv(1, 8, kernel_size=2, stride=2, padding=1)   # 28x28 -> 15x15
        self.conv2 = DepthwiseSeparableConv(8, 16, kernel_size=3, stride=2, padding=1)  # 15x15 -> 8x8
        self.conv3 = DepthwiseSeparableConv(16, 32, kernel_size=4, stride=1, padding=1)  # 8x8 -> 7x7
        self.conv4 = DepthwiseSeparableConv(32, 32, kernel_size=5, stride=1, padding=1)  # 7x7 -> 5x5
        self.conv5 = DepthwiseSeparableConv(32, 64, kernel_size=5, stride=1, padding=0)  # 5x5 -> 1x1

        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [7]:
learning_rate = 0.001 * 1
epochs = 1000

model = ExperimentalModel().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.depthwise.weight: 4 params, requires_grad=True
conv1.depthwise.bias: 1 params, requires_grad=True
conv1.pointwise.weight: 64 params, requires_grad=True
conv1.pointwise.bias: 8 params, requires_grad=True
conv2.depthwise.weight: 72 params, requires_grad=True
conv2.depthwise.bias: 8 params, requires_grad=True
conv2.pointwise.weight: 1024 params, requires_grad=True
conv2.pointwise.bias: 16 params, requires_grad=True
conv3.depthwise.weight: 256 params, requires_grad=True
conv3.depthwise.bias: 16 params, requires_grad=True
conv3.pointwise.weight: 4096 params, requires_grad=True
conv3.pointwise.bias: 32 params, requires_grad=True
conv4.depthwise.weight: 800 params, requires_grad=True
conv4.depthwise.bias: 32 params, requires_grad=True
conv4.pointwise.weight: 8192 params, requires_grad=True
conv4.pointwise.bias: 32 params, requires_grad=True
conv5.depthwise.weight: 800 params, requires_grad=True
conv5.depthwise.bias: 32 params, requires_grad=True
conv5.pointwise.weight: 16384 params, req

In [9]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2931
Epoch [1/1000], Validation Loss: 2.2249, Validation Accuracy: 37.24%
Output Summary: Max=0.1836, Min=-0.1928, Median=-0.0041, Mean=-0.0026

Epoch [2/1000], Training Loss: 1.5187
Epoch [2/1000], Validation Loss: 0.8940, Validation Accuracy: 67.27%
Output Summary: Max=24.1974, Min=-23.6897, Median=1.8750, Mean=0.2262

Epoch [3/1000], Training Loss: 0.7783
Epoch [3/1000], Validation Loss: 0.7291, Validation Accuracy: 73.09%
Output Summary: Max=21.0874, Min=-25.0418, Median=2.3647, Mean=0.2783

Epoch [4/1000], Training Loss: 0.6771
Epoch [4/1000], Validation Loss: 0.6646, Validation Accuracy: 74.56%
Output Summary: Max=21.5659, Min=-26.4458, Median=2.7635, Mean=0.2243

Epoch [5/1000], Training Loss: 0.6306
Epoch [5/1000], Validation Loss: 0.6291, Validation Accuracy: 76.51%
Output Summary: Max=22.2256, Min=-26.9921, Median=3.0525, Mean=0.1838

Epoch [6/1000], Training Loss: 0.6017
Epoch [6/1000], Validation Loss: 0.6076, Validation Accuracy: 77.25%
Out

KeyboardInterrupt: 

In [114]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3012
Epoch [1/1000], Validation Loss: 2.2882, Validation Accuracy: 19.46%
Output Summary: Max=0.0703, Min=-0.1166, Median=-0.0117, Mean=-0.0130

Epoch [2/1000], Training Loss: 2.0032
Epoch [2/1000], Validation Loss: 1.4941, Validation Accuracy: 48.34%
Output Summary: Max=4.3050, Min=-8.2861, Median=0.1374, Mean=-0.0165

Epoch [3/1000], Training Loss: 1.0819
Epoch [3/1000], Validation Loss: 0.8928, Validation Accuracy: 66.38%
Output Summary: Max=23.2645, Min=-25.5696, Median=1.9330, Mean=-0.0560

Epoch [4/1000], Training Loss: 0.8255
Epoch [4/1000], Validation Loss: 0.7981, Validation Accuracy: 70.98%
Output Summary: Max=22.4251, Min=-27.6378, Median=1.9379, Mean=-0.1756

Epoch [5/1000], Training Loss: 0.7518
Epoch [5/1000], Validation Loss: 0.7370, Validation Accuracy: 72.64%
Output Summary: Max=23.8140, Min=-29.3738, Median=2.1945, Mean=-0.2213

Epoch [6/1000], Training Loss: 0.6978
Epoch [6/1000], Validation Loss: 0.6898, Validation Accuracy: 73.92%
O

In [None]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2538
Epoch [1/1000], Validation Loss: 2.0663, Validation Accuracy: 29.65%
Output Summary: Max=1.5856, Min=-1.4284, Median=0.0883, Mean=0.1063

Epoch [2/1000], Training Loss: 1.5564
Epoch [2/1000], Validation Loss: 1.0894, Validation Accuracy: 58.70%
Output Summary: Max=16.0737, Min=-8.0528, Median=3.4470, Mean=3.3465

Epoch [3/1000], Training Loss: 0.9512
Epoch [3/1000], Validation Loss: 0.8822, Validation Accuracy: 66.97%
Output Summary: Max=20.5775, Min=-11.3086, Median=5.0378, Mean=4.6663

Epoch [4/1000], Training Loss: 0.8276
Epoch [4/1000], Validation Loss: 0.7972, Validation Accuracy: 69.94%
Output Summary: Max=22.5915, Min=-12.6064, Median=5.7124, Mean=5.2277

Epoch [5/1000], Training Loss: 0.7596
Epoch [5/1000], Validation Loss: 0.7475, Validation Accuracy: 71.02%
Output Summary: Max=23.6714, Min=-14.5858, Median=6.1398, Mean=5.5061

Epoch [6/1000], Training Loss: 0.7138
Epoch [6/1000], Validation Loss: 0.7141, Validation Accuracy: 72.58%
Output

In [60]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 1.7081
Epoch [1/1000], Validation Loss: 0.9695, Validation Accuracy: 64.26%
Output Summary: Max=18.9550, Min=-17.7916, Median=2.9532, Mean=2.6816

Epoch [2/1000], Training Loss: 0.8493
Epoch [2/1000], Validation Loss: 0.7686, Validation Accuracy: 71.64%
Output Summary: Max=21.4130, Min=-23.1725, Median=3.5562, Mean=2.6070

Epoch [3/1000], Training Loss: 0.7148
Epoch [3/1000], Validation Loss: 0.6822, Validation Accuracy: 74.29%
Output Summary: Max=20.8488, Min=-18.8119, Median=3.2916, Mean=2.4231

Epoch [4/1000], Training Loss: 0.6425
Epoch [4/1000], Validation Loss: 0.6331, Validation Accuracy: 76.68%
Output Summary: Max=20.0240, Min=-17.0283, Median=2.9606, Mean=2.2373

Epoch [5/1000], Training Loss: 0.5873
Epoch [5/1000], Validation Loss: 0.5805, Validation Accuracy: 78.63%
Output Summary: Max=21.2734, Min=-17.3467, Median=3.0032, Mean=2.3032

Epoch [6/1000], Training Loss: 0.5582
Epoch [6/1000], Validation Loss: 0.5695, Validation Accuracy: 79.12%
Out

In [61]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 0.2124
Epoch [1/1000], Validation Loss: 0.3223, Validation Accuracy: 88.62%
Output Summary: Max=37.2345, Min=-33.2681, Median=0.7298, Mean=0.7523

Epoch [2/1000], Training Loss: 0.2127
Epoch [2/1000], Validation Loss: 0.3344, Validation Accuracy: 88.15%
Output Summary: Max=37.0837, Min=-33.0560, Median=0.6679, Mean=0.7063

Epoch [3/1000], Training Loss: 0.2153
Epoch [3/1000], Validation Loss: 0.3228, Validation Accuracy: 88.87%
Output Summary: Max=34.9332, Min=-32.7047, Median=0.7557, Mean=0.7112

Epoch [4/1000], Training Loss: 0.2118
Epoch [4/1000], Validation Loss: 0.3239, Validation Accuracy: 88.56%
Output Summary: Max=34.5521, Min=-33.6289, Median=0.5247, Mean=0.5905

Epoch [5/1000], Training Loss: 0.2117
Epoch [5/1000], Validation Loss: 0.3217, Validation Accuracy: 88.77%
Output Summary: Max=34.6413, Min=-34.3847, Median=0.3629, Mean=0.5237

Epoch [6/1000], Training Loss: 0.2150
Epoch [6/1000], Validation Loss: 0.3180, Validation Accuracy: 88.92%
Out

In [24]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 1.8539
Epoch [1/1000], Validation Loss: 1.0479, Validation Accuracy: 63.64%
Output Summary: Max=13.4919, Min=-10.8509, Median=0.7589, Mean=0.8177

Epoch [2/1000], Training Loss: 0.9139
Epoch [2/1000], Validation Loss: 0.8332, Validation Accuracy: 69.59%
Output Summary: Max=15.8360, Min=-20.6593, Median=0.9988, Mean=0.5593

Epoch [3/1000], Training Loss: 0.7718
Epoch [3/1000], Validation Loss: 0.7377, Validation Accuracy: 73.09%
Output Summary: Max=17.5524, Min=-21.9330, Median=1.3413, Mean=0.4340

Epoch [4/1000], Training Loss: 0.6963
Epoch [4/1000], Validation Loss: 0.7229, Validation Accuracy: 73.43%
Output Summary: Max=18.3551, Min=-22.0460, Median=1.6436, Mean=0.5397

Epoch [5/1000], Training Loss: 0.6553
Epoch [5/1000], Validation Loss: 0.7225, Validation Accuracy: 73.15%
Output Summary: Max=19.3050, Min=-23.2021, Median=1.9817, Mean=0.6588

Epoch [6/1000], Training Loss: 0.6309
Epoch [6/1000], Validation Loss: 0.6526, Validation Accuracy: 75.31%
Out

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ExperimentalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 4 * 8, kernel_size=2, stride=2)       # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(8 * 4, 8 * 8, kernel_size=2, stride=2)   # 14x14 -> 7x7
        self.conv3 = nn.Conv2d(8 * 8, 8 * 8, kernel_size=2, stride=1)   # 7x7   -> 6x6
        self.conv4 = nn.Conv2d(8 * 8, 16 * 8, kernel_size=2, stride=2)  # 6x6   -> 3x3
        self.conv5 = nn.Conv2d(8 * 16, 16 * 8, kernel_size=2, stride=1) # 3x3   -> 2x2
        self.conv6 = nn.Conv2d(8 * 16, 32 * 8, kernel_size=2, stride=1) # 2x2   -> 1x1

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(32 * 8, 10)

    def _conv_relu(self, x, conv_layer):
        return self.relu(conv_layer(x))

    def forward(self, x):
        x = self._conv_relu(x, self.conv1)
        x = self._conv_relu(x, self.conv2)
        x = self._conv_relu(x, self.conv3)
        x = self._conv_relu(x, self.conv4)
        x = self._conv_relu(x, self.conv5)
        x = self._conv_relu(x, self.conv6)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [26]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExperimentalModel().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [27]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.weight: 128 params, requires_grad=True
conv1.bias: 32 params, requires_grad=True
conv2.weight: 8192 params, requires_grad=True
conv2.bias: 64 params, requires_grad=True
conv3.weight: 16384 params, requires_grad=True
conv3.bias: 64 params, requires_grad=True
conv4.weight: 32768 params, requires_grad=True
conv4.bias: 128 params, requires_grad=True
conv5.weight: 65536 params, requires_grad=True
conv5.bias: 128 params, requires_grad=True
conv6.weight: 131072 params, requires_grad=True
conv6.bias: 256 params, requires_grad=True
fc.weight: 2560 params, requires_grad=True
fc.bias: 10 params, requires_grad=True

257322


In [28]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(1000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 1.6043
Epoch [1/1000], Validation Loss: 0.8719, Validation Accuracy: 67.78%
Output Summary: Max=13.6333, Min=-18.7842, Median=0.8420, Mean=-0.0718

Epoch [2/1000], Training Loss: 0.7601
Epoch [2/1000], Validation Loss: 0.7028, Validation Accuracy: 73.42%
Output Summary: Max=15.4859, Min=-22.1852, Median=1.0870, Mean=-0.4387

Epoch [3/1000], Training Loss: 0.6564
Epoch [3/1000], Validation Loss: 0.6429, Validation Accuracy: 74.93%
Output Summary: Max=17.0843, Min=-23.9776, Median=1.4523, Mean=-0.4943

Epoch [4/1000], Training Loss: 0.6023
Epoch [4/1000], Validation Loss: 0.6106, Validation Accuracy: 76.33%
Output Summary: Max=18.0315, Min=-25.8195, Median=1.4430, Mean=-0.5481

Epoch [5/1000], Training Loss: 0.5609
Epoch [5/1000], Validation Loss: 0.5582, Validation Accuracy: 79.34%
Output Summary: Max=18.4508, Min=-26.6314, Median=1.3929, Mean=-0.5886

Epoch [6/1000], Training Loss: 0.5211
Epoch [6/1000], Validation Loss: 0.5224, Validation Accuracy: 80.69