In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary

In [360]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [361]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [420]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=3)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=1, padding=3)
        self.fc1 = nn.Linear(64 * 6 * 6, 10)

        self.aug_pool = nn.AvgPool2d(kernel_size=2, stride=1, padding=1)
        self.enable_aug = True
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        nn.init.constant_(self.fc1.weight, 0)
        nn.init.constant_(self.fc1.bias, 0)

    def forward(self, x, i):
        # if self.enable_aug:
        #     x = self.aug_pool(x)
        
        if i < 300:
            x = x + (1 - i / 300) * torch.randn_like(x)
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 6 * 6)
        x = self.fc1(x)
        return x

batch_size = 6000
learning_rate = 0.001 * 0.5
epochs = 1000

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

print(len(train_loader.dataset))
print(len(test_loader.dataset))

60000
10000


In [421]:
model.enable_aug = False
print(model.enable_aug)

False


In [422]:
patience = 100
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data, epoch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data, 300)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2914
Epoch [1/1000], Validation Loss: 2.2739, Validation Accuracy: 47.48%

Epoch [2/1000], Training Loss: 2.1885
Epoch [2/1000], Validation Loss: 2.1160, Validation Accuracy: 42.61%

Epoch [3/1000], Training Loss: 1.8772
Epoch [3/1000], Validation Loss: 1.8183, Validation Accuracy: 45.92%

Epoch [4/1000], Training Loss: 1.5527
Epoch [4/1000], Validation Loss: 1.5923, Validation Accuracy: 49.16%

Epoch [5/1000], Training Loss: 1.3982
Epoch [5/1000], Validation Loss: 1.4350, Validation Accuracy: 52.83%

Epoch [6/1000], Training Loss: 1.3074
Epoch [6/1000], Validation Loss: 1.2742, Validation Accuracy: 58.31%

Epoch [7/1000], Training Loss: 1.2410
Epoch [7/1000], Validation Loss: 1.1368, Validation Accuracy: 62.62%

Epoch [8/1000], Training Loss: 1.1827
Epoch [8/1000], Validation Loss: 1.0369, Validation Accuracy: 65.46%

Epoch [9/1000], Training Loss: 1.1295
Epoch [9/1000], Validation Loss: 0.9315, Validation Accuracy: 72.28%

Epoch [10/1000], Training Lo

KeyboardInterrupt: 

In [391]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0 
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 0.0315
Epoch [1/100], Validation Loss: 0.0347, Validation Accuracy: 98.76%

Epoch [2/100], Training Loss: 0.0312
Epoch [2/100], Validation Loss: 0.0346, Validation Accuracy: 98.76%

Epoch [3/100], Training Loss: 0.0309
Epoch [3/100], Validation Loss: 0.0345, Validation Accuracy: 98.75%

Epoch [4/100], Training Loss: 0.0306
Epoch [4/100], Validation Loss: 0.0344, Validation Accuracy: 98.76%

Epoch [5/100], Training Loss: 0.0303
Epoch [5/100], Validation Loss: 0.0343, Validation Accuracy: 98.76%

Epoch [6/100], Training Loss: 0.0300
Epoch [6/100], Validation Loss: 0.0342, Validation Accuracy: 98.78%

Epoch [7/100], Training Loss: 0.0297
Epoch [7/100], Validation Loss: 0.0342, Validation Accuracy: 98.78%

Epoch [8/100], Training Loss: 0.0294
Epoch [8/100], Validation Loss: 0.0341, Validation Accuracy: 98.78%

Epoch [9/100], Training Loss: 0.0291
Epoch [9/100], Validation Loss: 0.0340, Validation Accuracy: 98.78%

