In [1]:
!pip install optuna
import os
import gc
import cv2
import json
import torch
from math import inf
from posixpath import defpath
import wandb
from torchvision.models.detection.image_list import ImageList
import optuna
import shutil
import logging
import numpy as np
import torchvision
import torch.optim as optim
from torchvision.ops import nms
import torchvision.ops as ops
import matplotlib.pyplot as plt
import torch.nn.functional as F
import matplotlib.patches as patches
from torch import nn
from tqdm import tqdm
from PIL import Image
from torchvision.models.detection.rpn import AnchorGenerator
from collections import Counter
from torchsummary import summary
from sklearn.metrics import f1_score,accuracy_score,precision_score,recall_score,average_precision_score
from typing import Dict, Any, Optional
from torchvision import transforms , models
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
from google.colab import drive
drive.mount('/content/drive')

src = "/content/drive/MyDrive/kitti2012"
dst = "/content"
os.makedirs(dst, exist_ok=True)

!cp -r "$src" "$dst"

#src = "/content/drive/MyDrive/vkitti_sample (1)"
#dst = "/content"
#os.makedirs(dst, exist_ok=True)

#!cp -r "$src" "$dst"

wandb.login()



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


True

In [2]:
class KITTI_Dataset(Dataset):
    def __init__(self, data_path, transform=None, mode='train'):
        self.data_path = data_path
        self.transform = transform
        self.mode = mode  # 'train', 'val', veya 'test'
        self.classes = ['Car', 'Van', 'Truck', 'Pedestrian', 'Cyclist', 'Tram', 'Misc']
        self.class_map = {cls: idx for idx, cls in enumerate(self.classes)}
        self.data = []

        # KITTI görsel boyutları - normalizasyon için
        self.img_width = 256  # Model giriş boyutu
        self.img_height = 256
        self.original_width = 1242  # Orijinal KITTI boyutu
        self.original_height = 375
        # Varsayılan transform: yeniden boyutlandırma
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor()
            ])
        else:
            self.transform = transform
        self.max_depth = 80.0  # KITTI max derinlik

        if self.mode == 'test':
            # Test modunda sadece görselleri yükle
            data_path =data_path+"/testing/colored_0"
            file_names = os.listdir(data_path)
            for fname in file_names:
                if fname.endswith('.png'):
                    img_path = os.path.join(data_path, fname)
                    self.data.append(img_path)
        else :
          image_dir = os.path.join(data_path, 'training', 'colored_0')
          label_dir = os.path.join(data_path, 'training', 'label_2')
          disp_dir = os.path.join(data_path, 'training', 'disp_noc')

          file_names = os.listdir(image_dir)
          for fname in file_names:
              if fname.endswith('.png'):
                  scene_id = fname.split('_')[0]
                  img_path = os.path.join(image_dir, fname)
                  label_path = os.path.join(label_dir, f'{scene_id}.txt')
                  disp_path = os.path.join(disp_dir, f'{scene_id}_10.png')
                  if os.path.exists(label_path) and os.path.exists(disp_path):
                      self.data.append((img_path, label_path, disp_path))

    def __len__(self):
        return len(self.data)

    def convert_labels(self, label_file):
      labels = []
      with open(label_file, 'r') as f:
          for line in f:
              parts = line.strip().split()
              category = parts[0]
              if category in self.class_map and category != 'DontCare':
                  x1, y1, x2, y2 = map(float, parts[4:8])
                  # Orijinal boyutlara göre normalize et
                  x1_norm = x1 / self.original_width
                  y1_norm = y1 / self.original_height
                  x2_norm = x2 / self.original_width
                  y2_norm = y2 / self.original_height

                  category_num = self.class_map[category]
                  labels.append([category_num, x1_norm, y1_norm, x2_norm, y2_norm])
      return labels

    def load_disparity(self, disp_path):#konum haritasının olduğu image yüklenir
        disp_map = cv2.imread(disp_path, cv2.IMREAD_UNCHANGED) / 256.0
        return disp_map

    def calculate_depth(self, disp_map):
      baseline = 0.54
      focal_length = 721.5377 * (256 / 1242)  # Odak uzaklığını ölçeklendir
      depth = (baseline * focal_length) / (disp_map + 1e-6)
      depth = np.clip(depth, 0, self.max_depth)
      normalized_depth = depth / self.max_depth
      # 256x256'ya yeniden boyutlandır
      normalized_depth = cv2.resize(normalized_depth, (256, 256), interpolation=cv2.INTER_LINEAR)
      return normalized_depth

    def get_depth_at_box(self, depth_map, x, y, w, h):#her nesnenin ortalama mesafesi için box içerisindeki merkez piksel depthi alınır

        x_pixel = int(x * self.img_width)
        y_pixel = int(y * self.img_height)

        # Sınır kontrolü
        x_pixel = np.clip(x_pixel, 0, self.img_width - 1)
        y_pixel = np.clip(y_pixel, 0, self.img_height - 1)

        if depth_map[y_pixel, x_pixel] == 0:
            return 0.0
        return depth_map[y_pixel, x_pixel]

    def get_disparity_at_box(self, disp_map, x, y, w, h):#her nesnenin box içerisindeki konum değerini hesaplar


        x_pixel = int(x * self.img_width)
        y_pixel = int(y * self.img_height)

        # Sınır kontrolü
        x_pixel = np.clip(x_pixel, 0, self.img_width - 1)
        y_pixel = np.clip(y_pixel, 0, self.img_height - 1)

        if disp_map[y_pixel, x_pixel] == 0:
            return 0.0

        # Disparite değerini de normalize et (max disparite ~300 civarı)
        max_disparity = 300.0
        normalized_disparity = disp_map[y_pixel, x_pixel] / max_disparity
        return np.clip(normalized_disparity, 0, 1)

    def __getitem__(self, idx):
      if self.mode != 'test':
          labels_with_depth = []
          img_path, label_path, disp_path = self.data[idx]
          image = Image.open(img_path).convert('RGB')
          labels = self.convert_labels(label_path)
          disp_map = self.load_disparity(disp_path)
          depth_map = self.calculate_depth(disp_map)  # Derinlik haritası
          for label in labels:
              category_num, x1, y1, x2, y2 = label
              # Bounding box merkezinden derinlik değerini al
              center_x = (x1 + x2) / 2
              center_y = (y1 + y2) / 2
              depth = self.get_depth_at_box(depth_map, center_x, center_y, x2 - x1, y2 - y1)
              labels_with_depth.append([category_num, x1, y1, x2, y2, depth])
          if self.transform:
              image = self.transform(image)
          output_labels = torch.tensor(labels_with_depth, dtype=torch.float32)
          return image, output_labels
      else:
          img_path = self.data[idx]
          image = Image.open(img_path).convert('RGB')
          if self.transform:
              image = self.transform(image)
          return image

        #eğitimde doğrudan konum değerleri ile kayıp hesaplanırken test aşamasında direkt mesafe hesaplanabilir

def test_collate_fn(batch):
    """Test mode için basit collate function"""
    images = batch  # batch sadece image tensor'larının listesi
    images = torch.stack(images, dim=0)  # [B, C, H, W]
    return images

def kitti_collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    images = torch.stack(images, dim=0)  # [B, C, H, W]
    return images, labels

