<a href="https://colab.research.google.com/github/mohamedalaaaz/testpytroch/blob/main/f-35.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os, math, time, copy, json, random
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def seed_all(seed=42):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
seed_all(42)

def build_transforms(img_size=256, crop=224):
    train_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomResizedCrop(crop, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.15, hue=0.05),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    eval_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    return train_tf, eval_tf

def make_datasets(data_dir):
    train_tf, eval_tf = build_transforms()
    train_ds = datasets.ImageFolder(Path(data_dir)/"train", transform=train_tf)
    val_ds   = datasets.ImageFolder(Path(data_dir)/"val",   transform=eval_tf)
    test_ds  = datasets.ImageFolder(Path(data_dir)/"test",  transform=eval_tf)
    return train_ds, val_ds, test_ds

def make_sampler(dataset):
    # Handle class imbalance with WeightedRandomSampler
    targets = [y for _, y in dataset.samples]
    class_count = torch.bincount(torch.tensor(targets))
    class_weights = 1.0 / torch.clamp(class_count.float(), min=1.0)
    sample_weights = [class_weights[t] for t in targets]
    return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

def make_loaders(train_ds, val_ds, test_ds, batch_size=32, num_workers=4):
    sampler = make_sampler(train_ds)
    train_ld = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True)
    val_ld   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_ld  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_ld, val_ld, test_ld

def build_model(num_classes, freeze_backbone=True):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    if freeze_backbone:
        for p in model.parameters(): p.requires_grad = False
    in_feat = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(in_feat, num_classes)
    )
    return model

@torch.no_grad()
def evaluate(model, loader, num_classes):
    model.eval()
    acc = MulticlassAccuracy(num_classes=num_classes).to(DEVICE)
    f1  = MulticlassF1Score(num_classes=num_classes, average="macro").to(DEVICE)
    total, correct = 0, 0
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        logits = model(x)
        preds = logits.argmax(1)
        acc.update(preds, y)
        f1.update(preds, y)
        total += y.size(0)
        correct += (preds == y).sum().item()
    return {"acc": acc.compute().item(), "f1_macro": f1.compute().item(), "raw_acc": correct/total}

def train(
    data_dir="data",
    out_dir="runs/aircraft_cls",
    epochs=20,
    batch_size=32,
    lr=1e-3,
    weight_decay=1e-4,
    patience=5,
    freeze_backbone=True,
    num_workers=4
):
    os.makedirs(out_dir, exist_ok=True)
    train_ds, val_ds, test_ds = make_datasets(data_dir)
    class_names = train_ds.classes
    num_classes = len(class_names)
    with open(Path(out_dir)/"classes.json", "w") as f:
        json.dump(class_names, f, indent=2)

    train_ld, val_ld, test_ld = make_loaders(train_ds, val_ds, test_ds, batch_size, num_workers)
    model = build_model(num_classes, freeze_backbone).to(DEVICE)

    # If backbone is frozen, only optimize the head
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_wts = copy.deepcopy(model.state_dict())
    best_val = -1.0
    bad_epochs = 0

    ce = nn.CrossEntropyLoss()

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        n = 0
        t0 = time.time()
        for x, y in train_ld:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * y.size(0)
            n += y.size(0)

        scheduler.step()
        train_loss = running_loss / max(n, 1)

        val_metrics = evaluate(model, val_ld, num_classes)
        val_score = val_metrics["f1_macro"]  # monitor macro F1
        elapsed = time.time() - t0

        print(f"Epoch {epoch:02d}/{epochs} | "
              f"train_loss={train_loss:.4f} | "
              f"val_acc={val_metrics['acc']:.4f} | "
              f"val_f1={val_metrics['f1_macro']:.4f} | "
              f"time={elapsed:.1f}s")

        # Checkpoint
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "class_names": class_names
        }, Path(out_dir)/"last.pt")

        if val_score > best_val:
            best_val = val_score
            best_wts = copy.deepcopy(model.state_dict())
            torch.save(best_wts, Path(out_dir)/"best_weights.pt")
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print("Early stopping triggered.")
                break

    # Load best and evaluate on test
    model.load_state_dict(best_wts)
    torch.save(model.state_dict(), Path(out_dir)/"best_weights.pt")
    test_metrics = evaluate(model, test_ld, num_classes)
    with open(Path(out_dir)/"test_metrics.json", "w") as f:
        json.dump(test_metrics, f, indent=2)
    print("Test:", test_metrics)
    print("Classes:", class_names)
    return model, class_names

if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, default="data")
    ap.add_argument("--out_dir", type=str, default="runs/aircraft_cls")
    ap.add_argument("--epochs", type=int, default=20)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--weight_decay", type=float, default=1e-4)
    ap.add_argument("--patience", type=int, default=5)
    ap.add_argument("--freeze_backbone", action="store_true")
    ap.add_argument("--no_freeze_backbone", dest="freeze_backbone", action="store_false")
    ap.set_defaults(freeze_backbone=True)
    ap.add_argument("--num_workers", type=int, default=4)
    args = ap.parse_args()

    train(
        data_dir=args.data_dir,
        out_dir=args.out_dir,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        patience=args.patience,
        freeze_backbone=args.freeze_backbone,
        num_workers=args.num_workers
    )


usage: colab_kernel_launcher.py [-h] [--data_dir DATA_DIR] [--out_dir OUT_DIR]
                                [--epochs EPOCHS] [--batch_size BATCH_SIZE]
                                [--lr LR] [--weight_decay WEIGHT_DECAY]
                                [--patience PATIENCE] [--freeze_backbone]
                                [--no_freeze_backbone]
                                [--num_workers NUM_WORKERS]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-6648930d-3886-4c96-9516-43ef7e8f14fc.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
import json
from pathlib import Path
import torch
from torchvision import transforms, models
from PIL import Image

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def build_transform(img_size=256, crop=224):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

def load_model(weights_path, classes_path):
    class_names = json.load(open(classes_path))
    model = models.resnet18(weights=None)
    in_feat = model.fc.in_features
    model.fc = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(in_feat, len(class_names)))
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
    model.eval().to(DEVICE)
    return model, class_names

def predict(img_path, weights="runs/aircraft_cls/best_weights.pt", classes="runs/aircraft_cls/classes.json"):
    tf = build_transform()
    img = Image.open(img_path).convert("RGB")
    x = tf(img).unsqueeze(0).to(DEVICE)
    model, class_names = load_model(weights, classes)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)[0]
        conf, idx = torch.max(probs, dim=0)
    return class_names[idx.item()], conf.item()

if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("image", type=str)
    ap.add_argument("--weights", type=str, default="runs/aircraft_cls/best_weights.pt")
    ap.add_argument("--classes", type=str, default="runs/aircraft_cls/classes.json")
    args = ap.parse_args()

    label, conf = predict(args.image, args.weights, args.classes)
    print(f"Prediction: {label} (confidence {conf:.3f})")
