# Day 3: Convolutional Neural Networks (CNNs) for Medical Imaging

Welcome to **Day 3** of our elective: **AI in White Coat**. Today, we’ll dive into **Convolutional Neural Networks (CNNs)** using **PyTorch**, focusing on the **ChestMNIST** dataset from [MedMNIST](https://medmnist.com/). CNNs are the foundation of many **state-of-the-art** medical imaging applications, such as:

- Automated X-ray classification (normal vs. pneumonia).
- Detection of lesions or tumors on CT/MRI.
- More advanced tasks like segmentation and object detection.

By the end of this notebook, you’ll:
1. Understand the basics of CNN layers and why they’re well-suited for images.
2. Use a *prompt-first* approach to set up a simple CNN in PyTorch.
3. Train and evaluate the CNN on the **ChestMNIST** dataset.
4. Observe how CNNs can outperform simple models on image tasks.

---
## 1. Recap & Prerequisites
So far:
- **Day 1**: HPC setup, logging in, environment checks, prompt-first approach.
- **Day 2**: Basic ML classification with logistic regression on **BloodMNIST**.

Today, we’ll leverage **PyTorch** for a deeper neural network approach. Make sure you have PyTorch installed on the HPC. If not, ask an LLM or your mentor for the correct install commands.


## 2. PyTorch Installation Check

If PyTorch isn’t already installed, we can install it with:
```
!pip install torch torchvision
```
But let’s do a quick check first. If you’re missing anything, prompt your LLM for an install command.

In [None]:
# ====== Environment Check for PyTorch ======
try:
    import torch
    print("PyTorch version:", torch.__version__)
    if torch.cuda.is_available():
        print("CUDA is available! GPU ready.")
    else:
        print("No GPU detected. Training might be slower on CPU.")
except ImportError:
    print("PyTorch not found. Please install via pip or conda.")

## 3. About the ChestMNIST Dataset

**ChestMNIST** is part of **MedMNIST**, containing chest X-ray images labeled as **one of five classes** (normal lung, lung opacity, etc.). Each image is 28x28 pixels, grayscale (similar to classic MNIST style, but medical!).

- Dataset link: [Hugging Face: MedMNIST/ChestMNIST](https://huggingface.co/datasets/MedMNIST/chestmnist)
- Typical splits: train, val, test
- Task: multi-class classification (if the data is set up that way) or multi-label classification depending on version. For simplicity, we’ll treat it as multi-class.

### Prompt Example:
```
I'm working in a Jupyter notebook with PyTorch.
Please generate code to load ChestMNIST from 'MedMNIST/chestmnist'
using the 'datasets' library, then show me how to explore it.
```

In [None]:
# ====== Load ChestMNIST ======
from datasets import load_dataset

chestmnist = load_dataset("MedMNIST/chestmnist")
print(chestmnist)

# Let's peek at one sample from the train split.
sample = chestmnist['train'][0]
sample

## 4. Data Preprocessing & Dataloaders

In PyTorch, we typically create **Datasets** and **Dataloaders** to handle batching, shuffling, etc. For image data:
1. Convert each image to a **torch.Tensor**.
2. Normalize or scale pixel values if needed (e.g., from `[0, 255]` to `[0, 1]`).
3. Use `torch.utils.data.DataLoader` to batch and shuffle.

### Prompt Example:
```
Please generate PyTorch code to:
1. Convert the 'train' and 'test' splits from chestmnist into torch Datasets.
2. Create DataLoaders with a batch size of 32.
3. Return (image_tensor, label_tensor) pairs.
```

In [None]:
# ====== LLM-GENERATED CODE CELL: Create PyTorch DataLoaders ======
import torch
from torch.utils.data import Dataset, DataLoader

class ChestMnistDataset(Dataset):
    def __init__(self, split_data):
        self.data = split_data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        # image: 28x28 grayscale
        image = torch.tensor(item['image'], dtype=torch.float32)
        # Expand dims to [1, 28, 28] for a single-channel image
        image = image.unsqueeze(0)
        # label may be multi-class; if it's just one int, treat as single class
        label = torch.tensor(item['label'], dtype=torch.long)
        return image, label

# Prepare train, validation, and test sets
train_ds = ChestMnistDataset(chestmnist['train'])
# Check if there's a validation split
val_ds = None
if 'validation' in chestmnist:
    val_ds = ChestMnistDataset(chestmnist['validation'])
test_ds = ChestMnistDataset(chestmnist['test'])

# Create DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = None
if val_ds:
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

print("DataLoaders ready!")

## 5. Defining a Simple CNN in PyTorch

A typical CNN has:
- **Convolution layers**: Extract spatial features.
- **Pooling**: Downsample the feature maps.
- **Fully connected**: Final classification.

We’ll do a small architecture with a couple of conv layers. **ChestMNIST** images are 28×28 grayscale, so input channels = 1.

### Prompt Example:
```
Please generate a simple PyTorch CNN model for single-channel 28x28 images,
with two convolutional layers, ReLU, and a final linear output.
```

In [None]:
# ====== LLM-GENERATED CODE CELL: Define CNN ======
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        # After two pooling operations, 28x28 -> 14x14 -> 7x7
        # With 32 channels, that becomes 32 * 7 * 7 = 1568
        self.fc = nn.Linear(32 * 7 * 7, num_classes)

    def forward(self, x):
        # x shape: [batch_size, 1, 28, 28]
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc(x)
        return x

model = SimpleCNN(num_classes=5)
print(model)

We assume there are **5 classes** in ChestMNIST (the exact count depends on the dataset version). Adjust if needed.

## 6. Training Loop

We’ll write a standard PyTorch training loop. High-level steps:
1. Move data to GPU if available.
2. Forward pass → compute loss → backprop → update.
3. Track accuracy over the epoch.

### Prompt Example
```
Please generate a PyTorch training loop for the SimpleCNN model,
using cross-entropy loss and Adam optimizer. One or two epochs is enough for demo.
```

In [None]:
# ====== LLM-GENERATED CODE CELL: Training ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 2  # We'll just do 2 for demonstration.

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}")

## 7. Validation & Testing

If **ChestMNIST** includes a validation set, we can evaluate it after each epoch. If not, we can go directly to the **test set**. We’ll show a basic test loop below.

### Prompt Example
```
Please generate code to evaluate the trained CNN on the test_loader.
Compute accuracy, print the result.
```

In [None]:
# ====== LLM-GENERATED CODE CELL: Testing ======
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        test_correct += (predicted == labels).sum().item()
        test_total += labels.size(0)

test_acc = test_correct / test_total
print(f"Test Accuracy: {test_acc:.4f}")

Depending on the number of epochs and batch size, you may see accuracy that’s already **better** than our simple logistic regression approach. For more epochs or a more complex architecture, you can push that higher.

## 8. Confusion Matrix (Optional)

Similar to Day 2, you might want to generate a **confusion matrix** to see how the model performs across classes.

### Prompt Example
```
Generate code to compute a confusion matrix for the test set,
store predictions and labels, then visualize with matplotlib.
```

In [None]:
# ====== OPTIONAL: Confusion Matrix ======
from sklearn.metrics import confusion_matrix

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:")
print(cm)

import matplotlib.pyplot as plt
plt.matshow(cm, cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

## 9. Observations & Clinical Relevance

- **CNN Performance**: Even a small CNN can outperform classical models (like logistic regression on flattened pixels) for image tasks.
- **Medical Tie-In**: Automated chest X-ray analysis has huge potential—though real-world models typically use larger images, more advanced architectures, and extensive data.
- **Next Steps**: If time permits, try advanced techniques (e.g., more layers, data augmentation, or transfer learning from a bigger pretrained model like ResNet).


## 10. Assignment #3: Extend & Experiment

**Task**:
1. Increase the number of epochs (e.g., 5 or 10) and observe if accuracy improves.
2. Explore **data augmentations** (random flips or rotations) to see if the model generalizes better.
3. Create a simple function to **visualize** a few **model predictions** vs. **true labels** on the test set.

### Bonus
- Try **transfer learning** with a pretrained CNN (e.g., ResNet-18) by resizing images to 224x224. This may require more HPC resources.
- Reflect in your daily log: How could an automated chest X-ray classifier fit into a clinical workflow? What are the limitations and ethical considerations?


# End of Day 3 Notebook

Today, you learned how to:
- Load a **medical** imaging dataset (ChestMNIST) suitable for CNNs.
- Define and train a **Convolutional Neural Network** in **PyTorch**.
- Evaluate its performance and interpret results (accuracy, confusion matrix).

**Keep your daily log** updated with your progress, challenges, and reflections.

Happy CNN-ing! We’ll continue exploring advanced topics in the upcoming sessions.
