# CNN Classifier Training (Supervised Baseline)

Trains a lightweight CNN for supervised defect classification on **NEU Surface Defect** dataset.

**Key Info:**
- 6 defect classes: crazing, inclusion, patches, pitted_surface, rolled-in_scale, scratches
- Supervised classification (not anomaly detection)
- Used as baseline to compare with unsupervised methods

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

from src.config import DEVICE, MODELS_DIR, FIGURES_DIR, NEU_CATEGORIES, ensure_dirs
from src.data import NEUDataset
from src.models import create_cnn_classifier
from src.training import get_optimizer, get_scheduler

ensure_dirs()
print(f"Device: {DEVICE}")
print(f"Classes: {NEU_CATEGORIES}")

## Configuration & Data Loading

In [None]:
CONFIG = {
    'batch_size': 16,
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'num_classes': 6
}

train_dataset = NEUDataset(split='train')
val_dataset = NEUDataset(split='validation')

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## Create Model

In [None]:
model = create_cnn_classifier(num_classes=CONFIG['num_classes']).to(DEVICE)
optimizer = get_optimizer(model, lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = get_scheduler(optimizer, patience=3, factor=0.5)
criterion = nn.CrossEntropyLoss()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Training Loop

In [None]:
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
best_acc = 0.0

for epoch in tqdm(range(1, CONFIG['num_epochs'] + 1), desc='Training'):
    # Train
    model.train()
    train_loss = 0
    for img, label in train_loader:
        img, label = img.to(DEVICE), label.to(DEVICE)
        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    # Validate
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for img, label in val_loader:
            img, label = img.to(DEVICE), label.to(DEVICE)
            logits = model(img)
            val_loss += criterion(logits, label).item()
            correct += (logits.argmax(1) == label).sum().item()
    
    avg_train = train_loss / len(train_loader)
    avg_val = val_loss / len(val_loader)
    acc = correct / len(val_dataset)
    
    history['train_loss'].append(avg_train)
    history['val_loss'].append(avg_val)
    history['val_acc'].append(acc)
    
    scheduler.step(avg_val)
    
    if acc > best_acc:
        best_acc = acc
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Train Loss={avg_train:.4f}, Val Loss={avg_val:.4f}, Val Acc={acc:.4f}")

print(f"\nBest validation accuracy: {best_acc:.4f}")

## Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

epochs = range(1, len(history['train_loss']) + 1)

axes[0].plot(epochs, history['train_loss'], 'b-', label='Train')
axes[0].plot(epochs, history['val_loss'], 'r-', label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(epochs, history['val_acc'], 'g-', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Accuracy')
axes[1].grid(True, alpha=0.3)

plt.suptitle('CNN Classifier Training - NEU Surface Defect', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'cnn_training_curves.png', dpi=150)
plt.show()

## Confusion Matrix & Classification Report

In [None]:
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for img, label in val_loader:
        logits = model(img.to(DEVICE))
        preds = logits.argmax(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(label.numpy())

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=NEU_CATEGORIES, yticklabels=NEU_CATEGORIES, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix - CNN Classifier')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'cnn_confusion_matrix.png', dpi=150)
plt.show()

# Classification Report
print("\n=== Classification Report ===")
print(classification_report(all_labels, all_preds, target_names=NEU_CATEGORIES))

## Save Model

In [None]:
save_path = MODELS_DIR / 'cnn_classifier_final.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'history': history,
    'accuracy': best_acc,
}, save_path)

print(f"Model saved to: {save_path}")
print(f"\n=== Training Summary ===")
print(f"Epochs: {CONFIG['num_epochs']}")
print(f"Best Accuracy: {best_acc:.4f}")