In [None]:
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if DEVICE == 'cpu':
    import multiprocessing
    torch.set_num_threads(multiprocessing.cpu_count() or 4)
print('CUDA:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))

In [None]:

from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np

class CarColorDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

def collect_dataset(data_dir='data', min_samples=50):
    data_dir = Path(data_dir)
    all_images = list(data_dir.rglob('*.jpg')) + list(data_dir.rglob('*.jpeg'))
    color_to_images = {}
    for img_path in all_images:
        filename = img_path.name
        if '$$' in filename:
            parts = filename.split('$$')
            if len(parts) >= 4:
                color = parts[3].lower()
                if color == 'grey': color = 'gray'
                if color in ['unlisted', 'multicolour']: continue
                if color not in color_to_images: color_to_images[color] = []
                color_to_images[color].append(str(img_path))
    color_to_images = {c: imgs for c, imgs in color_to_images.items() if len(imgs) >= min_samples}
    valid_colors = set(color_to_images.keys())
    class_names = sorted(valid_colors)
    class_to_idx = {name: idx for idx, name in enumerate(class_names)}
    image_paths, labels = [], []
    for color, imgs in color_to_images.items():
        if color in valid_colors:
            image_paths.extend(imgs)
            labels.extend([class_to_idx[color]] * len(imgs))
    return image_paths, labels, class_names

data_dir = 'data' if Path('data').exists() else 'confirmed_fronts'
image_paths, labels, class_names = collect_dataset(data_dir, min_samples=600)
num_classes = len(class_names)
print('Classes:', num_classes, 'Imgs:', len(image_paths))


In [None]:
import torch
ON_CPU = not torch.cuda.is_available()
IMG_SIZE = 224
BATCH_SIZE = 128
NUM_WORKERS = 0
print('CPU mode:', ON_CPU, '| img_size:', IMG_SIZE, '| batch:', BATCH_SIZE)

train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    image_paths, labels, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels)

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    transforms.RandomGrayscale(p=0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

train_ds = CarColorDataset(train_paths, train_labels, train_tf)
val_ds = CarColorDataset(val_paths, val_labels, val_tf)
test_ds = CarColorDataset(test_paths, test_labels, val_tf)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=not ON_CPU)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=not ON_CPU)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=not ON_CPU)
print('Train:', len(train_paths), 'Val:', len(val_paths), 'Test:', len(test_paths))

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, num_classes=15):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def _make_layer(self, in_ch, out_ch, blocks, stride):
        layers = [BasicBlock(in_ch, out_ch, stride)]
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_ch, out_ch))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import datetime

LOG_PATH = 'training.log'
log_file = open(LOG_PATH, 'w', encoding='utf-8')
def log(msg):
    print(msg)
    log_file.write(msg + '\n')
    log_file.flush()
log(f'Started {datetime.datetime.now().isoformat()}')

NUM_EPOCHS = 10
LR = 0.001

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, lbl in loader:
        images, lbl = images.to(device, non_blocking=True), lbl.to(device)
        optimizer.zero_grad()
        out = model(images)
        loss = criterion(out, lbl)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        correct += (out.argmax(1) == lbl).sum().item()
        total += lbl.size(0)
    return total_loss / total, correct / total

def evaluate(model, loader, device):
    model.eval()
    preds, gt = [], []
    with torch.no_grad():
        for images, lbl in loader:
            out = model(images.to(device))
            preds.extend(out.argmax(1).cpu().numpy())
            gt.extend(lbl.numpy())
    return f1_score(gt, preds, average='macro', zero_division=0)

def train_model(model, name, epochs=NUM_EPOCHS, lr=LR):
    model = model.to(DEVICE)
    weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
    weights = torch.FloatTensor(weights).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    best_f1 = 0.0
    for ep in range(epochs):
        loss, acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        f1_val = evaluate(model, val_loader, DEVICE)
        f1_train = evaluate(model, train_loader, DEVICE)
        scheduler.step()
        if f1_val > best_f1:
            best_f1 = f1_val
            torch.save(model.state_dict(), f'{name}_best.pth')
        log(f'{name} ep{ep+1}/{epochs} loss={loss:.3f} acc={acc:.3f} f1_train={f1_train:.3f} val_f1={f1_val:.3f}')
    return best_f1