def analyze_dataset(train_dataset, device):
    """
    Veri seti analizi ve görselleştirme fonksiyonu - Optimize edilmiş versiyon
    """
    import numpy as np
    import torch
    from collections import Counter

    # === 0. Veri Setinden Özellikleri Çıkarma - TEK DÖNGÜ İLE ===
    print(f"Veri seti boyutu: {len(train_dataset)} örnek")

    # Tüm listeleri tek seferde oluştur
    cls_labels = []
    depth_maps = []
    bboxes = []
    objects_per_image = []


    for i in range(len(train_dataset)):
        image_data = train_dataset[i][1]  # Bir kez al
        objects_per_image.append(len(image_data))

        # Bu görüntüdeki tüm objeleri işle
        for obj in image_data:
            cls_labels.append(obj[0])           # sınıf
            depth_maps.append(obj[5])           # depth
            bboxes.append(obj[1:5])             # [x1,y1,x2,y2]
    total_objects = len(cls_labels)
    print(f"Toplam obje sayısı: {total_objects}")
    print(f"Toplam bbox sayısı: {len(bboxes)}")
    print(f"Toplam depth değeri: {len(depth_maps)}")

    # === 1. Sınıf Dağılımı Analizi ===
    cls_distribution = Counter(cls_labels)
    sorted_cls_distribution = sorted(cls_distribution.items(), key=lambda x: x[0])

    # Class weights hesaplama
    class_weights = torch.tensor([count for _, count in sorted_cls_distribution],
                                dtype=torch.float32).to(device)
    # === 2. BBox Analizi - VEKTÖRLEŞTİRİLMİŞ ===
    # NumPy array'e çevir hızlı işlem için
    bboxes_array = np.array(bboxes)

    # Format: [x1, y1, x2, y2]
    bbox_widths = bboxes_array[:, 2] - bboxes_array[:, 0]   # x2 - x1
    bbox_heights = bboxes_array[:, 3] - bboxes_array[:, 1]  # y2 - y1
    bbox_areas = bbox_widths * bbox_heights

    # Sıfıra bölme kontrolü ile aspect ratio
    bbox_aspect_ratios = np.divide(bbox_widths, bbox_heights,
                                  out=np.zeros_like(bbox_widths),
                                  where=bbox_heights!=0)

    # === 3. Görsel Başına Obje Sayısı Analizi - ZATEN HAZIR ===
    # objects_per_image yukarıda hesaplandı

    # === 4. Depth Analizi - VEKTÖRLEŞTİRİLMİŞ ===
    depth_array = np.array(depth_maps)
    valid_mask = ~np.isnan(depth_array)
    valid_depths = depth_array[valid_mask]
    invalid_depths = np.sum(~valid_mask)




    # === 5. Görselleştirme ===
    fig = plt.figure(figsize=(18, 12))

    # Sınıf dağılımı pasta grafiği
    plt.subplot(3, 4, 1)
    labels, values = zip(*sorted_cls_distribution)
    plt.pie(values, labels=labels, autopct='%1.1f%%', startangle=140)
    plt.title('Sınıf Dağılımı', fontsize=12, fontweight='bold')
    plt.axis('equal')

    # Sınıf dağılımı bar grafiği
    plt.subplot(3, 4, 2)
    plt.bar(labels, values, alpha=0.7, edgecolor='black')
    plt.title('Sınıf Dağılımı (Bar Chart)', fontsize=12, fontweight='bold')
    plt.xlabel('Sınıf')
    plt.ylabel('Örnek Sayısı')
    plt.xticks(rotation=45)

    # Görsel başına obje sayısı histogramı
    plt.subplot(3, 4, 3)
    plt.hist(objects_per_image, bins=range(1, max(objects_per_image)+2),
             alpha=0.7, edgecolor='black', color='skyblue')
    plt.title('Görsel Başına Obje Sayısı', fontsize=12, fontweight='bold')
    plt.xlabel('Obje Sayısı')
    plt.ylabel('Görsel Sayısı')
    plt.grid(True, alpha=0.3)

    # BBox genişlik histogramı
    plt.subplot(3, 4, 4)
    plt.hist(bbox_widths, bins=50, alpha=0.7, edgecolor='black', color='orange')
    plt.title('BBox Genişlik Dağılımı', fontsize=12, fontweight='bold')
    plt.xlabel('Genişlik')
    plt.ylabel('Frekans')
    plt.grid(True, alpha=0.3)

    # BBox yükseklik histogramı
    plt.subplot(3, 4, 5)
    plt.hist(bbox_heights, bins=50, alpha=0.7, edgecolor='black', color='red')
    plt.title('BBox Yükseklik Dağılımı', fontsize=12, fontweight='bold')
    plt.xlabel('Yükseklik')
    plt.ylabel('Frekans')
    plt.grid(True, alpha=0.3)

    # BBox alan histogramı
    plt.subplot(3, 4, 6)
    plt.hist(bbox_areas, bins=50, alpha=0.7, edgecolor='black', color='purple')
    plt.title('BBox Alan Dağılımı', fontsize=12, fontweight='bold')
    plt.xlabel('Alan')
    plt.ylabel('Frekans')
    plt.grid(True, alpha=0.3)

    # BBox aspect ratio histogramı
    plt.subplot(3, 4, 7)
    plt.hist(bbox_aspect_ratios, bins=50, alpha=0.7, edgecolor='black', color='brown')
    plt.title('BBox En/Boy Oranı', fontsize=12, fontweight='bold')
    plt.xlabel('En/Boy Oranı')
    plt.ylabel('Frekans')
    plt.grid(True, alpha=0.3)

    # Depth değerleri histogramı
    plt.subplot(3, 4, 8)
    if valid_depths:
        plt.hist(valid_depths, bins=50, alpha=0.7, edgecolor='black', color='lightgreen')
    plt.title('Depth Değerleri Dağılımı', fontsize=12, fontweight='bold')
    plt.xlabel('Normalize Edilmiş Depth')
    plt.ylabel('Frekans')
    plt.grid(True, alpha=0.3)

    # Depth box plot
    plt.subplot(3, 4, 9)
    if valid_depths:
        plt.boxplot(valid_depths, vert=True)
    plt.title('Depth Box Plot', fontsize=12, fontweight='bold')
    plt.ylabel('Normalize Edilmiş Depth')
    plt.grid(True, alpha=0.3)

    # Obje sayısı box plot
    plt.subplot(3, 4, 10)
    plt.boxplot(objects_per_image, vert=True)
    plt.title('Obje Sayısı Box Plot', fontsize=12, fontweight='bold')
    plt.ylabel('Obje Sayısı')
    plt.grid(True, alpha=0.3)

    # BBox alan box plot
    plt.subplot(3, 4, 11)
    plt.boxplot(bbox_areas, vert=True)
    plt.title('BBox Alan Box Plot', fontsize=12, fontweight='bold')
    plt.ylabel('Alan')
    plt.grid(True, alpha=0.3)

    # BBox aspect ratio box plot
    plt.subplot(3, 4, 12)
    plt.boxplot(bbox_aspect_ratios, vert=True)
    plt.title('En/Boy Oranı Box Plot', fontsize=12, fontweight='bold')
    plt.ylabel('En/Boy Oranı')
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # === 6. İstatistik Raporları ===
    print("=" * 80)
    print("DATASET ANALİZ RAPORU")
    print("=" * 80)

    # Genel bilgiler
    print(f"\n📋 GENEL BİLGİLER:")
    print(f"Toplam görsel sayısı:    {len(train_dataset):6d}")
    print(f"Toplam obje sayısı:      {len(cls_labels):6d}")
    print(f"Toplam bbox sayısı:      {len(bboxes):6d}")

    # Sınıf dağılımı raporu
    print(f"\n SINIF DAĞILIMI:")
    print(f"Toplam sınıf sayısı: {len(cls_distribution)}")
    for label, count in sorted_cls_distribution:
        percentage = (count / sum(cls_distribution.values())) * 100
        print(f"  Sınıf {label}: {count:4d} örnek ({percentage:5.1f}%)")

    # BBox analiz raporu
    print(f"\n BBOX ANALİZ RAPORU:")
    print(f"Ortalama genişlik:       {np.mean(bbox_widths):8.2f}")
    print(f"Ortalama yükseklik:      {np.mean(bbox_heights):8.2f}")
    print(f"Ortalama alan:           {np.mean(bbox_areas):8.2f}")
    print(f"Ortalama en/boy oranı:   {np.mean(bbox_aspect_ratios):8.2f}")
    print(f"Min genişlik:            {np.min(bbox_widths):8.2f}")
    print(f"Max genişlik:            {np.max(bbox_widths):8.2f}")
    print(f"Min yükseklik:           {np.min(bbox_heights):8.2f}")
    print(f"Max yükseklik:           {np.max(bbox_heights):8.2f}")
    print(f"Min alan:                {np.min(bbox_areas):8.2f}")
    print(f"Max alan:                {np.max(bbox_areas):8.2f}")

    # Obje sayısı raporu
    print(f"\n OBJE SAYISI İSTATİSTİKLERİ:")
    print(f"Ortalama obje sayısı: {np.mean(objects_per_image):6.2f}")
    print(f"Medyan obje sayısı:   {np.median(objects_per_image):6.2f}")
    print(f"Maksimum obje sayısı: {max(objects_per_image):6d}")
    print(f"Minimum obje sayısı:  {min(objects_per_image):6d}")
    print(f"Standart sapma:       {np.std(objects_per_image):6.2f}")

    # Depth raporu
    print(f"\n DEPTH ANALİZ RAPORU:")
    print(f"Toplam depth değeri:     {len(depth_maps):6d}")
    print(f"Geçerli depth değeri:    {len(valid_depths):6d}")
    print(f"Geçersiz depth (NaN):    {invalid_depths:6d}")
    print(f"Depth geçerlilik oranı:  {len(valid_depths)/len(depth_maps)*100:6.1f}%")

    if valid_depths:
        print(f"\n DEPTH İSTATİSTİKLERİ:")
        for key, value in depth_stats.items():
            print(f"{key:>8}: {value:8.4f}")

    print("=" * 80)

    return {
        'class_weights': class_weights,
        'cls_distribution': sorted_cls_distribution,
        'bbox_stats': {
            'mean_width': np.mean(bbox_widths),
            'mean_height': np.mean(bbox_heights),
            'mean_area': np.mean(bbox_areas),
            'mean_aspect_ratio': np.mean(bbox_aspect_ratios),
            'width_std': np.std(bbox_widths),
            'height_std': np.std(bbox_heights),
            'area_std': np.std(bbox_areas)
        },
        'objects_per_image_stats': {
            'mean': np.mean(objects_per_image),
            'median': np.median(objects_per_image),
            'max': max(objects_per_image),
            'min': min(objects_per_image),
            'std': np.std(objects_per_image)
        },
        'depth_stats': depth_stats,
        'depth_validity_ratio': len(valid_depths)/len(depth_maps)*100 if depth_maps else 0,
        'total_images': len(train_dataset),
        'total_objects': len(cls_labels)
    }

