In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


ModuleNotFoundError: No module named 'torch'

In [None]:
# hyperparams
num_classes = 10
batch_size = 16

# dummy inputs: [N, C, H, W]
X = torch.randn(128, 3, 64, 64)        # 128 fake "images"
y = torch.randint(0, num_classes, (128,))  # 128 fake class labels

dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
class BaselineCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(BaselineCNN, self).__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # after 3 pools: 64x64 -> 8x8
        self.classifier = nn.Sequential(
            nn.Flatten(),                           # [B, 64*8*8] = [B, 4096]
            nn.Linear(64 * 8 * 8, 128),             # dense layer
            nn.ReLU(),
            nn.Linear(128, num_classes)             # logits
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = BaselineCNN(num_classes=num_classes)
print(model)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:
model.train()

for batch_idx, (inputs, targets) in enumerate(loader):
    # 1. zero gradients from previous step
    optimizer.zero_grad()

    # 2. forward pass
    outputs = model(inputs)         # shape [B, num_classes]

    # 3. compute loss
    loss = criterion(outputs, targets)

    # 4. backward pass
    loss.backward()

    # 5. update weights
    optimizer.step()

    if batch_idx % 5 == 0:
        print(f"batch {batch_idx}, loss = {loss.item():.4f}")
