In [3]:

# train_frcnn_ssdlite.py
import os
import json
import time
import torch
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn,
    ssdlite320_mobilenet_v3_large
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import numpy as np

# ---------- GPU Settings ----------
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# ---------- Settings ----------
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
EPOCHS = 2
BATCH_SIZE = 4
PATIENCE = 3
SAVE_DIR = Path("outputs")
SAVE_DIR.mkdir(exist_ok=True)

# ---------- KITTI Dataset ----------
class KITTIDataset(Dataset):
    def __init__(self, pairs, classes, transforms=None):
        self.pairs = pairs
        self.transforms = transforms
        self.classes = classes

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, label_path = self.pairs[idx]
        image = Image.open(img_path).convert("RGB")
        width, height = image.size

        boxes = []
        labels = []
        if label_path.exists():
            with open(label_path) as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        cls = int(parts[0])
                        x_center, y_center, w, h = map(float, parts[1:5])
                        x1 = (x_center - w/2) * width
                        y1 = (y_center - h/2) * height
                        x2 = (x_center + w/2) * width
                        y2 = (y_center + h/2) * height
                        boxes.append([x1, y1, x2, y2])
                        labels.append(cls)

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64)
        }

        if self.transforms:
            image = self.transforms(image)

        return image, target

# ---------- Data Preparation ----------
def get_data_loaders(transform, batch_size):
    base_dir = Path('kaggle/input/kitti-dataset')
    img_path = base_dir / 'data_object_image_2' / 'training' / 'image_2'
    label_path = Path('kaggle/input/kitti-dataset-yolo-format/labels')

    with open('kaggle/input/kitti-dataset-yolo-format/classes.json','r') as f:
        classes = json.load(f)

    ims = sorted(list(img_path.glob('*')))
    labels = sorted(list(label_path.glob('*')))
    pairs = list(zip(ims, labels))
    mini_pairs = pairs[:len(pairs)//10]  # Now using 1/4th of dataset
    train_pairs, test_pairs = train_test_split(mini_pairs, test_size=0.2, shuffle=True, random_state=42)

    train_dataset = KITTIDataset(train_pairs, classes, transform)
    test_dataset = KITTIDataset(test_pairs, classes, T.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    return train_loader, test_loader, classes

# ---------- Engine ----------
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    model.train()
    total_loss = 0
    pbar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
    for images, targets in pbar:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
        pbar.set_postfix(loss=losses.item())
    return total_loss / len(data_loader)

def evaluate(model, data_loader, device, class_names, out_dir):
    model.eval()
    y_true, y_pred = [], []
    y_true_bin = []
    y_scores = []
    class_counts = [0 for _ in range(len(class_names))]

    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Evaluating"):
            images = list(img.to(device) for img in images)
            outputs = model(images)
            for out, tgt in zip(outputs, targets):
                preds = out['labels'].cpu().tolist()
                scores = out['scores'].cpu().tolist()
                trues = tgt['labels'].cpu().tolist()
                num = min(len(preds), len(trues))

                y_true.extend(trues[:num])
                y_pred.extend(preds[:num])

                for label in trues:
                    class_counts[label] += 1

                for i in range(num):
                    onehot = [0] * len(class_names)
                    onehot[trues[i]] = 1
                    y_true_bin.append(onehot)

                    score_vec = [0] * len(class_names)
                    if preds[i] < len(class_names):
                        score_vec[preds[i]] = scores[i] if i < len(scores) else 0
                    y_scores.append(score_vec)

    # 🔸 Debug prints
    print(f"[DEBUG] Total true labels: {len(y_true)}")
    print(f"[DEBUG] Total predicted labels: {len(y_pred)}")

    report = classification_report(
        y_true,
        y_pred,
        labels=list(range(len(class_names))),
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot()
    plt.title("Confusion Matrix")
    plt.savefig(out_dir / "confusion_matrix.png")
    print("[SAVED] confusion_matrix.png")
    plt.close()

    plt.figure(figsize=(10, 4))
    plt.bar(list(class_names.values()), class_counts)
    plt.title("Class Distribution in Test Set")
    plt.ylabel("Frequency")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(out_dir / "class_distribution.png")
    print("[SAVED] class_distribution.png")
    plt.close()

    metrics_plot = out_dir / "metrics_plot.png"
    labels = list(report.keys())[:-3]
    precision = [report[cls]['precision'] for cls in labels]
    recall = [report[cls]['recall'] for cls in labels]
    f1 = [report[cls]['f1-score'] for cls in labels]

    x = np.arange(len(labels))
    width = 0.2
    plt.bar(x - width, precision, width, label='Precision')
    plt.bar(x, recall, width, label='Recall')
    plt.bar(x + width, f1, width, label='F1-Score')
    plt.xticks(x, labels, rotation=45, ha='right')
    plt.title("Metrics per Class")
    plt.legend()
    plt.tight_layout()
    plt.savefig(metrics_plot)
    print("[SAVED] metrics_plot.png")
    plt.close()

    try:
        y_true_bin = np.array(y_true_bin)
        y_scores = np.array(y_scores)
        for i in range(len(class_names)):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.2f})")
        plt.plot([0, 1], [0, 1], 'k--')
        plt.title("AUC-ROC per Class")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend()
        plt.tight_layout()
        plt.savefig(out_dir / "auc_roc.png")
        print("[SAVED] auc_roc.png")
        plt.close()
    except Exception as e:
        print("[AUC ERROR]", e)

    return report

    
# ---------- Models ----------
def get_frcnn_model(num_classes):
    model = fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def get_ssdlite_model(num_classes):
    model = ssdlite320_mobilenet_v3_large(weights="DEFAULT")
    return model

MODELS = {
    'fasterrcnn': get_frcnn_model,
    'ssdlite': get_ssdlite_model
}

# ---------- Training + Evaluation ----------
def run_training(model_name):
    print(f"\n--- Running {model_name.upper()} ---")
    out_dir = SAVE_DIR / model_name
    out_dir.mkdir(parents=True, exist_ok=True)
    weights_path = out_dir / "weights.pth"
    log_path = out_dir / "log.txt"
    metrics_path = out_dir / "metrics.json"
    plot_path = out_dir / "loss_plot.png"

    model = MODELS[model_name](num_classes=len(classes)).to(DEVICE)

    if weights_path.exists():
        print("Model already trained. Skipping training.")
        try:
            model.load_state_dict(torch.load(weights_path))
        except RuntimeError:
            print("Shape mismatch. Reinitializing model.")
            model = MODELS[model_name](num_classes=len(classes)).to(DEVICE)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
        train_losses = []
        best_loss = float('inf')
        patience_counter = 0

        with open(log_path, 'w') as log:
            for epoch in range(EPOCHS):
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                loss = train_one_epoch(model, optimizer, train_loader, DEVICE, epoch, print_freq=10)
                train_losses.append(loss)
                log.write(f"Epoch {epoch + 1}, Loss: {loss:.4f}\n")

                if loss < best_loss:
                    best_loss = loss
                    patience_counter = 0
                    torch.save(model.state_dict(), weights_path)
                else:
                    patience_counter += 1
                    if patience_counter >= PATIENCE:
                        print("Early stopping triggered.")
                        break

        plt.plot(train_losses)
        plt.title(f"Training Loss - {model_name}")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(plot_path)
        plt.close()

    print("Evaluating model...")
    model.load_state_dict(torch.load(weights_path))
    metrics = evaluate(model, test_loader, device=DEVICE, class_names=classes, out_dir=out_dir)
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)
    print(f"{model_name.upper()} Metrics:", metrics)