def train_model(model, train_loader, val_loader, num_epochs=100,
                learning_rate=1e-4, device='cuda', save_path='model_checkpoint.pth',
                early_stop_patience=3,scheduler_patience=10, scheduler_factor=0.5, class_weights=None, task_weights=None,
                p_iou_threshold=0.5, n_iou_threshold=0.4):

    # WandB config'e weights ekle
    config = {
        "learning_rate": learning_rate,
        "architecture": "EfficientBasedMultiTask",
        "dataset": "KITTI-2012",
        "epochs": num_epochs,
    }

    # Task weights'i config'e ekle
    if task_weights:
        config.update({
            "task_weights": task_weights,
            "classification_weight": task_weights.get("classification", 1.0),
            "regression_weight": task_weights.get("regression", 1.0),
            "detection_depth_weight": task_weights.get("detection_depth", 1.0),
            "depth_map_weight": task_weights.get("depth_map", 1.0)
        })

    wandb.init(
        entity="mehmeteminuludag-kirikkale-university",
        project="StajProjesi",
        config=config
    )

    # Move model to device
    model = model.to(device)
    criterion = MultiTaskCriterion(loss_weights=task_weights, pos_iou_threshold=p_iou_threshold, neg_iou_threshold=n_iou_threshold).to(device)

    optimizer = optim.AdamW(
        list(model.parameters()) + list(criterion.parameters()),
        lr=learning_rate, weight_decay=1e-5, eps=1e-8
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=scheduler_factor,
        patience=scheduler_patience,
        verbose=True
    )
    best_val_loss = 1.0
    counter=0
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_losses = {
            'total': 0,
            'classification': 0,
            'regression': 0,
            'depth': 0,
            'depth_map': 0
        }
        train_metrics_accum = {
            'Accuracy': 0,
            'F1_score': 0,
            'MSE': 0,
            'RMSE': 0,
            'mAP': 0,
            'TotalLoss': 0,
            'ClsLoss': 0
        }

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train')

        for batch_idx, (images, targets) in enumerate(pbar):
            images = images.to(device)
            batch_size = len(targets)

            optimizer.zero_grad()
            model_outputs = model(images, mode="train")
            values = {}
            values['cls_preds'] = model_outputs['cls_preds']
            values['reg_preds'] = model_outputs['reg_preds']
            values['anchors'] = model_outputs['anchors']
            values['depth_pred'] = model_outputs['depth_pred']
            values['targets'] = targets

            losses, metrics = criterion(values)
            losses['total'].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Hataları biriktirir
            for key in train_losses:
                if key in losses:
                    if isinstance(losses[key], torch.Tensor):
                        train_losses[key] += losses[key].item() * batch_size
                    else:
                        train_losses[key] += losses[key] * batch_size

            # Metrikleri biriktir
            for key in train_metrics_accum:
                if key in metrics:
                    if isinstance(metrics[key], torch.Tensor):
                        train_metrics_accum[key] += metrics[key].item() * batch_size
                    else:
                        train_metrics_accum[key] += metrics[key] * batch_size

            # DÜZELTME 4: Memory cleanup
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()

            pbar.set_postfix({
                'Acc': f"{train_metrics_accum['Accuracy']/((batch_idx+1)*batch_size): .3f}",
                'ClsLoss': f"{train_losses['classification']/((batch_idx+1)*batch_size):.3f}",
                'F1': f"{train_metrics_accum['F1_score']/((batch_idx+1)*batch_size):.3f}",
                'RMSE': f"{train_metrics_accum['RMSE']/((batch_idx+1)*batch_size):.3f}",
                'mAP': f"{train_metrics_accum['mAP']/((batch_idx+1)*batch_size):.3f}",
                'TotalLoss': f"{train_losses['total']/((batch_idx+1)*batch_size):.3f}"
            })

        # DÜZELTME 3: Loop dışına çıkarıldı
        for key in train_losses:
            train_losses[key] /= len(train_loader.dataset)

        for key in train_metrics_accum:
            train_metrics_accum[key] /= len(train_loader.dataset)

        # Validation

        val_losses = {
            'total': 0,
            'classification': 0,
            'regression': 0,
            'depth': 0,
            'depth_map': 0
        }
        val_metrics_accum = {
            'Accuracy': 0,
            'F1_score': 0,
            'MSE': 0,
            'RMSE': 0,
            'mAP': 0,
            'TotalLoss': 0,
            'ClsLoss': 0
        }

        with torch.no_grad():
            #model.train()
            pbar2 = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} Validation')

            for batch_idx, (images, targets) in enumerate(pbar2):
                images = images.to(device)
                batch_size = len(targets)


                model_outputs = model(images, mode="train")
                values = {}
                values['cls_preds'] = model_outputs['cls_preds']
                values['reg_preds'] = model_outputs['reg_preds']
                values['anchors'] = model_outputs['anchors']
                values['depth_pred'] = model_outputs['depth_pred']
                values['targets'] = targets
                losses, metrics = criterion(values)

                # Hataları biriktirir
                for key in val_losses:
                    if key in losses:
                        if isinstance(losses[key], torch.Tensor):
                            val_losses[key] += losses[key].item() * batch_size
                        else:
                            val_losses[key] += losses[key] * batch_size

                # Metrikleri biriktir
                for key in val_metrics_accum:
                    if key in metrics:
                        if isinstance(metrics[key], torch.Tensor):
                            val_metrics_accum[key] += metrics[key].item() * batch_size
                        else:
                            val_metrics_accum[key] += metrics[key] * batch_size

                pbar2.set_postfix({
                    'Acc': f"{val_metrics_accum['Accuracy']/((batch_idx+1)*batch_size): .3f}",
                    'ClsLoss': f"{val_losses['classification']/((batch_idx+1)*batch_size):.3f}",
                    'F1': f"{val_metrics_accum['F1_score']/((batch_idx+1)*batch_size):.3f}",
                    'RMSE': f"{val_metrics_accum['RMSE']/((batch_idx+1)*batch_size):.3f}",
                    'mAP': f"{val_metrics_accum['mAP']/((batch_idx+1)*batch_size):.3f}",
                    'TotalLoss': f"{val_losses['total']/((batch_idx+1)*batch_size):.3f}"
                })

        # DÜZELTME 5: Validation loop dışına çıkarıldı
        for key in val_losses:
            val_losses[key] /= len(val_loader.dataset)

        for key in val_metrics_accum:
            val_metrics_accum[key] /= len(val_loader.dataset)


        scheduler.step(val_losses['total'])

        if float(val_losses['total']) < best_val_loss:
          best_val_loss = float(val_losses['total'])
          counter = 0  # İyileşme oldu, sıfırla
          torch.save({
              'epoch': epoch + 1,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'criterion_state_dict': criterion.state_dict()
          }, save_path)
        else:
            counter += 1
            print(f"No improvement in val loss for {counter} epochs.")

        # wandb logları alınıyor
        wandb.log({
            "val/accuracy": val_metrics_accum['Accuracy'],
            "val/classification_loss": val_losses['classification'],
            "val/f1_score": val_metrics_accum['F1_score'],
            "val/rmse": val_metrics_accum['RMSE'],
            "val/map": val_metrics_accum['mAP'],
            "learning_rate": optimizer.param_groups[0]['lr']
        }, step=epoch+1)

        wandb.log({
            "train/accuracy": train_metrics_accum['Accuracy'],
            "train/classification_loss": train_losses['classification'],
            "train/f1_score": train_metrics_accum['F1_score'],
            "train/rmse": train_metrics_accum['RMSE'],
            "train/map": train_metrics_accum['mAP'],
            "learning_rate": optimizer.param_groups[0]['lr']
        }, step=epoch+1)



        if counter >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    wandb.finish()
    return best_val_loss



    def _apply_memory_efficient_nms(self, boxes, scores, classes):
        if len(boxes) == 0:
            return torch.tensor([], dtype=torch.long, device=self.device)
        unique_classes = torch.unique(classes)
        all_keep_indices = []
        for cls in unique_classes:
            if cls == 8:  # Skip DontCare
                continue
            cls_mask = classes == cls
            if cls_mask.sum() == 0:
                continue
            cls_boxes = boxes[cls_mask]
            cls_scores = scores[cls_mask]
            cls_indices = torch.where(cls_mask)[0]
            abs_boxes = cls_boxes * self.img_size
            keep_cls = nms(abs_boxes, cls_scores, self.nms_threshold)
            all_keep_indices.append(cls_indices[keep_cls])
        if len(all_keep_indices) == 0:
            return torch.tensor([], dtype=torch.long, device=self.device)
        final_keep_indices = torch.cat(all_keep_indices)
        if len(final_keep_indices) > self.max_detections_per_image:
            keep_scores = scores[final_keep_indices]
            top_indices = torch.topk(keep_scores, self.max_detections_per_image)[1]
            final_keep_indices = final_keep_indices[top_indices]
        return final_keep_indices

    def _get_depth_values(self, depth_map, boxes):
        if depth_map is None:
            return [0.0] * len(boxes)
        H, W = depth_map.shape
        depths = []
        for box in boxes:
            x1, y1, x2, y2 = box
            x1_pix = int(x1 * W)
            y1_pix = int(y1 * H)
            x2_pix = int(x2 * W)
            y2_pix = int(y2 * H)
            x1_pix = max(0, min(x1_pix, W-1))
            y1_pix = max(0, min(y1_pix, H-1))
            x2_pix = max(0, min(x2_pix, W-1))
            y2_pix = max(0, min(y2_pix, H-1))
            center_x = (x1_pix + x2_pix) // 2
            center_y = (y1_pix + y2_pix) // 2
            depth_value = depth_map[center_y, center_x].item() if depth_map is not None else 0.0
            depths.append(depth_value)
        return depths

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_out = self.conv1(x_cat)
        attention_map = self.sigmoid(x_out)
        return x * attention_map

