# TP: CNN for Binary Classification (1 vs 8)

_From [Dataflowr Module 6](https://dataflowr.github.io/website/modules/6-convolutional-neural-network/) by Marc Lelarge_

In this practical, you will build a Convolutional Neural Network (CNN) that learns filter weights to classify MNIST digits (1 vs 8).

## Setup

In [29]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from matplotlib import pyplot as plt

device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [30]:
# plot multiple images
def plots(ims, interp=False, titles=None):
    if isinstance(ims, torch.Tensor):
        ims = ims.cpu().numpy()
    elif isinstance(ims, list) and len(ims) > 0 and isinstance(ims[0], torch.Tensor):
        ims = torch.stack(ims).cpu().numpy()
    mn, mx = ims.min(), ims.max()
    f = plt.figure(figsize=(12, 24))
    for i in range(len(ims)):
        sp = f.add_subplot(1, len(ims), i + 1)
        if not titles is None:
            sp.set_title(titles[i], fontsize=18)
        plt.imshow(ims[i], interpolation=None if interp else "none", vmin=mn, vmax=mx)


# plot a single image
def plot(im, interp=False):
    if isinstance(im, torch.Tensor):
        im = im.cpu().numpy()
    f = plt.figure(figsize=(3, 6), frameon=True)
    plt.imshow(im, interpolation=None if interp else "none")


plt.gray()
plt.close()

## Load MNIST Data

In [31]:
root_dir = "./data/MNIST/"
train_set = torchvision.datasets.MNIST(root=root_dir, train=True, download=True)

# Extract images and labels
images = train_set.data.float() / 255
labels = train_set.targets

print(f"Images shape: {images.shape}")
print(f"Labels shape: {labels.shape}")

Images shape: torch.Size([60000, 28, 28])
Labels shape: torch.Size([60000])


The following lines of code implement a data loader for the train set and the test set. No modification is needed.

In [32]:
eights = torch.stack([i for (i, l) in zip(images, labels) if l == 8])
ones = torch.stack([i for (i, l) in zip(images, labels) if l == 1])

In [33]:
bs = 64

l8 = torch.tensor(0, dtype=torch.long)
eights_dataset = [[e.unsqueeze(0), l8] for e in eights]
l1 = torch.tensor(1, dtype=torch.long)
ones_dataset = [[e.unsqueeze(0), l1] for e in ones]
train_dataset = eights_dataset[1000:] + ones_dataset[1000:]
test_dataset = eights_dataset[:1000] + ones_dataset[:1000]

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs, shuffle=True)

## Task: Building a CNN

You will build a neural network that learns the weights of convolutional filters.

### Architecture:

The network should have:
1. **Convolutional layer**: 8 filters of size 3×3 (input: 1 channel, output: 8 channels)
2. **Max Pooling layer**: kernel size 7×7, stride 7 (reduces 28×28 to 4×4)
3. **Flatten**: converts [batch, 8, 4, 4] to [batch, 128]
4. **Linear layer**: maps 128 features to 2 classes (1 vs 8)

### TODO:

Complete the CNN class below. You'll need to:
- Set the correct `padding` value for the convolutional layer
- Implement the `forward` method with the correct sequence of operations

**Hints:**
- Use `F.max_pool2d(x, kernel_size=7, stride=7)` for max pooling
- Use `F.log_softmax(x, dim=1)` for the final output (works with NLLLoss)
- Don't forget to flatten between pooling and the linear layer!

