# Wire Color Detection - MobileNetV2 Training
Train a small CNN to classify wire colors: white, red, green, yellow

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

## Configuration

In [None]:
DATA_DIR = '<data_dir>'
OUTPUT_DIR = './checkpoints'
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001
IMG_SIZE = 224

# Color classes must match folder names
CLASSES = ['green', 'red', 'white', 'yellow']
NUM_CLASSES = len(CLASSES)

if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using MPS (Apple Silicon)')
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
else:
    device = torch.device('cpu')
    print('Using CPU')

## Data Transforms

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

## Load Dataset

Expected folder structure:
```
DATA_DIR/
  train/
    green/
    red/
    white/
    yellow/
  valid/
    green/
    red/
    white/
    yellow/
  test/
    green/
    red/
    white/
    yellow/
```

In [None]:
train_dir = Path(DATA_DIR) / 'train'
valid_dir = Path(DATA_DIR) / 'valid'
test_dir = Path(DATA_DIR) / 'test'

# load datasets
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
valid_dataset = datasets.ImageFolder(valid_dir, transform=eval_transform) if valid_dir.exists() else None
test_dataset = datasets.ImageFolder(test_dir, transform=eval_transform) if test_dir.exists() else None

# create data loaders (num_workers=0 works best with MPS)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) if valid_dataset else None
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) if test_dataset else None

print(f'Training samples: {len(train_dataset)}')
print(f'Classes: {train_dataset.classes}')
if valid_dataset:
    print(f'Validation samples: {len(valid_dataset)}')
if test_dataset:
    print(f'Test samples: {len(test_dataset)}')

## Create Model

In [None]:
def create_model(num_classes=NUM_CLASSES, pretrained=True):
    """Create MobileNetV2 model with custom classifier head."""
    model = models.mobilenet_v2(weights='IMAGENET1K_V1' if pretrained else None)
    
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(in_features, num_classes)
    )
    
    return model

model = create_model(num_classes=len(train_dataset.classes))
model = model.to(device)
print(f'Model created with {sum(p.numel() for p in model.parameters())} parameters')

## Training Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total


def validate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total

## Training Loop

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    
    log = f'Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%'
    
    if valid_loader:
        val_loss, val_acc = validate(model, valid_loader, criterion)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        log += f' | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%'
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'accuracy': val_acc,
                'classes': train_dataset.classes
            }, f'{OUTPUT_DIR}/best_model.pt')
            log += ' [BEST]'
    
    print(log)
    scheduler.step()

print(f'\nTraining complete! Best validation accuracy: {best_acc:.2f}%')

# evaluate on the test set
if test_loader:
    test_loss, test_acc = validate(model, test_loader, criterion)
    print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')

## Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['train_loss'], label='Train')
if history['val_loss']:
    ax1.plot(history['val_loss'], label='Val')
ax1.set_title('Loss')
ax1.set_xlabel('Epoch')
ax1.legend()

ax2.plot(history['train_acc'], label='Train')
if history['val_acc']:
    ax2.plot(history['val_acc'], label='Val')
ax2.set_title('Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend()

plt.tight_layout()
plt.show()

## Inference Function

In [None]:
def predict(image_path, model, classes):
    """Predict color for a single image."""
    model.eval()
    
    image = Image.open(image_path).convert('RGB')
    image_tensor = eval_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted = probabilities.max(1)
    
    return {
        'class': classes[predicted.item()],
        'confidence': confidence.item(),
        'probabilities': {c: p.item() for c, p in zip(classes, probabilities[0])}
    }

## Test Prediction

In [None]:
# load best model and test
checkpoint = torch.load(f'{OUTPUT_DIR}/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
classes = checkpoint['classes']

# Test on an image (change path as needed)
result = predict('<path_to_test_image>', model, classes)
print(f"Predicted: {result['class']} ({result['confidence']:.2%})")