class EncoderBackBone(nn.Module):
    def __init__(self,İsPretreained=True):
        super(EncoderBackBone,self).__init__()
        efficient = models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        self.features = efficient.features
        self.SAttention = SpatialAttention()

    def forward(self, x):         # B,C,H,W

        outs = []

        # Her iki frame için özellikler
        for i, block in enumerate(self.features):
            x = block(x)
            if i > 2:  # C3'ten sonrası için Spatial Attention
                x = x * self.SAttention(x)
            if i in [3,5,7]:
                out = F.interpolate(x, size=256, mode='bilinear', align_corners=False)
                outs.append(out)
        return outs

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size,
                                 stride, padding, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.swish = nn.SiLU()

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        return self.swish(x)

class BiFPNBlock(nn.Module):
    def __init__(self, channels, epsilon=1e-4):
        super(BiFPNBlock, self).__init__()
        self.epsilon = epsilon
        self.channels = channels

        # Convolution layers for each level
        self.conv_p3 = DepthwiseSeparableConv(channels, channels)
        self.conv_p4 = DepthwiseSeparableConv(channels, channels)
        self.conv_p5 = DepthwiseSeparableConv(channels, channels)
        self.conv_p6 = DepthwiseSeparableConv(channels, channels)
        self.conv_p7 = DepthwiseSeparableConv(channels, channels)

        # Weight parameters for feature fusion
        self.w1 = nn.Parameter(torch.ones(2))
        self.w2 = nn.Parameter(torch.ones(2))
        self.w3 = nn.Parameter(torch.ones(2))
        self.w4 = nn.Parameter(torch.ones(2))
        self.w5 = nn.Parameter(torch.ones(3))
        self.w6 = nn.Parameter(torch.ones(3))
        self.w7 = nn.Parameter(torch.ones(3))
        self.w8 = nn.Parameter(torch.ones(2))

    def forward(self, inputs):
        P3, P4, P5, P6, P7 = inputs

        # Bottom-up pathway
        w1 = F.relu(self.w1)
        P6_td = (w1[0] * P6 + w1[1] * self.up_sampling(P7, P6.shape[-2:])) / (w1.sum() + self.epsilon)
        P6_td = self.conv_p6(P6_td)

        w2 = F.relu(self.w2)
        P5_td = (w2[0] * P5 + w2[1] * self.up_sampling(P6_td, P5.shape[-2:])) / (w2.sum() + self.epsilon)
        P5_td = self.conv_p5(P5_td)

        w3 = F.relu(self.w3)
        P4_td = (w3[0] * P4 + w3[1] * self.up_sampling(P5_td, P4.shape[-2:])) / (w3.sum() + self.epsilon)
        P4_td = self.conv_p4(P4_td)

        # Top-down pathway
        w4 = F.relu(self.w4)
        P3_out = (w4[0] * P3 + w4[1] * self.up_sampling(P4_td, P3.shape[-2:])) / (w4.sum() + self.epsilon)
        P3_out = self.conv_p3(P3_out)

        w5 = F.relu(self.w5)
        P4_out = (w5[0] * P4 + w5[1] * P4_td + w5[2] * self.down_sampling(P3_out, P4.shape[-2:])) / (w5.sum() + self.epsilon)
        P4_out = self.conv_p4(P4_out)

        w6 = F.relu(self.w6)
        P5_out = (w6[0] * P5 + w6[1] * P5_td + w6[2] * self.down_sampling(P4_out, P5.shape[-2:])) / (w6.sum() + self.epsilon)
        P5_out = self.conv_p5(P5_out)

        w7 = F.relu(self.w7)
        P6_out = (w7[0] * P6 + w7[1] * P6_td + w7[2] * self.down_sampling(P5_out, P6.shape[-2:])) / (w7.sum() + self.epsilon)
        P6_out = self.conv_p6(P6_out)

        w8 = F.relu(self.w8)
        P7_out = (w8[0] * P7 + w8[1] * self.down_sampling(P6_out, P7.shape[-2:])) / (w8.sum() + self.epsilon)
        P7_out = self.conv_p7(P7_out)

        return [P3_out, P4_out, P5_out, P6_out, P7_out]

    def up_sampling(self, x, target_size):
        return F.interpolate(x, size=target_size, mode='nearest')

    def down_sampling(self, x, target_size):
        if x.shape[-2:] == target_size:
            return x
        stride = x.shape[-1] // target_size[-1]
        kernel_size = stride
        return F.max_pool2d(x, kernel_size=kernel_size, stride=stride)

class BiFPN(nn.Module):
    def __init__(self, in_channels_list, out_channels=256, num_blocks=3):
        super(BiFPN, self).__init__()
        self.out_channels = out_channels
        self.num_blocks = num_blocks

        # Input projection layers
        self.input_convs = nn.ModuleList([
            nn.Conv2d(in_ch, out_channels, 1, bias=False)
            for in_ch in in_channels_list
        ])

        # Additional P6 and P7 layers
        self.p6_conv = nn.Conv2d(in_channels_list[-1], out_channels, 3, stride=2, padding=1)
        self.p7_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)

        # BiFPN blocks
        self.bifpn_blocks = nn.ModuleList([
            BiFPNBlock(out_channels) for _ in range(num_blocks)
        ])

    def forward(self, inputs):
        # Project input features
        features = []
        for i, feat in enumerate(inputs):
            features.append(self.input_convs[i](feat))

        # Create P6 and P7
        P6 = self.p6_conv(inputs[-1])
        P7 = self.p7_conv(P6)

        # Initial feature list
        pyramid_features = features + [P6, P7]

        # Apply BiFPN blocks
        for block in self.bifpn_blocks:
            pyramid_features = block(pyramid_features)

        return pyramid_features

class NNConv3UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return x

class FusionBlock(nn.Module):
    def __init__(self, fusion_type='add'):
        super().__init__()
        self.fusion_type = fusion_type

    def forward(self, high_level, low_level):
        if self.fusion_type == 'add':
            return high_level + low_level
        elif self.fusion_type == 'concat':
            return torch.cat([high_level, low_level], dim=1)

