# CNN Classifier Training (Supervised Baseline)

Trains a lightweight CNN for supervised defect classification on NEU Surface Defect dataset.

In [None]:
import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

from src.config import DEVICE, MODELS_DIR, FIGURES_DIR, NEU_CATEGORIES, ensure_dirs
from src.data import NEUDataset
from src.models import create_cnn_classifier
from src.training import get_optimizer, get_scheduler

ensure_dirs()

In [None]:
CONFIG = {'batch_size': 16, 'num_epochs': 50, 'learning_rate': 1e-3, 'num_classes': 6}

train_dataset = NEUDataset(split='train')
val_dataset = NEUDataset(split='validation')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=CONFIG['batch_size'])

model = create_cnn_classifier(num_classes=CONFIG['num_classes']).to(DEVICE)
optimizer = get_optimizer(model, lr=CONFIG['learning_rate'])
criterion = nn.CrossEntropyLoss()

print(f'Train: {len(train_dataset)}, Val: {len(val_dataset)}')
print(f'Classes: {NEU_CATEGORIES}')

In [None]:
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

for epoch in tqdm(range(1, CONFIG['num_epochs'] + 1)):
    # Train
    model.train()
    train_loss = 0
    for img, label in train_loader:
        img, label = img.to(DEVICE), label.to(DEVICE)
        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    # Validate
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for img, label in val_loader:
            img, label = img.to(DEVICE), label.to(DEVICE)
            logits = model(img)
            val_loss += criterion(logits, label).item()
            correct += (logits.argmax(1) == label).sum().item()
    
    history['train_loss'].append(train_loss / len(train_loader))
    history['val_loss'].append(val_loss / len(val_loader))
    history['val_acc'].append(correct / len(val_dataset))

print(f'Final val accuracy: {history["val_acc"][-1]:.4f}')

In [None]:
# Confusion matrix
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for img, label in val_loader:
        preds = model.predict(img.to(DEVICE))
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(label.numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=NEU_CATEGORIES, yticklabels=NEU_CATEGORIES, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig(FIGURES_DIR / 'cnn_confusion_matrix.png', dpi=150)
plt.show()

print(classification_report(all_labels, all_preds, target_names=NEU_CATEGORIES))

In [None]:
torch.save({'model_state_dict': model.state_dict(), 'config': CONFIG, 'accuracy': history['val_acc'][-1]}, 
           MODELS_DIR / 'cnn_classifier_final.pth')
print('Model saved!')