In [None]:
model_custom = ResNet(num_classes=num_classes)
f1_custom = train_model(model_custom, 'custom_resnet')
log('Custom ResNet best val F1: ' + str(f1_custom))

In [None]:
from torchvision import models

def get_pretrained_resnet18(num_classes):
    m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

model_resnet18 = get_pretrained_resnet18(num_classes)
f1_resnet18 = train_model(model_resnet18, 'pretrained_resnet18', lr=1e-4)
log('ResNet18 best val F1: ' + str(f1_resnet18))

In [None]:
def get_pretrained_mobilenetv2(num_classes):
    m = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
    m.classifier[1] = nn.Linear(m.last_channel, num_classes)
    return m

model_mobilenet = get_pretrained_mobilenetv2(num_classes)
f1_mobilenet = train_model(model_mobilenet, 'pretrained_mobilenetv2', lr=1e-4)
log('MobileNetV2 best val F1: ' + str(f1_mobilenet))

In [None]:
def test_f1(model):
    model.eval()
    preds, gt = [], []
    with torch.no_grad():
        for images, lbl in test_loader:
            out = model(images.to(DEVICE))
            preds.extend(out.argmax(1).cpu().numpy())
            gt.extend(lbl.numpy())
    return f1_score(gt, preds, average='macro', zero_division=0)

model_custom.load_state_dict(torch.load('custom_resnet_best.pth', map_location=DEVICE))
model_resnet18.load_state_dict(torch.load('pretrained_resnet18_best.pth', map_location=DEVICE))
model_mobilenet.load_state_dict(torch.load('pretrained_mobilenetv2_best.pth', map_location=DEVICE))

t_custom = test_f1(model_custom)
t_resnet18 = test_f1(model_resnet18)
t_mobilenet = test_f1(model_mobilenet)

log('Test F1:')
log('  Custom ResNet: ' + str(round(t_custom, 4)))
log('  ResNet18: ' + str(round(t_resnet18, 4)))
log('  MobileNetV2: ' + str(round(t_mobilenet, 4)))
best_name = 'Custom ResNet' if t_custom >= max(t_resnet18, t_mobilenet) else ('ResNet18' if t_resnet18 >= t_mobilenet else 'MobileNetV2')
log('Best: ' + best_name + ' | F1>0.8: ' + str(max(t_custom, t_resnet18, t_mobilenet) > 0.8))
log_file.close()

In [15]:
import matplotlib.pyplot as plt
import re

def plot_log_metrics(logfile):
    epochs = []
    losses = []
    accs = []
    train_f1s = []
    val_f1s = []

    pattern = re.compile(
        r'ep(\d+)/(\d+) loss=([0-9.]+) acc=([0-9.]+) f1_train=([0-9.]+) val_f1=([0-9.]+)'
    )
    with open(logfile, encoding='utf-8') as f:
        for line in f:
            match = pattern.search(line)
            if match:
                epoch = int(match.group(1))
                loss = float(match.group(3))
                acc = float(match.group(4))
                train_f1 = float(match.group(5))
                val_f1 = float(match.group(6))

                epochs.append(epoch)
                losses.append(loss)
                accs.append(acc)
                train_f1s.append(train_f1)
                val_f1s.append(val_f1)

    plt.figure(figsize=(15, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, label='Loss')
    plt.plot(epochs, accs, label='Accuracy')
    plt.plot(epochs, train_f1s, label='Train F1')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Train Loss, Accuracy, Train F1')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_f1s, label='Validation F1', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('F1')
    plt.title('Validation F1')
    plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
plot_log_metrics('resnet18_logs.log')
plot_log_metrics('resnet_custom_logs.log')
plot_log_metrics('mobilenetv2_logs.log')