class PredictionDecoder(nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels//2, out_channels, kernel_size=3, padding=1)
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        return x

class RTMonoDepthDecoder(nn.Module):

    def __init__(self, encoder_channels=[48, 136, 384], decoder_channels=[256, 128, 64, 32]):
        super().__init__()

        # Upsampling blocks
        self.upconv2 = NNConv3UpBlock(encoder_channels[2], decoder_channels[0])  # F3 -> D2
        self.upconv1 = NNConv3UpBlock(decoder_channels[0], decoder_channels[1])  # After fusion -> D1
        self.upconv0 = NNConv3UpBlock(decoder_channels[1], decoder_channels[2])  # After fusion -> D0

        # Projection layers to match dimensions for fusion
        self.proj2 = nn.Conv2d(encoder_channels[1], decoder_channels[0], 1)  # F2 -> D2 channels
        self.proj1 = nn.Conv2d(encoder_channels[0], decoder_channels[1], 1)  # F1 -> D1 channels

        # Fusion blocks
        self.fusion1 = FusionBlock('add')
        self.fusion0 = FusionBlock('concat')

        # Prediction decoders at each scale
        self.decoder2 = PredictionDecoder(decoder_channels[0])
        self.decoder1 = PredictionDecoder(decoder_channels[1])
        # After concat: up1_resized (128) + f1_proj (128) = 256 channels
        self.decoder0 = PredictionDecoder(decoder_channels[1] + decoder_channels[1])

    def forward(self, features, inference_mode=False):
        f1, f2, f3 = features  # [low_res -> high_res]
        depth_maps = {}

        # Level 2: Start from highest level feature
        up2 = self.upconv2(f3)
        depth_maps['depth_2'] = self.decoder2(up2)

        # Level 1: Project F2 to match up2 channels and fuse
        f2_proj = self.proj2(f2)
        # Resize up2 to match f2 spatial dimensions
        up2_resized = F.interpolate(up2, size=f2_proj.shape[-2:], mode='bilinear', align_corners=False)
        fused1 = self.fusion1(up2_resized, f2_proj)
        up1 = self.upconv1(fused1)
        depth_maps['depth_1'] = self.decoder1(up1)

        # Level 0: Project F1 to match up1 channels and fuse
        f1_proj = self.proj1(f1)
        # Resize up1 to match f1 spatial dimensions
        up1_resized = F.interpolate(up1, size=f1_proj.shape[-2:], mode='bilinear', align_corners=False)
        fused0 = self.fusion0(up1_resized, f1_proj)
        depth_maps['depth_0'] = self.decoder0(fused0)

        return depth_maps

class DepthHead2(nn.Module):
    def __init__(self, in_channels=256, out_channels=1):
        super(DepthHead2, self).__init__()
        # small refinement conv stack
        self.refine = nn.Sequential(
            nn.Conv2d(in_channels * 3, in_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 2, out_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, depth_features):
        # Accept either dict from RTMonoDepthDecoder or list/tuple
        if isinstance(depth_features, dict):
            # prefer 'depth_0' as highest res; upsample others to its size
            keys = ['depth_0', 'depth_1', 'depth_2']
            maps = []
            # find first existing key for target size
            target = None
            for k in keys:
                if k in depth_features:
                    target = depth_features[k].shape[2:]
                    break
            if target is None:
                raise ValueError("depth_features dict empty or unexpected keys")
            for k in keys:
                if k in depth_features:
                    m = depth_features[k]
                    if m.shape[2:] != target:
                        m = F.interpolate(m, size=target, mode='bilinear', align_corners=False)
                    maps.append(m)
            # if less than 3 maps, duplicate last to keep consistent channels
            while len(maps) < 3:
                maps.append(maps[-1])
        else:
            # assume iterable: take first 3 or duplicate if fewer
            maps = list(depth_features)
            while len(maps) < 3:
                maps.append(maps[-1])
            # upsample to first map's size
            target = maps[0].shape[2:]
            maps = [m if m.shape[2:] == target else F.interpolate(m, size=target, mode='bilinear', align_corners=False) for m in maps[:3]]

        # concat along channels and refine
        concat = torch.cat(maps, dim=1)  # C= sum of channels
        out = self.refine(concat)
        return out  # [B,1,H,W] sigmoid-normalized

class DepthHead(nn.Module):
    def __init__(self, in_channels=1):  # Change to 1 to match input depth maps
        super(DepthHead, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )
        self.weights = nn.Parameter(torch.ones(3))  # Adjust to 3 for the number of features (depth_2, depth_1, depth_0); previously 5, which may cause issues in zip()

    def weighted_fusion(self, features, weights, target_size):
        weights = F.softmax(weights, dim=0)
        fused = None
        for feat, weight in zip(features, weights): # features liste halinde : depthmap2 B,1,512,512 depthmap1 B,1,512,512 depthmap0 B,1,256,256
            if feat.shape[2:] != target_size:
                feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            if fused is None:
                fused = weight * feat
            else:
                fused += weight * feat
        return fused

    def forward(self, features):
        processed = [self.conv(feat) for feat in features]
        target_size = processed[0].shape[2:]  # Or use processed[-1].shape[2:] for higher resolution (e.g., 512x512) if preferred
        return self.weighted_fusion(processed, self.weights, target_size)

class DetectionHead(nn.Module):
    def __init__(self, in_channels=256, num_anchors=3, num_classes=8):
        super(DetectionHead, self).__init__()
        self.num_anchors = num_anchors
        self.num_classes = num_classes

        # Classification head: her anchor için sınıf olasılıkları (sigmoid / softmax)
        self.cls_conv = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)

        # Regression head: her anchor için bbox 4 koordinatı
        self.reg_conv = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)

        # istersen headlerde BatchNorm + Activation koyabilirsin

    def forward(self, features):
        """
        features: list of tensor, her biri [B, C, H, W]
        returns:
            cls_preds: [B, total_anchors, num_classes]
            reg_preds: [B, total_anchors, 4]
        """
        cls_outputs = []
        reg_outputs = []

        for feat in features:
            # feat: [B, C, H, W]
            cls_out = self.cls_conv(feat)  # [B, A*C, H, W]
            reg_out = self.reg_conv(feat)  # [B, A*4, H, W]

            B, _, H, W = cls_out.shape

            # reshape: (B, A, C, H, W) → (B, H*W*A, C)
            cls_out = cls_out.view(B, self.num_anchors, self.num_classes, H, W)
            cls_out = cls_out.permute(0, 3, 4, 1, 2).contiguous()  # B, H, W, A, C
            cls_out = cls_out.view(B, -1, self.num_classes)         # B, (H*W*A), C

            # regression benzer şekilde (B, A, 4, H, W) → (B, H*W*A, 4)
            reg_out = reg_out.view(B, self.num_anchors, 4, H, W)
            reg_out = reg_out.permute(0, 3, 4, 1, 2).contiguous()  # B, H, W, A, 4
            reg_out = reg_out.view(B, -1, 4)                       # B, (H*W*A), 4

            cls_outputs.append(cls_out)
            reg_outputs.append(reg_out)

        # Tüm seviyeleri birleştir
        cls_preds = torch.cat(cls_outputs, dim=1)  # B, total_anchors, num_classes
        reg_preds = torch.cat(reg_outputs, dim=1)  # B, total_anchors, 4

        return cls_preds, reg_preds

class MultiTaskHeads(nn.Module):
    def __init__(self, num_classes=10, in_channels=256, num_anchors=3):
        super(MultiTaskHeads, self).__init__()
        # Depth heads unchanged (they expect depth_maps dict/list)
        self.depth_head1 = DepthHead(in_channels=1)   # train fusion from depth_decoder outputs
        self.depth_head2 = DepthHead2(in_channels=1)  # inference refinement
        self.detection_head = DetectionHead(num_anchors=num_anchors)

    def forward(self, bifpn_features, depth_features):
        """
        bifpn_features: list [P3,P4,P5,P6,P7] each [B, C, H, W]
        depth_features: dict from RTMonoDepthDecoder (depth_0,1,2)
        returns:
           classification: list(len=5) of [B, num_anchors*num_classes, H, W]
           regression:     list(len=5) of [B, num_anchors*4, H, W]
           depth: [B,1,Hd,Wd]
        """
        # Depth: use decoder outputs (pixel-wise). Use depth_head1 in train, depth_head2 in inference
        depth_list = [depth_features['depth_2'], depth_features['depth_1'], depth_features['depth_0']]

        depth = self.depth_head1(depth_list)

        cls_preds, reg_preds = self.detection_head(bifpn_features)

        return {
            'depth': depth,
            'classification': cls_preds,
            'regression': reg_preds
        }

class PostProcessor:
    def __init__(self, num_classes=7, num_anchors=15, strides=[8, 16, 32, 64, 128]):
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.strides = strides
        # KITTI-specific anchor scales and ratios
        self.anchor_scales = [[32, 48, 64], [64, 96, 128], [128, 192, 256], [256, 384, 512], [512, 768, 1024]]
        self.anchor_ratios = [[0.5, 1.0, 2.0, 0.33, 3.0]] * 5  # Her feature map için aynı
        self.strides = [8, 16, 32, 64, 128]
        self.anchor_generator = AnchorGenerator(sizes=self.anchor_scales, aspect_ratios=self.anchor_ratios)
    def generate_anchors_for_level(self, feat, level_idx):
        """
        feat: feature map tensor [B, C, H, W]
        level_idx: hangi feature map seviyesi (0..4)
        """
    def generate_all_anchors(self, bifpn_features,image_list):
        anchors_per_level = self.anchor_generator(feature_maps=bifpn_features,image_list=image_list)
        return torch.cat(anchors_per_level, dim=0)  # [N_total, 4]

