In [1]:
import sys
import torch
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt

import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

mnist_train = datasets.FashionMNIST(root='./data', download=True, train=True, transform=ToTensor())
mnist_test = datasets.FashionMNIST(root='./data', download=True, train=False, transform=ToTensor())

train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(mnist_test, batch_size=32, shuffle=True)

# Build a small CNN classifier with:
#   - one Conv2d layer for single-channel 28x28 inputs
#   - a ReLU activation
#   - a Flatten layer
#   - two Linear layers ending with 10 outputs (for the 10 classes)
model = nn.Sequential(
    nn.Conv2d(1, 3, kernel_size=3, padding=1, padding_mode="reflect"),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(28*28*3, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

# Take a single image from the training set and reshape it to (1, 1, 28, 28)
image = mnist_train[0][0].reshape(1, 1, 28, 28)

# Run the image through the model to get logits
output = model(image)

# Print the output tensor shape (should be torch.Size([1, 10]))
print(output.shape)

torch.Size([1, 10])
