In [12]:
import torch
import numpy as np
import os
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Question 2: Classification Using CNN

### Data Loading and Preprocessing

In [14]:
def load_mnist_data(path):
    splits = ['train', 'val', 'test']
    data = {'train': [], 'val': [], 'test': []}
    labels = {'train': [], 'val': [], 'test': []}
    for split in splits:
        split_path = os.path.join(path, split)
        for label in os.listdir(split_path):
            label_path = os.path.join(split_path, label)
            if (int(label) == 0):
                cur_label = 0
            else:
                cur_label = len(label)
            if os.path.isdir(label_path):
                for image_name in os.listdir(label_path):
                    image_path = os.path.join(label_path, image_name)
                    try:
                        image = Image.open(image_path).convert('L')
                        image_array = np.array(image)
                        data[split].append(image_array)
                        labels[split].append(cur_label)
                    except Exception as e:
                        print(f"Error loading image {image_name}: {e}")
    
    return data['train'], labels['train'], data['val'], labels['val'], data['test'], labels['test']



In [15]:
data_path = "./../../data/external/double_mnist"

train_data, train_labels, val_data, val_labels, test_data, test_labels = load_mnist_data(data_path)

In [16]:
class MultiMNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = torch.tensor(image, dtype=torch.float32) / 255.0
        image = image.unsqueeze(0)

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long).to(device)
        return image, label


In [17]:
train_data = MultiMNISTDataset(train_data, train_labels)
val_data = MultiMNISTDataset(val_data, val_labels)
test_data = MultiMNISTDataset(test_data, test_labels)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

### Implement the CNN Class

In [18]:
class CNN(nn.Module):
    def __init__(self, task='classification', num_classes=10):
        super(CNN, self).__init__()
        
        self.task = task
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        
        self.fc1 = None
        self.fc2 = nn.Linear(128, 64)
        
        if task == 'classification':
            self.fc3 = nn.Linear(64, num_classes)
        elif task == 'regression':
            self.fc3 = nn.Linear(64, 1)

    def _initialize_fc(self, input_shape, device):
        dummy_input = torch.zeros(1, *input_shape).to(device)
        with torch.no_grad():
            output = self._forward_conv(dummy_input)
        flattened_size = output.view(-1).shape[0]
        self.fc1 = nn.Linear(flattened_size, 128).to(device)

    def _forward_conv(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        return x

    def forward(self, x):
        if self.fc1 is None:
            self._initialize_fc(x.shape[1:], x.device)

        x = self._forward_conv(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
    def predict(self, x):
        self.eval()
        with torch.no_grad():
            y_pred = self.forward(x)
        return y_pred
    
    def get_accuracy(self, y_pred, y_true):
        if self.task == 'classification':
            y_pred = torch.argmax(y_pred, dim=1)
            return (y_pred == y_true).float().mean()
        elif self.task == 'regression':
            return F.mse_loss(y_pred, y_true)

    
    def loss(self, y_pred, y_true):
        if self.task == 'classification':
            return F.cross_entropy(y_pred, y_true)
        elif self.task == 'regression':
            return F.mse_loss(y_pred, y_true)

    


In [23]:
from tqdm import tqdm

def train(model, optimizer, train_loader, val_loader, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        train_progress = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
        
        for i, (x, y) in train_progress:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = model.loss(y_pred, y)
            loss.backward()
            optimizer.step()
            train_progress.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        model.eval()
        total_accuracy = 0
        total_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                y_pred = model(x)
                total_loss += model.loss(y_pred, y).item()
                total_accuracy += model.get_accuracy(y_pred, y).item()
                
            avg_loss = total_loss / len(val_loader)
            accuracy = total_accuracy / len(val_loader)
            print(f"Epoch {epoch + 1}, Validation Accuracy: {accuracy*100:.2f}%, Validation Loss: {avg_loss:.6f}")


In [24]:
model = CNN(task='classification', num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(model, optimizer, train_loader, val_loader, num_epochs=10)

Epoch 1/10: 100%|██████████| 394/394 [00:06<00:00, 60.74it/s, Loss=0.0490]


Epoch 1, Validation Accuracy: 99.10%, Validation Loss: 0.059389


Epoch 2/10: 100%|██████████| 394/394 [00:06<00:00, 60.39it/s, Loss=0.0190]


Epoch 2, Validation Accuracy: 99.34%, Validation Loss: 0.037673


Epoch 3/10: 100%|██████████| 394/394 [00:05<00:00, 65.72it/s, Loss=0.1635]


Epoch 3, Validation Accuracy: 99.63%, Validation Loss: 0.040604


Epoch 4/10: 100%|██████████| 394/394 [00:05<00:00, 68.60it/s, Loss=0.0057]


Epoch 4, Validation Accuracy: 99.50%, Validation Loss: 0.020367


Epoch 5/10: 100%|██████████| 394/394 [00:05<00:00, 69.68it/s, Loss=0.0043]


Epoch 5, Validation Accuracy: 99.63%, Validation Loss: 0.016056


Epoch 6/10: 100%|██████████| 394/394 [00:05<00:00, 70.36it/s, Loss=0.0398]


Epoch 6, Validation Accuracy: 99.77%, Validation Loss: 0.011964


Epoch 7/10: 100%|██████████| 394/394 [00:06<00:00, 64.82it/s, Loss=0.0131]


Epoch 7, Validation Accuracy: 99.73%, Validation Loss: 0.032205


Epoch 8/10: 100%|██████████| 394/394 [00:05<00:00, 70.79it/s, Loss=0.0074]


Epoch 8, Validation Accuracy: 99.77%, Validation Loss: 0.015004


Epoch 9/10: 100%|██████████| 394/394 [00:05<00:00, 70.70it/s, Loss=0.0100]


Epoch 9, Validation Accuracy: 99.53%, Validation Loss: 0.020547


Epoch 10/10: 100%|██████████| 394/394 [00:07<00:00, 51.32it/s, Loss=0.0177]


Epoch 10, Validation Accuracy: 99.57%, Validation Loss: 0.029821