class CompleteMultiTaskModel(nn.Module):
    def __init__(self, İsPretreained=True, num_classes=7, bifpn_channels=256, bifpn_blocks=3,
                 confidence_threshold=0.1, max_detections=100, num_anchors=15, MAX_CANDIDATES=1000):  # Threshold düşürüldü
        super(CompleteMultiTaskModel, self).__init__()

        # Model components
        self.encoder = EncoderBackBone(İsPretreained)
        in_channels_list = [48, 136, 384]
        self.bifpn = BiFPN(in_channels_list, bifpn_channels, bifpn_blocks)
        self.depth_decoder = RTMonoDepthDecoder()
        self.multi_head = MultiTaskHeads(num_classes, bifpn_channels, num_anchors)

        # Parameters
        self.num_classes = num_classes
        self.conf_thresh = confidence_threshold
        self.max_detections = max_detections
        self.num_anchors = num_anchors
        self.strides = [8, 16, 32, 64, 128]
        self.MAX_CANDIDATES = MAX_CANDIDATES
        self.img_size = 256  # KITTI resized image size

        # Post processor
        self.post_processor = PostProcessor(num_classes, num_anchors, self.strides)

    def forward(self, images, targets=None, mode="train"):
        B = images.shape[0]

        # DÜZELTME: Model mode'u doğru ayarla
        if mode == "inference":
            self.eval()  # Inference için eval mode
        else:
            self.train()  # Training için train mode

        with torch.set_grad_enabled(mode == "train"):  # Gradient sadece training'de
            # Forward pass
            backbone_features = self.encoder(images)
            bifpn_features = self.bifpn(backbone_features)
            depth_maps = self.depth_decoder(backbone_features, inference_mode=(mode != "train"))
            raw_preds = self.multi_head(bifpn_features, depth_maps)
            image_sizes = [img.shape[-2:] for img in images]  # Her image için (H,W)
            image_list = ImageList(images, image_sizes=image_sizes)
            anchors_all = self.post_processor.generate_all_anchors(bifpn_features,image_list)
            cls_preds = raw_preds['classification']
            reg_preds = raw_preds['regression']
            depth_pred = raw_preds['depth']
            if mode == "train":
                return {
                    "cls_preds": cls_preds,
                    "reg_preds": reg_preds,
                    "anchors": anchors_all,
                    "depth_pred": depth_pred,
                    "targets": targets
                }

            elif mode == "inference":
                return self._inference_postprocess_fixed(cls_preds,reg_preds,depth_pred,anchors_all,images)

    def _inference_postprocess_fixed(self, cls_preds, reg_preds, depth_pred, anchors_all, images):
        """Düzeltilmiş inference postprocessing"""
        B = images.shape[0]
        device = images.device
        img_h, img_w = images.shape[2], images.shape[3]  # Model input size

        # Generate anchors (pixel coordinates)
        results = []

        for b in range(B):
            try:
                batch_cls = cls_preds[b]  # [N, num_classes]
                batch_reg = reg_preds[b]  # [N, 4]

                # 1. Confidence filtering
                cls_probs = torch.softmax(batch_cls, dim=-1)
                max_probs, pred_labels = torch.max(cls_probs, dim=-1)

                # Confidence threshold
                conf_mask = max_probs > self.conf_thresh
                if conf_mask.sum() == 0:
                    results.append(self._get_empty_result_single(device))
                    continue

                # Filter predictions
                filtered_probs = max_probs[conf_mask]
                filtered_labels = pred_labels[conf_mask]
                filtered_reg = batch_reg[conf_mask]
                filtered_anchors = anchors_all[conf_mask]

                # 2. Top-K selection for memory efficiency
                if len(filtered_probs) > self.MAX_CANDIDATES:
                    top_k_scores, top_k_idx = torch.topk(filtered_probs, self.MAX_CANDIDATES)
                    filtered_probs = top_k_scores
                    filtered_labels = filtered_labels[top_k_idx]
                    filtered_reg = filtered_reg[top_k_idx]
                    filtered_anchors = filtered_anchors[top_k_idx]
                # 3. Decode boxes - DÜZELTME: Doğru koordinat sistemi
                decoded_boxes = self._decode_boxes_corrected(filtered_anchors, filtered_reg)

                decoded_boxes_normalized = decoded_boxes / torch.tensor([256, 256, 256, 256], device=decoded_boxes.device)
                # Sınır dışını engelle
                decoded_boxes_normalized = decoded_boxes_normalized.clamp(0, 1)
                boxes = decoded_boxes_normalized * torch.tensor([img_w, img_h, img_w, img_h], device=decoded_boxes_normalized.device)


                # 5. Box validation
                valid_boxes, valid_scores, valid_labels = self._validate_boxes_fixed(
                    boxes, filtered_probs, filtered_labels
                )
                print(str(valid_boxes)+ "--"+str(valid_labels)+"--"+str(valid_scores))
                if len(valid_scores) == 0:
                    results.append(self._get_empty_result_single(device))
                    continue

                # 6. NMS - DÜZELTME: Daha düşük threshold
                final_boxes, final_scores, final_labels = self._apply_class_wise_nms(
                    valid_boxes, valid_scores, valid_labels, iou_threshold=0.3  # Düşürüldü
                )



                # 8. Depth extraction
                depth_values = None
                if depth_pred is not None:
                    try:
                        depth_values = self._extract_depth_values(
                            depth_pred[b], final_boxes
                        )
                    except:
                        depth_values = [0.0] * len(final_boxes)

                results.append({
                    "boxes": final_boxes,  # Pixel coordinates for visualization
                    "scores": final_scores,
                    "labels": final_labels,
                    "depth": depth_values
                })

            except Exception as e:
                print(f"Error processing batch {b}: {e}")
                results.append(self._get_empty_result_single(device))

        return results

    def _extract_depth_values(self, depth_map, final_boxes_pixel):
      depth_values = []
      for box in final_boxes_pixel:
          x1, y1, x2, y2 = box
          center_x_pix = int((x1 + x2) / 2 )
          center_y_pix = int((y1 + y2) / 2 )

          # Sınır kontrolü
          center_x_pix = max(0, min(center_x_pix, -1))
          center_y_pix = max(0, min(center_y_pix, -1))

          depth_value = depth_map[center_y_pix, center_x_pix].item()
          depth_values.append(depth_value * 80.0)  # Gerçek mesafeye çevir
      return depth_values

    def _apply_class_wise_nms(self, boxes, scores, labels, iou_threshold=0.2):
      final_boxes = []
      final_scores = []
      final_labels = []

      unique_labels = labels.unique()
      for lbl in unique_labels:
          mask = labels == lbl
          boxes_lbl = boxes[mask]
          scores_lbl = scores[mask]

          keep = nms(boxes_lbl, scores_lbl, iou_threshold)

          final_boxes.append(boxes_lbl[keep])
          final_scores.append(scores_lbl[keep])
          final_labels.append(labels[mask][keep])

      if final_boxes:
          final_boxes = torch.cat(final_boxes)
          final_scores = torch.cat(final_scores)
          final_labels = torch.cat(final_labels)
      else:
          final_boxes = torch.empty((0, 4))
          final_scores = torch.empty((0,))
          final_labels = torch.empty((0,), dtype=torch.long)

      return final_boxes, final_scores, final_labels

    def _decode_boxes_corrected(self, anchors, deltas): # x1,y1,x2,y2 anchors formatı tx,ty,tw,th deltas formatı ikiside piksel cinsinde
        if anchors.numel() == 0:
            return torch.zeros((0, 4), device=anchors.device)

        # Anchor boyut ve merkezleri
        anchor_widths  = anchors[:, 2] - anchors[:, 0]
        anchor_heights = anchors[:, 3] - anchors[:, 1]
        anchor_ctr_x   = anchors[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchors[:, 1] + 0.5 * anchor_heights


        # Delta değerleri
        dx, dy, dw, dh = deltas[:, 0], deltas[:, 1], deltas[:, 2], deltas[:, 3]

        # Patlamayı önlemek için genişlik/yükseklik log-scale clamp
        dw = torch.clamp(dw, max=4.0)
        dh = torch.clamp(dh, max=4.0)

        # Deltaları uygula
        pred_ctr_x = dx * anchor_widths + anchor_ctr_x
        pred_ctr_y = dy * anchor_heights + anchor_ctr_y
        pred_w = torch.exp(dw) * anchor_widths
        pred_h = torch.exp(dh) * anchor_heights

        # Köşe formatına çevir
        x1 = pred_ctr_x - 0.5 * pred_w
        y1 = pred_ctr_y - 0.5 * pred_h
        x2 = pred_ctr_x + 0.5 * pred_w
        y2 = pred_ctr_y + 0.5 * pred_h

        return torch.stack((x1, y1, x2, y2), dim=1)



    def _validate_boxes_fixed(self, boxes, scores, labels):
        """Normalize koordinatlarda box validation"""
        if len(boxes) == 0:
            return self._get_empty_tensors(boxes.device)

        # Geometric validation
        valid_geom = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])

        # Size validation (normalized coordinates)
        box_w = boxes[:, 2] - boxes[:, 0]
        box_h = boxes[:, 3] - boxes[:, 1]
        min_size = 0.001  # 1% of image
        max_size = 1  # 95% of image
        valid_size = (box_w >= min_size) & (box_h >= min_size) & \
                     (box_w <= max_size) & (box_h <= max_size)

        # Bounds validation (normalized)
        margin = 0.05  # 5% margin
        valid_bounds = (boxes[:, 0] >= -margin) & (boxes[:, 1] >= -margin) & \
                      (boxes[:, 2] <= 1 + margin) & (boxes[:, 3] <= 1 + margin)

        valid_mask = valid_geom & valid_size & valid_bounds

        if valid_mask.sum() == 0:
            return self._get_empty_tensors(boxes.device)

        boxes = boxes[valid_mask]
        scores = scores[valid_mask]
        labels = labels[valid_mask]
        boxes = torch.clamp(boxes, 0, 1)

        return boxes, scores, labels

    def _get_empty_tensors(self, device):
        return (torch.zeros((0, 4), device=device),
                torch.zeros((0,), device=device),
                torch.zeros((0,), dtype=torch.long, device=device))

    def _get_empty_result_single(self, device):
        return {
            "boxes": torch.zeros((0, 4), device=device),
            "scores": torch.zeros((0,), device=device),
            "labels": torch.zeros((0,), device=device),
            "depth": None
        }

