In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm


BASE_DIR = "/content/drive/MyDrive/helmetguard"
print("BASE_DIR:", BASE_DIR)

SYNTH_DIR = os.path.join(BASE_DIR, "data_synth")
REAL_TRAIN_DIR = os.path.join(BASE_DIR, "data_real", "train")
REAL_TEST_DIR = os.path.join(BASE_DIR, "data_real", "test")

print("SYNTH_DIR:", SYNTH_DIR, "->", os.listdir(SYNTH_DIR))
print("REAL_TRAIN_DIR:", REAL_TRAIN_DIR, "->", os.listdir(REAL_TRAIN_DIR))
print("REAL_TEST_DIR:", REAL_TEST_DIR, "->", os.listdir(REAL_TEST_DIR))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Mounted at /content/drive
BASE_DIR: /content/drive/MyDrive/helmetguard
SYNTH_DIR: /content/drive/MyDrive/helmetguard/data_synth -> ['safe', 'unsafe']
REAL_TRAIN_DIR: /content/drive/MyDrive/helmetguard/data_real/train -> ['safe', 'unsafe']
REAL_TEST_DIR: /content/drive/MyDrive/helmetguard/data_real/test -> ['safe', 'unsafe']
Using device: cuda


In [None]:

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


synth_full = datasets.ImageFolder(root=SYNTH_DIR, transform=train_transform)
print("Synthetic classes:", synth_full.classes)
print("Total synthetic images:", len(synth_full))


val_ratio = 0.2
val_size = int(len(synth_full) * val_ratio)
train_size = len(synth_full) - val_size

synth_train, synth_val = random_split(synth_full, [train_size, val_size])
synth_val.dataset.transform = val_transform

print("Synthetic train size:", len(synth_train))
print("Synthetic val size:", len(synth_val))


Synthetic classes: ['safe', 'unsafe']
Total synthetic images: 614
Synthetic train size: 492
Synthetic val size: 122


In [None]:

real_train = datasets.ImageFolder(root=REAL_TRAIN_DIR, transform=train_transform)
print("Real train classes:", real_train.classes)
print("Real train size:", len(real_train))


train_dataset = ConcatDataset([synth_train, real_train])
val_dataset = synth_val

print("Total train size (synth + real):", len(train_dataset))
print("Val size:", len(val_dataset))


Real train classes: ['safe', 'unsafe']
Real train size: 35
Total train size (synth + real): 527
Val size: 122


In [None]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=0, pin_memory=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=0, pin_memory=False)

print("Loaders ready.")



Loaders ready.


In [None]:

try:
    weights = models.ViT_B_16_Weights.IMAGENET1K_V1
    model_vit = models.vit_b_16(weights=weights)
    print("Loaded ViT-B/16 with pretrained ImageNet weights (new API).")
except Exception as e:
    print("Fallback to older API, error was:", e)
    model_vit = models.vit_b_16(pretrained=True)


num_features = model_vit.heads.head.in_features
model_vit.heads.head = nn.Linear(num_features, 2)

model_vit = model_vit.to(device)

criterion_vit = nn.CrossEntropyLoss()
optimizer_vit = torch.optim.Adam(model_vit.parameters(), lr=1e-4)

print("ViT model ready on device:", device)


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:01<00:00, 194MB/s]


Loaded ViT-B/16 with pretrained ImageNet weights (new API).
ViT model ready on device: cuda


In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader, leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


@torch.no_grad()
def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    print(f"\n[ViT] Epoch {epoch+1}/{num_epochs}")

    train_loss, train_acc = train_one_epoch(model_vit, train_loader, optimizer_vit, criterion_vit, device)
    val_loss, val_acc = eval_one_epoch(model_vit, val_loader, criterion_vit, device)

    print(f"[ViT] Train  Loss: {train_loss:.4f} | Acc: {train_acc*100:.2f}%")
    print(f"[ViT] Val    Loss: {val_loss:.4f} | Acc: {val_acc*100:.2f}%")



[ViT] Epoch 1/3


  0%|          | 0/17 [00:00<?, ?it/s]

[ViT] Train  Loss: 0.4177 | Acc: 80.27%
[ViT] Val    Loss: 0.2916 | Acc: 88.52%

[ViT] Epoch 2/3


  0%|          | 0/17 [00:00<?, ?it/s]

[ViT] Train  Loss: 0.1239 | Acc: 94.88%
[ViT] Val    Loss: 0.0968 | Acc: 95.90%

[ViT] Epoch 3/3


  0%|          | 0/17 [00:00<?, ?it/s]

[ViT] Train  Loss: 0.0528 | Acc: 98.10%
[ViT] Val    Loss: 0.2026 | Acc: 95.90%


In [None]:
MODEL_DIR = os.path.join(BASE_DIR, "models")
os.makedirs(MODEL_DIR, exist_ok=True)

vit_path = os.path.join(MODEL_DIR, "vit_b16_synth_plus_real_oversampled.pt")
torch.save(model_vit.state_dict(), vit_path)

print("Saved ViT model to:", vit_path)


Saved ViT model to: /content/drive/MyDrive/helmetguard/models/vit_b16_synth_plus_real_oversampled.pt


In [None]:
from torchvision import datasets

test_dataset = datasets.ImageFolder(root=REAL_TEST_DIR, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Test classes:", test_dataset.classes)
print("Test size:", len(test_dataset))


Test classes: ['safe', 'unsafe']
Test size: 44


In [None]:
@torch.no_grad()
def evaluate_on_loader(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = correct / total if total > 0 else 0
    return acc


from collections import Counter

@torch.no_grad()
def evaluate_per_class(model, loader, device, class_names):
    model.eval()
    correct_per_class = Counter()
    total_per_class = Counter()

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for y, y_pred in zip(labels.cpu().tolist(), preds.cpu().tolist()):
            total_per_class[y] += 1
            if y == y_pred:
                correct_per_class[y] += 1

    for idx, name in enumerate(class_names):
        total = total_per_class[idx]
        correct = correct_per_class[idx]
        acc = correct / total if total > 0 else 0
        print(f"Class '{name}': {correct}/{total} correct ({acc*100:.2f}%)")


In [None]:
vit_test_acc = evaluate_on_loader(model_vit, test_loader, device)
print(f"[ViT] Real test accuracy: {vit_test_acc*100:.2f}%")

evaluate_per_class(model_vit, test_loader, device, test_dataset.classes)


[ViT] Real test accuracy: 63.64%
Class 'safe': 22/30 correct (73.33%)
Class 'unsafe': 6/14 correct (42.86%)
