In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
batch_size = 10000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

In [90]:
import torch.nn.functional as F

class ExperimentalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 2, kernel_size=2, padding=0)
        self.conv2 = nn.Conv2d(2, 4, kernel_size=2, padding=0)
        self.conv3 = nn.Conv2d(4, 8, kernel_size=2, padding=0)
        self.conv4 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv14 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv15 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv16 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv17 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv18 = nn.Conv2d(8, 8, kernel_size=3, padding=1)

        self.fc = nn.Linear(5000, 10)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        x3 = F.relu(self.conv3(x2))
        rolling = x3
        x4 = F.relu(self.conv4(rolling))
        rolling = rolling + x4
        x5 = F.relu(self.conv5(rolling / 2))
        rolling = rolling + x5
        x6 = F.relu(self.conv6(rolling / 3))
        rolling = rolling + x6
        x7 = F.relu(self.conv7(rolling / 4))
        rolling = rolling + x7
        x8 = F.relu(self.conv8(rolling / 5))
        rolling = rolling + x8
        x9 = F.relu(self.conv9(rolling / 6))
        rolling = rolling + x9
        x10 = F.relu(self.conv10(rolling / 7))
        rolling = rolling + x10
        x11 = F.relu(self.conv11(rolling / 8))
        rolling = rolling + x11
        x12 = F.relu(self.conv12(rolling / 9))
        rolling = rolling + x12
        x13 = F.relu(self.conv13(rolling / 10))
        rolling = rolling + x13
        x14 = F.relu(self.conv14(rolling / 11))
        rolling = rolling + x14
        x15 = F.relu(self.conv15(rolling / 12))
        rolling = rolling + x15
        x16 = F.relu(self.conv16(rolling / 13))
        rolling = rolling + x16
        x17 = F.relu(self.conv17(rolling / 14))
        rolling = rolling + x17
        x18 = F.relu(self.conv18(rolling / 15))
        rolling = rolling + x18

        rolling = rolling / 16

        x_flat = rolling.view(rolling.size(0), -1)
        out = self.fc(x_flat)
        return out


In [91]:
learning_rate = 0.001 * 1
epochs = 1000

model = ExperimentalModel().to(device)
model 
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [92]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

conv1.weight: 8 params, requires_grad=True
conv1.bias: 2 params, requires_grad=True
conv2.weight: 32 params, requires_grad=True
conv2.bias: 4 params, requires_grad=True
conv3.weight: 128 params, requires_grad=True
conv3.bias: 8 params, requires_grad=True
conv4.weight: 576 params, requires_grad=True
conv4.bias: 8 params, requires_grad=True
conv5.weight: 576 params, requires_grad=True
conv5.bias: 8 params, requires_grad=True
conv6.weight: 576 params, requires_grad=True
conv6.bias: 8 params, requires_grad=True
conv7.weight: 576 params, requires_grad=True
conv7.bias: 8 params, requires_grad=True
conv8.weight: 576 params, requires_grad=True
conv8.bias: 8 params, requires_grad=True
conv9.weight: 576 params, requires_grad=True
conv9.bias: 8 params, requires_grad=True
conv10.weight: 576 params, requires_grad=True
conv10.bias: 8 params, requires_grad=True
conv11.weight: 576 params, requires_grad=True
conv11.bias: 8 params, requires_grad=True
conv12.weight: 576 params, requires_grad=True
conv12.

In [93]:
patience = 20
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(200):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()


        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3062
Epoch [1/1000], Validation Loss: 2.3045, Validation Accuracy: 10.00%
Output Summary: Max=0.1133, Min=-0.1266, Median=0.0067, Mean=-0.0032

Epoch [2/1000], Training Loss: 2.3030
Epoch [2/1000], Validation Loss: 2.3012, Validation Accuracy: 10.00%
Output Summary: Max=0.0618, Min=-0.0863, Median=-0.0363, Mean=-0.0177

Epoch [3/1000], Training Loss: 2.2994
Epoch [3/1000], Validation Loss: 2.2967, Validation Accuracy: 10.09%
Output Summary: Max=0.0488, Min=-0.0505, Median=-0.0032, Mean=-0.0068

Epoch [4/1000], Training Loss: 2.2935
Epoch [4/1000], Validation Loss: 2.2869, Validation Accuracy: 15.29%
Output Summary: Max=0.0887, Min=-0.0598, Median=-0.0167, Mean=-0.0097

Epoch [5/1000], Training Loss: 2.2783
Epoch [5/1000], Validation Loss: 2.2617, Validation Accuracy: 53.51%
Output Summary: Max=0.0953, Min=-0.1250, Median=-0.0109, Mean=-0.0102

Epoch [6/1000], Training Loss: 2.2377
Epoch [6/1000], Validation Loss: 2.1900, Validation Accuracy: 53.48%
Outp

KeyboardInterrupt: 