class MultiTaskCriterion(nn.Module):
    def __init__(self, num_classes=7, loss_weights=None, device='cuda',
                 pos_iou_threshold=0.5, neg_iou_threshold=0.3, img_size=256):
        super(MultiTaskCriterion, self).__init__()
        self.device = device
        self.img_size = img_size
        self.num_classes = num_classes

        self.loss_weights = loss_weights if loss_weights else {
            'classification': 1.0,
            'regression': 1.0,
            'depth': 1.0,
            'depth_map': 0.1
        }

        self.cls_criterion = nn.CrossEntropyLoss(reduction='none')
        self.reg_criterion = nn.SmoothL1Loss(reduction='none')
        self.depth_criterion = nn.MSELoss(reduction='none')

        self.pos_iou_threshold = pos_iou_threshold
        self.neg_iou_threshold = neg_iou_threshold

    def forward(self, model_output):
        cls_preds = model_output['cls_preds']
        reg_preds = model_output['reg_preds']
        anchors = model_output['anchors']
        depth_pred = model_output['depth_pred']
        targets = model_output['targets']
        batch_size = cls_preds.size(0)

        total_loss = 0.0
        cls_loss_sum = 0.0
        reg_loss_sum = 0.0
        depth_loss_sum = 0.0
        depth_map_loss_val = 0.0

        # DÜZELTME: Object-based metrikler için
        total_objects = 0
        detected_objects = 0
        correct_detections = 0
        depth_errors = []
        valid_samples = 0

        for batch_idx in range(batch_size):
            try:
                if isinstance(targets, list):
                    if batch_idx >= len(targets):
                        continue
                    target = targets[batch_idx]
                    if isinstance(target, torch.Tensor):
                        if target.numel() == 0 or target.size(0) == 0:
                            continue
                        target = target.to(self.device)
                    else:
                        continue
                else:
                    if batch_idx not in targets or len(targets[batch_idx]) == 0:
                        continue
                    target = targets[batch_idx].to(self.device)

                gt_classes = target[:, 0].long()
                gt_boxes = target[:, 1:5]
                gt_depths = target[:, 5] if target.size(1) > 5 else None

                batch_cls_preds = cls_preds[batch_idx]
                batch_reg_preds = reg_preds[batch_idx]
                batch_depth_preds = depth_pred[batch_idx] if depth_pred is not None else None

                # Anchor assignment (loss için gerekli)
                pos_indices, neg_indices, matched_gt_indices = self._assign_targets_to_anchors(
                    anchors, gt_boxes, gt_classes
                )

                # Classification Loss (anchor-based, loss için)
                if len(pos_indices) > 0:
                    cls_loss = self._compute_classification_loss(
                        batch_cls_preds, gt_classes, pos_indices, neg_indices, matched_gt_indices
                    )
                    if cls_loss is not None:
                        total_loss += cls_loss * self.loss_weights['classification']
                        cls_loss_sum += cls_loss.detach().item()

                # Regression Loss (anchor-based, loss için)
                if len(pos_indices) > 0:
                    reg_loss = self._compute_regression_loss(
                        batch_reg_preds, gt_boxes, anchors, pos_indices, matched_gt_indices
                    )
                    if reg_loss is not None:
                        total_loss += reg_loss * self.loss_weights['regression']
                        reg_loss_sum += reg_loss.detach().item()

                # Depth Loss (anchor-based, loss için)
                if gt_depths is not None and batch_depth_preds is not None:
                    depth_loss = self._compute_depth_loss_normalized(
                        batch_depth_preds, gt_depths, gt_boxes
                    )
                    if depth_loss is not None:
                        total_loss += depth_loss * self.loss_weights['depth']
                        depth_loss_sum += depth_loss.detach().item()

                # DÜZELTME: Object-based evaluation
                obj_metrics = self._evaluate_object_detection(
                    batch_cls_preds, batch_reg_preds, anchors,
                    gt_classes, gt_boxes, gt_depths, batch_depth_preds
                )

                total_objects += obj_metrics['total_objects']
                detected_objects += obj_metrics['detected_objects']
                correct_detections += obj_metrics['correct_detections']
                if obj_metrics['depth_errors']:
                    depth_errors.extend(obj_metrics['depth_errors'])

                valid_samples += 1

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        # Depth map smoothness loss
        if depth_pred is not None:
            depth_map_loss = self._compute_depth_smoothness_loss(depth_pred)
            total_loss += depth_map_loss * self.loss_weights['depth_map']
            depth_map_loss_val = depth_map_loss.detach().item()

        # DÜZELTME: Object-based metrikler
        object_accuracy = correct_detections / max(total_objects, 1)
        object_precision = correct_detections / max(detected_objects, 1)
        object_recall = correct_detections / max(total_objects, 1)
        object_f1 = 2 * (object_precision * object_recall) / max(object_precision + object_recall, 1e-6)

        avg_depth_error = np.mean(depth_errors) if depth_errors else 0.0
        rmse_depth = np.sqrt(avg_depth_error) if depth_errors else 0.0

        losses = {
            'total': total_loss,
            'classification': cls_loss_sum / max(valid_samples, 1),
            'regression': reg_loss_sum / max(valid_samples, 1),
            'depth': depth_loss_sum / max(valid_samples, 1),
            'depth_map': depth_map_loss_val
        }

        metrics = {
            'Accuracy': object_accuracy,  # Artık object-based
            'F1_score': object_f1,        # Artık object-based
            'MSE': avg_depth_error,
            'RMSE': rmse_depth,
            'mAP': object_precision,      # Basitleştirilmiş object precision
            'TotalLoss': total_loss.detach().item() if isinstance(total_loss, torch.Tensor) else total_loss,
            'ClsLoss': cls_loss_sum / max(valid_samples, 1),
            'Precision': object_precision,
            'Recall': object_recall,
            'TotalObjects': total_objects,
            'DetectedObjects': detected_objects,
            'CorrectDetections': correct_detections
        }

        return losses, metrics

    def _evaluate_object_detection(self, cls_preds, reg_preds, anchors, gt_classes, gt_boxes, gt_depths=None, depth_pred=None):
        """
        DÜZELTME: Object-based detection evaluation
        Her GT object için en iyi anchor'u bulur ve değerlendirir
        """
        with torch.no_grad():
            total_objects = len(gt_boxes)
            detected_objects = 0
            correct_detections = 0
            depth_errors = []

            if total_objects == 0:
                return {
                    'total_objects': 0,
                    'detected_objects': 0,
                    'correct_detections': 0,
                    'depth_errors': []
                }

            # Her GT object için en iyi anchor'u bul
            anchors_norm = anchors.clone()
            anchors_norm[:, [0,2]] /= 1242
            anchors_norm[:, [1,3]] /= 375

            ious = self.bbox_iou(anchors_norm, gt_boxes)  # [N_anchors, N_gt]

            for gt_idx in range(len(gt_boxes)):
                # Bu GT object için en yüksek IoU'ya sahip anchor
                gt_ious = ious[:, gt_idx]
                best_anchor_idx = torch.argmax(gt_ious)
                best_iou = gt_ious[best_anchor_idx]

                # Eğer IoU yeterince yüksekse detection var sayıyoruz
                if best_iou >= self.pos_iou_threshold:
                    detected_objects += 1

                    # Classification doğruluğunu kontrol et
                    pred_cls = torch.argmax(cls_preds[best_anchor_idx])
                    gt_cls = gt_classes[gt_idx]

                    if pred_cls == gt_cls:
                        correct_detections += 1

                    # Depth error hesapla
                    if gt_depths is not None and depth_pred is not None:
                        gt_box = gt_boxes[gt_idx]
                        gt_depth = gt_depths[gt_idx]

                        # Depth map'ten center pixel'ı sample et
                        if depth_pred.dim() == 3:
                            depth_map = depth_pred[0]
                        else:
                            depth_map = depth_pred

                        H, W = depth_map.shape
                        center_x = int((gt_box[0] + gt_box[2]) / 2 * W)
                        center_y = int((gt_box[1] + gt_box[3]) / 2 * H)
                        center_x = max(0, min(center_x, W-1))
                        center_y = max(0, min(center_y, H-1))

                        pred_depth = depth_map[center_y, center_x].clamp(0, 1)
                        depth_error = F.mse_loss(pred_depth, gt_depth.clamp(0, 1)).item()
                        depth_errors.append(depth_error)

            return {
                'total_objects': total_objects,
                'detected_objects': detected_objects,
                'correct_detections': correct_detections,
                'depth_errors': depth_errors
            }

    def _compute_classification_loss(self, cls_preds, gt_classes, pos_indices, neg_indices, matched_gt_indices):
        """Classification loss - anchor based (loss için)"""
        if len(pos_indices) == 0:
            return None

        pos_cls_preds = cls_preds[pos_indices]
        pos_gt_classes = gt_classes[matched_gt_indices]
        pos_loss = self.focal_loss(pos_cls_preds, pos_gt_classes)

        if len(neg_indices) > 0:
            max_neg_samples = min(len(neg_indices), len(pos_indices) * 3)
            neg_cls_preds = cls_preds[neg_indices]
            neg_scores = torch.max(torch.softmax(neg_cls_preds, dim=1), dim=1)[0]
            _, hard_neg_indices = torch.topk(neg_scores, max_neg_samples, largest=True)

            selected_neg_indices = neg_indices[hard_neg_indices]
            selected_neg_cls_preds = cls_preds[selected_neg_indices]
            neg_gt_classes = torch.zeros(len(selected_neg_indices), dtype=torch.long, device=self.device)
            neg_loss = self.focal_loss(selected_neg_cls_preds, neg_gt_classes)
            total_cls_loss = pos_loss + neg_loss
        else:
            total_cls_loss = pos_loss

        return total_cls_loss

    def _compute_regression_loss(self, reg_preds, gt_boxes, anchors, pos_indices, matched_gt_indices):
        """Regression loss - anchor based (loss için)"""
        if len(pos_indices) == 0:
            return None

        pos_anchors = anchors[pos_indices]
        pos_reg_preds = reg_preds[pos_indices]
        pos_gt_boxes = gt_boxes[matched_gt_indices]

        # Normalize anchors
        pos_anchors_norm = pos_anchors.clone()
        pos_anchors_norm[:, [0,2]] /= 1242
        pos_anchors_norm[:, [1,3]] /= 375

        pos_gt_encoded = self._encode_boxes(pos_anchors_norm, pos_gt_boxes)
        reg_loss = self.reg_criterion(pos_reg_preds, pos_gt_encoded).mean()

        return reg_loss

    def _compute_depth_loss_normalized(self, depth_pred, gt_depths, gt_boxes):
        """Depth loss - anchor based (loss için)"""
        if len(gt_depths) == 0:
            return None

        if depth_pred.dim() == 3:
            depth_map = depth_pred[0]
        else:
            depth_map = depth_pred

        H, W = depth_map.shape
        sampled_depths = []
        target_depths = []

        for gt_box, gt_depth in zip(gt_boxes, gt_depths):
            x1, y1, x2, y2 = gt_box

            x1_pix = int(x1 * W)
            y1_pix = int(y1 * H)
            x2_pix = int(x2 * W)
            y2_pix = int(y2 * H)

            x1_pix = max(0, min(x1_pix, W-1))
            y1_pix = max(0, min(y1_pix, H-1))
            x2_pix = max(0, min(x2_pix, W-1))
            y2_pix = max(0, min(y2_pix, H-1))

            center_y = (y1_pix + y2_pix) // 2
            center_x = (x1_pix + x2_pix) // 2

            sampled_depth = depth_map[center_y, center_x]
            sampled_depth = torch.clamp(sampled_depth, 0, 1)
            sampled_depths.append(sampled_depth)
            target_depths.append(gt_depth.clamp(0, 1))

        if len(sampled_depths) == 0:
            return None

        sampled_depths = torch.stack(sampled_depths)
        target_depths = torch.stack(target_depths)
        depth_loss = self.depth_criterion(sampled_depths, target_depths).mean()

        return depth_loss

    def _compute_depth_smoothness_loss(self, depth_pred):
        """Compute depth map smoothness loss"""
        grad_x = torch.abs(depth_pred[:, :, :, :-1] - depth_pred[:, :, :, 1:])
        grad_y = torch.abs(depth_pred[:, :, :-1, :] - depth_pred[:, :, 1:, :])
        smoothness_loss = grad_x.mean() + grad_y.mean()
        return smoothness_loss

    def _assign_targets_to_anchors(self, anchors, gt_boxes, gt_classes):
        """Assign ground truth to anchors based on IoU"""
        if len(gt_boxes) == 0:
            return torch.tensor([], dtype=torch.long, device=self.device), \
                   torch.tensor([], dtype=torch.long, device=self.device), \
                   torch.tensor([], dtype=torch.long, device=self.device)

        anchors_norm = anchors.clone()
        anchors_norm[:, [0,2]] /= 1242
        anchors_norm[:, [1,3]] /= 375
        ious = self.bbox_iou(anchors_norm, gt_boxes)
        max_ious, matched_gt_indices = torch.max(ious, dim=1)

        pos_mask = max_ious >= self.pos_iou_threshold
        neg_mask = max_ious < self.neg_iou_threshold

        pos_indices = torch.where(pos_mask)[0]
        neg_indices = torch.where(neg_mask)[0]
        matched_gt_indices = matched_gt_indices[pos_indices]

        return pos_indices, neg_indices, matched_gt_indices

    def _encode_boxes(self, anchors, gt_boxes):
        """Box encoding - normalize koordinatlarda"""
        anchor_widths = anchors[:, 2] - anchors[:, 0]
        anchor_heights = anchors[:, 3] - anchors[:, 1]
        anchor_ctr_x = anchors[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchors[:, 1] + 0.5 * anchor_heights

        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
        gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths
        gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights

        dx = (gt_ctr_x - anchor_ctr_x) / (anchor_widths + 1e-6)
        dy = (gt_ctr_y - anchor_ctr_y) / (anchor_heights + 1e-6)
        dw = torch.log(gt_widths / (anchor_widths + 1e-6))
        dh = torch.log(gt_heights / (anchor_heights + 1e-6))

        return torch.stack((dx, dy, dw, dh), dim=1)

    def bbox_iou(self, box1, box2):
        """Compute IoU between two sets of boxes"""
        area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
        area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])

        lt = torch.max(box1[:, None, :2], box2[:, :2])
        rb = torch.min(box1[:, None, 2:], box2[:, 2:])

        wh = (rb - lt).clamp(min=0)
        inter = wh[:, :, 0] * wh[:, :, 1]

        union = area1[:, None] + area2 - inter
        iou = inter / union
        return iou

    def focal_loss(self, inputs, targets, alpha=0.25, gamma=2.0):
        """Focal Loss implementation"""
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss
        return focal_loss.mean()