# ---------- Run ----------
train_transform = T.Compose([
    T.RandomHorizontalFlip(0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.ToTensor()
])
train_loader, test_loader, classes = get_data_loaders(train_transform, BATCH_SIZE)
run_training("fasterrcnn")
run_training("ssdlite")



--- Running FASTERRCNN ---
Model already trained. Skipping training.
Evaluating model...


Evaluating:   0%|          | 0/38 [00:00<?, ?it/s]

[DEBUG] Total true labels: 738
[DEBUG] Total predicted labels: 738
[SAVED] confusion_matrix.png
[SAVED] class_distribution.png
[SAVED] metrics_plot.png
[AUC ERROR] 0
FASTERRCNN Metrics: {'Car': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 533.0}, 'Pedestrian': {'precision': 0.24154589371980675, 'recall': 0.847457627118644, 'f1-score': 0.37593984962406013, 'support': 59.0}, 'Van': {'precision': 0.0948905109489051, 'recall': 0.7222222222222222, 'f1-score': 0.16774193548387098, 'support': 54.0}, 'Cyclist': {'precision': 0.09090909090909091, 'recall': 0.19230769230769232, 'f1-score': 0.12345679012345678, 'support': 26.0}, 'Truck': {'precision': 0.18, 'recall': 0.42857142857142855, 'f1-score': 0.2535211267605634, 'support': 21.0}, 'Misc': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 24.0}, 'Tram': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 12.0}, 'Person_sitting': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 9.0}, 'accura

Epoch 1:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/150 [00:00<?, ?it/s]

Evaluating model...


Evaluating:   0%|          | 0/38 [00:00<?, ?it/s]

[DEBUG] Total true labels: 789
[DEBUG] Total predicted labels: 789
[SAVED] confusion_matrix.png
[SAVED] class_distribution.png
[SAVED] metrics_plot.png
[AUC ERROR] 0
SSDLITE Metrics: {'Car': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 577.0}, 'Pedestrian': {'precision': 0.11063829787234042, 'recall': 0.8666666666666667, 'f1-score': 0.19622641509433963, 'support': 60.0}, 'Van': {'precision': 0.06172839506172839, 'recall': 0.17543859649122806, 'f1-score': 0.091324200913242, 'support': 57.0}, 'Cyclist': {'precision': 0.015873015873015872, 'recall': 0.03571428571428571, 'f1-score': 0.02197802197802198, 'support': 28.0}, 'Truck': {'precision': 0.0625, 'recall': 0.14285714285714285, 'f1-score': 0.08695652173913043, 'support': 21.0}, 'Misc': {'precision': 0.3333333333333333, 'recall': 0.04, 'f1-score': 0.07142857142857142, 'support': 25.0}, 'Tram': {'precision': 0.11428571428571428, 'recall': 0.3333333333333333, 'f1-score': 0.1702127659574468, 'support': 12.0}, 'Person_sitti