Epoch [10/100], Training Loss: 0.0288
Epoch [1

In [385]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 2.2795
Epoch [1/100], Validation Loss: 2.2230, Validation Accuracy: 39.66%

Epoch [2/100], Training Loss: 2.1031
Epoch [2/100], Validation Loss: 1.8653, Validation Accuracy: 69.79%

Epoch [3/100], Training Loss: 1.5604
Epoch [3/100], Validation Loss: 1.1186, Validation Accuracy: 74.33%

Epoch [4/100], Training Loss: 0.9090
Epoch [4/100], Validation Loss: 0.6861, Validation Accuracy: 79.13%

Epoch [5/100], Training Loss: 0.6496
Epoch [5/100], Validation Loss: 0.5461, Validation Accuracy: 83.10%

Epoch [6/100], Training Loss: 0.5338
Epoch [6/100], Validation Loss: 0.4625, Validation Accuracy: 86.27%

Epoch [7/100], Training Loss: 0.4561
Epoch [7/100], Validation Loss: 0.3981, Validation Accuracy: 88.23%

Epoch [8/100], Training Loss: 0.3978
Epoch [8/100], Validation Loss: 0.3481, Validation Accuracy: 89.52%

Epoch [9/100], Training Loss: 0.3518
Epoch [9/100], Validation Loss: 0.3085, Validation Accuracy: 90.79%

Epoch [10/100], Training Loss: 0.3159
Epoch [1

In [386]:
model.enable_aug = False

In [387]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 0.0503
Epoch [1/100], Validation Loss: 0.0402, Validation Accuracy: 98.58%

Epoch [2/100], Training Loss: 0.0421
Epoch [2/100], Validation Loss: 0.0413, Validation Accuracy: 98.58%

Epoch [3/100], Training Loss: 0.0389
Epoch [3/100], Validation Loss: 0.0354, Validation Accuracy: 98.74%

Epoch [4/100], Training Loss: 0.0367
Epoch [4/100], Validation Loss: 0.0332, Validation Accuracy: 98.81%

Epoch [5/100], Training Loss: 0.0348
Epoch [5/100], Validation Loss: 0.0329, Validation Accuracy: 98.79%

Epoch [6/100], Training Loss: 0.0339
Epoch [6/100], Validation Loss: 0.0328, Validation Accuracy: 98.81%

Epoch [7/100], Training Loss: 0.0330
Epoch [7/100], Validation Loss: 0.0325, Validation Accuracy: 98.84%

Epoch [8/100], Training Loss: 0.0323
Epoch [8/100], Validation Loss: 0.0321, Validation Accuracy: 98.81%

Epoch [9/100], Training Loss: 0.0315
Epoch [9/100], Validation Loss: 0.0320, Validation Accuracy: 98.83%

Epoch [10/100], Training Loss: 0.0309
Epoch [1

In [341]:
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=1, padding=3)
        self.fc1 = nn.Linear(32 * 5 * 5, 10)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        nn.init.constant_(self.fc1.weight, 0)
        nn.init.constant_(self.fc1.bias, 0)

    def forward(self, x):
        # x = x.view(x.size(0), 1, 7, 4, 7, 4).mean(dim=(3, 5))
        x = x.view(x.size(0), 1, 7, 4, 7, 4).max(dim=3)[0].max(dim=4)[0]

        x = self.relu(self.conv1(x))
        x = self.pool(x)

        x = x.view(-1, 32 * 5 * 5)
        x = self.fc1(x)
        return x

batch_size = 6000
learning_rate = 0.001 * 5
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [342]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 1.8729
Epoch [1/100], Validation Loss: 1.3403, Validation Accuracy: 69.92%

Epoch [2/100], Training Loss: 1.0685
Epoch [2/100], Validation Loss: 0.7984, Validation Accuracy: 75.89%

Epoch [3/100], Training Loss: 0.7187
Epoch [3/100], Validation Loss: 0.6105, Validation Accuracy: 81.45%

Epoch [4/100], Training Loss: 0.5833
Epoch [4/100], Validation Loss: 0.5233, Validation Accuracy: 83.99%

Epoch [5/100], Training Loss: 0.5124
Epoch [5/100], Validation Loss: 0.4765, Validation Accuracy: 85.15%

Epoch [6/100], Training Loss: 0.4697
Epoch [6/100], Validation Loss: 0.4443, Validation Accuracy: 86.02%

Epoch [7/100], Training Loss: 0.4417
Epoch [7/100], Validation Loss: 0.4215, Validation Accuracy: 86.56%

Epoch [8/100], Training Loss: 0.4209
Epoch [8/100], Validation Loss: 0.4052, Validation Accuracy: 86.98%

Epoch [9/100], Training Loss: 0.4051
Epoch [9/100], Validation Loss: 0.3929, Validation Accuracy: 87.51%

Epoch [10/100], Training Loss: 0.3921
Epoch [1

In [343]:
for param in model.parameters():
    param.requires_grad = False

model.eval()

SimpleCNN(
  (conv1): Conv2d(1, 32, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))
  (fc1): Linear(in_features=800, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
)

In [356]:
class SimpleCNN_2(nn.Module):
    def __init__(self):
        super(SimpleCNN_2, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 10)

        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.aug_pool = nn.AvgPool2d(kernel_size=2, stride=1, padding=1)

        nn.init.constant_(self.fc1.weight, 0)
        nn.init.constant_(self.fc1.bias, 0)
        
    def forward(self, x):
        with torch.no_grad():
            to_add = model(x)
            
        # x = x.view(x.size(0), 1, 14, 2, 14, 2).mean(dim=(3, 5))
        # x = x.view(x.size(0), 1, 14, 2, 14, 2).max(dim=3)[0].max(dim=4)[0]
        
        x = self.relu(self.conv1(x))
        x = self.pool(x)

        x = x.view(-1, 32 * 8 * 8)
        x = self.fc1(x)

        return x + to_add

batch_size = 6000
learning_rate = 0.001 * 1
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_2 = SimpleCNN_2().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_2.parameters(), lr=learning_rate)

In [357]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model_2.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model_2(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model_2(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 0.2326
Epoch [1/100], Validation Loss: 0.2577, Validation Accuracy: 91.66%

Epoch [2/100], Training Loss: 0.2236
Epoch [2/100], Validation Loss: 0.2466, Validation Accuracy: 91.93%

Epoch [3/100], Training Loss: 0.2141
Epoch [3/100], Validation Loss: 0.2375, Validation Accuracy: 92.18%

Epoch [4/100], Training Loss: 0.2043
Epoch [4/100], Validation Loss: 0.2263, Validation Accuracy: 92.67%

Epoch [5/100], Training Loss: 0.1946
Epoch [5/100], Validation Loss: 0.2160, Validation Accuracy: 93.05%

Epoch [6/100], Training Loss: 0.1851
Epoch [6/100], Validation Loss: 0.2059, Validation Accuracy: 93.50%

Epoch [7/100], Training Loss: 0.1761
Epoch [7/100], Validation Loss: 0.1964, Validation Accuracy: 93.78%

Epoch [8/100], Training Loss: 0.1678
Epoch [8/100], Validation Loss: 0.1879, Validation Accuracy: 94.02%

Epoch [9/100], Training Loss: 0.1604
Epoch [9/100], Validation Loss: 0.1800, Validation Accuracy: 94.25%

Epoch [10/100], Training Loss: 0.1538
Epoch [1

In [208]:
for name, param in model_2.named_parameters():
    print(name)
    print(param)

conv1.weight
Parameter containing:
tensor([[[[-0.2534,  0.1643,  0.2416],
          [ 0.2667,  0.3875, -0.0827],
          [ 0.4322, -0.0697, -0.3061]]],


        [[[-0.0567, -0.0488,  0.2234],
          [-0.0501, -0.4290, -0.2631],
          [ 0.4309,  0.2795,  0.0157]]],


        [[[ 0.0650,  0.1302,  0.0565],
          [-0.2735, -0.1238,  0.3030],
          [ 0.0147, -0.4208,  0.2622]]],


        [[[ 0.3591, -0.2588,  0.2400],
          [-0.0104, -0.2382,  0.0424],
          [ 0.2966, -0.4755, -0.0138]]],


        [[[ 0.1702, -0.0504, -0.2481],
          [-0.3263, -0.2601,  0.3250],
          [ 0.0112,  0.3293,  0.0563]]],


        [[[-0.1517, -0.1362,  0.3226],
          [ 0.1773,  0.2505, -0.1789],
          [-0.2960,  0.3635,  0.1897]]],


        [[[ 0.0162, -0.2020, -0.7532],
          [ 0.3767, -0.1904, -0.7365],
          [ 0.2741, -0.0306, -0.5199]]],


        [[[-0.2556,  0.2600, -0.2233],
          [-0.0364,  0.4169, -0.2390],
          [ 0.0252, -0.0921, -0.1873]]],

In [202]:
for param in model_2.parameters():
    param.requires_grad = False

model_2.eval()

SimpleCNN_2(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=576, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
)

In [203]:
class SimpleCNN_3(nn.Module):
    def __init__(self):
        super(SimpleCNN_3, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        nn.init.constant_(self.fc2.weight, 0)
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x):
        with torch.no_grad():
            to_add = model_2(x)
            
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return to_add + x

In [204]:
batch_size = 6000
learning_rate = 0.001
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_3 = SimpleCNN_3().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_3.parameters(), lr=learning_rate)

In [205]:
patience = 10
best_val_loss = float('inf')
no_improvement_epochs = 0

for epoch in range(epochs):
    model_3.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model_3(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model_3.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model_3(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/100], Training Loss: 0.0249
Epoch [1/100], Validation Loss: 0.0436, Validation Accuracy: 98.40%

Epoch [2/100], Training Loss: 0.0247
Epoch [2/100], Validation Loss: 0.0435, Validation Accuracy: 98.42%

Epoch [3/100], Training Loss: 0.0246
Epoch [3/100], Validation Loss: 0.0435, Validation Accuracy: 98.41%

Epoch [4/100], Training Loss: 0.0246
Epoch [4/100], Validation Loss: 0.0435, Validation Accuracy: 98.41%

Epoch [5/100], Training Loss: 0.0246
Epoch [5/100], Validation Loss: 0.0435, Validation Accuracy: 98.42%

Epoch [6/100], Training Loss: 0.0245
Epoch [6/100], Validation Loss: 0.0434, Validation Accuracy: 98.43%

Epoch [7/100], Training Loss: 0.0244
Epoch [7/100], Validation Loss: 0.0432, Validation Accuracy: 98.43%

Epoch [8/100], Training Loss: 0.0243
Epoch [8/100], Validation Loss: 0.0430, Validation Accuracy: 98.46%

Epoch [9/100], Training Loss: 0.0240
Epoch [9/100], Validation Loss: 0.0427, Validation Accuracy: 98.49%

Epoch [10/100], Training Loss: 0.0237
Epoch [1