def test_model_corrected(model, test_loader, device='cuda', save_predictions=True):
    """Düzeltilmiş test fonksiyonu"""
    model.eval()  # IMPORTANT: Model'i eval mode'a al

    results = []

    with torch.no_grad():
        for batch_idx, images in enumerate(test_loader):
            images = images.to(device)

            # DÜZELTME: mode="inference" kullan
            batch_results = model(images, mode="inference")

            for i, result in enumerate(batch_results):
                image_result = {
                    'image_idx': batch_idx * len(images) + i,
                    'boxes': result['boxes'].cpu().numpy(),
                    'scores': result['scores'].cpu().numpy(),
                    'labels': result['labels'].cpu().numpy(),
                    'depth': result['depth'] if result['depth'] is not None else None
                }
                results.append(image_result)

            if batch_idx % 10 == 0:
                print(f"Processed {batch_idx}/{len(test_loader)} batches")

    return results

def visualize_single_prediction(image_path, prediction, show_depth=True, img_size=256):
    """
    image_path: Görsel dosya yolu
    prediction: test_model_corrected çıktısındaki tek bir dict
    img_size: Görselin yeniden boyutlandırılacak boyutu (square)
    """
    # Görseli yükle ve resize et
    image = Image.open(image_path).convert("RGB")
    w, h = image.size
    image_np = np.array(image)

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(image_np)

    boxes = prediction[379]['boxes']
    labels = prediction[379]['labels']
    depth_values = prediction[379]['depth']

    # Eğer boxes normalized ise img_size ile çarp
    if boxes.max() <= 1.0:
        boxes = boxes * img_size

    for i, box in enumerate(boxes):

      x1, y1, x2, y2 = box
      print(box)
      x1 *= 256
      x2 *= 256
      y1 *= 256
      y2 *= 256
      rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                          fill=False, color='red', linewidth=2)
      ax.add_patch(rect)

      label_str = kitti_class_dict[labels[i]]
      text = f"{label_str}"
      if show_depth and depth_values is not None:
          text += f", {depth_values[i]:.2f} metre"
      ax.text(x1, y1, text, color='yellow', fontsize=10)

    plt.axis('off')
    plt.show()

kitti_class_dict = {
    0: "Car",
    1: "Van",
    2: "Truck",
    3: "Pedestrian",
    4: "Person_sitting",
    5: "Cyclist",
    6: "Misc"
}

In [3]:
data_path = "/content/kitti2012" #colab
#data_path = "C:/Users/Mehmet/Desktop/kitti2012" #lcoal

image_size=256
batch_size=1

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])
transform2 = transforms.Compose([
    transforms.ToTensor()
])

train_val_dataset = KITTI_Dataset(data_path=data_path ,transform=transform, mode='train')
train_size = int(0.8 * len(train_val_dataset))  # ~155 sahne
val_size = len(train_val_dataset) - train_size  # ~39 sahne
train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

test_dataset  = KITTI_Dataset(data_path,transform=transform,mode='test')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0,collate_fn=kitti_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,pin_memory=True, num_workers=0,collate_fn=test_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0,collate_fn=kitti_collate_fn)

class_weights = torch.tensor([w / sum([1 / w for w in [1201,113,39,161,76,36,35]]) for w in [1 / w for w in [1201,113,39,161,76,36,35]]], dtype=torch.float32).to(device)

In [None]:
results = analyze_dataset(train_dataset,device)

In [None]:
def objective(trial):

    task_weights = {
          'classification': 1.0,
          'regression': 2.0,
          'depth': 0.3,
          'depth_map': 0.3}

    # IoU thresholds
    p_iou_threshold = trial.suggest_categorical('p_iou_threshold', [0.35, 0.45])
    n_iou_threshold = trial.suggest_categorical('n_iou_threshold', [0.2, 0.3])
    model = CompleteMultiTaskModel(num_classes=7,max_detections=15, num_anchors=15,MAX_CANDIDATES=3000,confidence_threshold=0.25).to(device)
    # Model oluştur ve BatchNorm ayarla

    best_val_loss = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=30,  # Optuna için kısa epoch
        learning_rate=1e-3,
        device=device,
        save_path=f'weights_{trial.number}.pth',
        class_weights=class_weights,
        task_weights=task_weights,
        scheduler_patience=3,
        scheduler_factor=0.2,
        p_iou_threshold=p_iou_threshold,
        n_iou_threshold=n_iou_threshold,
        early_stop_patience=5
    )

    return best_val_loss
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=2)

print("En iyi parametreler:", study.best_params)

[I 2025-08-13 22:08:56,189] A new study created in memory with name: no-name-b2d5fc17-598c-40b9-9c81-4ec75abc669f
[34m[1mwandb[0m: Currently logged in as: [33mmehmeteminuludag[0m ([33mmehmeteminuludag-kirikkale-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/30 Train:  39%|███▉      | 122/310 [01:11<01:54,  1.64it/s, Acc=0.127, ClsLoss=0.127, F1=0.172, RMSE=0.159, mAP=0.339, TotalLoss=0.233]