**Documentation:**
- [nn.Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
- [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
- [F.max_pool2d](https://pytorch.org/docs/stable/generated/torch.nn.functional.max_pool2d.html)

In [34]:
class classifier(nn.Module):
    
    def __init__(self):
        super(classifier, self).__init__()
        # TODO: fill the missing padding value
        # Hint: what padding keeps the spatial dimensions unchanged?
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.fc = nn.Linear(in_features=128, out_features=2)
        
    def forward(self, x):
        # TODO: Implement your network here
        x = self.conv1(x)
        x = F.max_pool2d(x, kernel_size=7, stride=7)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        
        return x

## Test your model

Create an instance and test with a batch of 3 images.

In [35]:
conv_class = classifier()
print(conv_class)

classifier(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=128, out_features=2, bias=True)
)


In [36]:
# Test with a batch of 3 images
batch_3images = train_set.data[0:3].float().unsqueeze(1)
output = conv_class(batch_3images)
print(f"Input shape: {batch_3images.shape}")
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")

Input shape: torch.Size([3, 1, 28, 28])
Output shape: torch.Size([3, 2])
Output:
tensor([[  0.0000, -71.7371],
        [  0.0000, -57.7277],
        [-27.3115,   0.0000]], grad_fn=<LogSoftmaxBackward0>)


## Implement the Training Loop

Complete the training function below.

**Hints:**
- Use `model.train(True)` to set the model to training mode
- For each batch:
  1. Zero the gradients: `optimizer.zero_grad()`
  2. Forward pass: `outputs = model(inputs)`
  3. Compute loss: `loss = loss_fn(outputs, labels)`
  4. Backward pass: `loss.backward()`
  5. Update weights: `optimizer.step()`
- Track running loss and accuracy

In [37]:
def train(model, data_loader, loss_fn, optimizer, n_epochs=1):
    model.train(True)
    loss_train = torch.zeros(n_epochs)
    acc_train = torch.zeros(n_epochs)

    for epoch_num in range(n_epochs):
        running_corrects = 0.0
        running_loss = 0.0
        size = 0

        for data in data_loader:
            inputs, labels = data
            bs = labels.size(0)

            # TODO: Implement the training step
            # 1. Zero gradients
            optimizer.zero_grad()

            # 2. Forward pass
            outputs = model(inputs)

            # 3. Compute loss
            loss = loss_fn(outputs, labels)

            # 4. Backward pass
            loss.backward()

            # 5. Optimizer step
            optimizer.step()

            # 6. Track statistics (running_loss and running_corrects)
            # Hint: use torch.max(outputs, 1) to get predictions
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * bs
            running_corrects += torch.sum(preds == labels.data).item()

            size += bs

        epoch_loss = running_loss / size
        epoch_acc = running_corrects / size
        loss_train[epoch_num] = epoch_loss
        acc_train[epoch_num] = epoch_acc
        print(
            f"Epoch {epoch_num+1}/{n_epochs} - Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}"
        )

    return loss_train, acc_train

## Setup Loss and Optimizer

Choose appropriate loss function and optimizer:
- For multi-class classification with log_softmax output, what do we need?
- Start with SGD optimizer with learning rate 1e-3

In [38]:
conv_class = classifier()

# TODO: Choose the appropriate loss function
loss_fn = nn.NLLLoss()

# TODO: Create SGD optimizer with learning_rate=1e-3
learning_rate = 1e-3
optimizer_cl = torch.optim.SGD(conv_class.parameters(), lr=learning_rate)

Train for 10 epochs

In [39]:
l_t, a_t = train(conv_class, train_loader, loss_fn, optimizer_cl, n_epochs=10)

Epoch 1/10 - Loss: 0.6738 Acc: 0.5783
Epoch 2/10 - Loss: 0.6334 Acc: 0.8213
Epoch 3/10 - Loss: 0.5986 Acc: 0.8688
Epoch 4/10 - Loss: 0.5664 Acc: 0.8900
Epoch 5/10 - Loss: 0.5363 Acc: 0.8978
Epoch 6/10 - Loss: 0.5081 Acc: 0.9072
Epoch 7/10 - Loss: 0.4816 Acc: 0.9093
Epoch 8/10 - Loss: 0.4567 Acc: 0.9152
Epoch 9/10 - Loss: 0.4334 Acc: 0.9209
Epoch 10/10 - Loss: 0.4117 Acc: 0.9261


## Test Function

Evaluate the model on the test set.

**Hints:**
- Use `model.train(False)` or `model.eval()` to set evaluation mode
- Use `torch.no_grad()` to disable gradient computation
- Calculate test loss and accuracy

In [45]:
def test(model, data_loader, loss_fn):
    model.train(False)

    running_corrects = 0.0
    running_loss = 0.0
    size = 0

    with torch.no_grad():
        for data in data_loader:
            inputs, labels = data
            bs = labels.size(0)

            # TODO: Implement testing
            # 1. Forward pass
            outputs = model(inputs)

            # 2. Compute loss
            loss = loss_fn(outputs, labels)

            # 3. Track statistics (running_loss and running_corrects)
            # Hint: use torch.max(outputs, 1) to get predictions
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * bs
            running_corrects += torch.sum(preds == labels.data).item()

            size += bs

    print(
        f"Test - Loss: {running_loss / size:.4f} Acc: {running_corrects / size:.4f}"
    )

In [46]:
test(conv_class, test_loader, loss_fn)

Test - Loss: 0.3981 Acc: 0.9305


Try [Adam](https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html) optimizer instead of SGD.

Reset the model and try training with Adam optimizer instead of SGD.

In [None]:
# TODO: Create a new model and train with Adam optimizer



How many parameters did your network learn?

## Analyze learned filters

Let's visualize what the network learned!

In [None]:
# Count parameters
total_params = sum(p.numel() for p in conv_class.parameters())
print(f"Total parameters: {total_params}")

# Break down by layer
for name, param in conv_class.named_parameters():
    print(f"{name}: {param.numel()} parameters, shape {param.shape}")

In [None]:
# View learned filters
for m in conv_class.children():
    print("Weights:", m.weight.data)
    print("Bias:", m.bias.data)

In [None]:
# Extract and visualize the 8 learned filters
T_w = conv_class.conv1.weight.data
T_b = conv_class.conv1.bias.data

# Plot the 8 learned 3x3 filters
plots([T_w[i][0] for i in range(8)], titles=[f"Filter {i}" for i in range(8)])