In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

Device: mps


In [None]:
class AlexNetScratch(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNetScratch, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

if __name__ == "__main__":
    model = AlexNetScratch(num_classes=10)
    print(model)

In [None]:
batch_size = 64

# Transformaciones para el dataset CIFAR-10
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Cargar el dataset CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
learning_rate = 0.001
num_epochs = 10

model = AlexNetScratch(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

def train(model, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 99:  # print every 100 batches
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
            running_loss = 0.0

def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion, epoch)
    evaluate(model, test_loader, criterion)

Epoch [1/10], Step [100/782], Loss: 2.1736
Epoch [1/10], Step [200/782], Loss: 1.9335
Epoch [1/10], Step [300/782], Loss: 1.7618
Epoch [1/10], Step [400/782], Loss: 1.6946
Epoch [1/10], Step [500/782], Loss: 1.6338
Epoch [1/10], Step [600/782], Loss: 1.5789
Epoch [1/10], Step [700/782], Loss: 1.5516
Test Loss: 0.0232, Accuracy: 46.65%
Epoch [2/10], Step [100/782], Loss: 1.4646
Epoch [2/10], Step [200/782], Loss: 1.4358
Epoch [2/10], Step [300/782], Loss: 1.4316
Epoch [2/10], Step [400/782], Loss: 1.4077
Epoch [2/10], Step [500/782], Loss: 1.3881
Epoch [2/10], Step [600/782], Loss: 1.3502
Epoch [2/10], Step [700/782], Loss: 1.3721
Test Loss: 0.0198, Accuracy: 54.94%
Epoch [3/10], Step [100/782], Loss: 1.2665
Epoch [3/10], Step [200/782], Loss: 1.2795
Epoch [3/10], Step [300/782], Loss: 1.2732
Epoch [3/10], Step [400/782], Loss: 1.2797
Epoch [3/10], Step [500/782], Loss: 1.2845
Epoch [3/10], Step [600/782], Loss: 1.2130
Epoch [3/10], Step [700/782], Loss: 1.2564
Test Loss: 0.0186, Accura

In [13]:
model_path = "alexnet_scratch_pytorch.pt"
torch.save(model.state_dict(), model_path)