In [None]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
from scipy import ndimage
from skimage.filters import threshold_otsu
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import torch.nn.functional as F

# ---------------------------------------
# CUDA Ayarları
# ---------------------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Multi-GPU ayarları
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    # Her GPU için batch size'ı ayarla
    BATCH_SIZE_PER_GPU = 12  # T4 için optimize edilmiş batch size
    BATCH_SIZE = BATCH_SIZE_PER_GPU * torch.cuda.device_count()
else:
    BATCH_SIZE = 12
    print("Using single GPU")

# ---------------------------------------
# Eğitim Parametreleri
# ---------------------------------------
num_epochs = 20  # Epoch sayısı 20'ye düşürüldü
patience = 5  # Early stopping için patience azaltıldı
min_delta = 0.0005  # Early stopping için minimum değişim

# Başlangıç hiperparametreleri
initial_hyperparams = {
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'dropout': 0.3
}

# Fold sonuçlarını saklamak için
fold_results = []
fold1_metrics = None

# ---------------------------------------
# 1) Veri Yolu ve Sınıflar
# ---------------------------------------
DATA_DIR = '/kaggle/input/adniunpreprocessed'
classes = {'CN': 0, 'MCI': 1, 'AD': 2}

# Dosyaları topla ve her sınıftan 300 örneğe indir
all_files = glob.glob(os.path.join(DATA_DIR, '**', '*.nii*'), recursive=True)
class_paths = {c: [] for c in classes}
for fp in all_files:
    parent = os.path.basename(os.path.dirname(fp))
    if parent in classes and os.path.getsize(fp) > 0:
        try:
            nib.load(fp)
            class_paths[parent].append(fp)
        except:
            pass

limited_paths, limited_labels = [], []
for cls, paths in class_paths.items():
    random.shuffle(paths)
    selected = paths[:300]
    limited_paths.extend(selected)
    limited_labels.extend([classes[cls]] * len(selected))

limited_paths = np.array(limited_paths)
limited_labels = np.array(limited_labels)

# ---------------------------------------
# 2) Dataset ve Basit Augmentasyon
# ---------------------------------------
PATCH_SIZE = 112
class ADNI_Dataset(Dataset):
    def __init__(self, paths, labels, augment=False):
        self.paths = paths
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        
        # Daha güçlü normalizasyon
        p1, p99 = np.percentile(img, (1, 99))
        img = np.clip(img, p1, p99)
        img = (img - p1) / (p99 - p1 + 1e-6)
        
        # Beyin maskeleme ve temizleme
        thresh = threshold_otsu(img)
        mask = img > thresh
        mask = ndimage.binary_closing(mask, structure=np.ones((5,5,5)))  # Daha büyük kernel
        mask = ndimage.binary_fill_holes(mask)
        mask = ndimage.binary_dilation(mask, structure=np.ones((5,5,5)))  # Daha büyük kernel
        labeled, nf = ndimage.label(mask)
        if nf > 1:
            counts = np.bincount(labeled.ravel()); counts[0] = 0
            mask = (labeled == counts.argmax())
        
        # Maske uygula
        img = img * mask
        
        # Patch crop
        coords = np.array(np.where(mask))
        mins, maxs = coords.min(axis=1), coords.max(axis=1)
        center = ((mins + maxs) // 2).astype(int)
        
        half = PATCH_SIZE // 2
        starts = center - half
        ends = starts + PATCH_SIZE
        
        # Sınırları kontrol et
        for i, dim in enumerate(img.shape):
            if starts[i] < 0: starts[i] = 0; ends[i] = PATCH_SIZE
            if ends[i] > dim: ends[i] = dim; starts[i] = dim - PATCH_SIZE
        
        patch = img[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
        
        # Padding
        if patch.shape != (PATCH_SIZE,)*3:
            pad = np.zeros((PATCH_SIZE,)*3, np.float32)
            pd, ph, pw = patch.shape
            sd, sh, sw = [(PATCH_SIZE - s)//2 for s in patch.shape]
            pad[sd:sd+pd, sh:sh+ph, sw:sw+pw] = patch
            patch = pad

        # Geliştirilmiş normalizasyon
        patch = (patch - patch.mean()) / (patch.std() + 1e-6)
        
        # Geliştirilmiş augmentasyon
        if self.augment:
            # Random flip (tüm eksenlerde)
            for axis in [0, 1, 2]:
                if random.random() < 0.5:
                    patch = np.flip(patch, axis=axis).copy()
            
            # Random rotation (daha geniş açı aralığı)
            if random.random() < 0.5:
                angle = random.uniform(-15, 15)
                patch = ndimage.rotate(patch, angle, axes=(0,1), reshape=False)
            
            # Random brightness (daha geniş aralık)
            if random.random() < 0.5:
                patch = patch * random.uniform(0.7, 1.3)
            
            # Random contrast (daha güçlü)
            if random.random() < 0.5:
                patch = (patch - patch.mean()) * random.uniform(0.7, 1.3) + patch.mean()
            
            # Gaussian noise (daha kontrollü)
            if random.random() < 0.5:
                noise_level = random.uniform(0.01, 0.03)
                patch += np.random.normal(0, noise_level, size=patch.shape).astype(np.float32)
            
            # Random zoom (yeni)
            if random.random() < 0.3:
                zoom_factor = random.uniform(0.9, 1.1)
                patch = ndimage.zoom(patch, zoom_factor, order=1)
                if patch.shape != (PATCH_SIZE,)*3:
                    # Yeniden boyutlandır
                    patch = ndimage.zoom(patch, PATCH_SIZE/patch.shape[0], order=1)

        tensor = torch.from_numpy(patch).unsqueeze(0)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return tensor, label

# ---------------------------------------
# 3) Model Tanımı (Inception3D + SE + Attention)
# ---------------------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels//reduction, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels//reduction, channels, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        b,c,d,h,w = x.size()
        y = x.view(b,c,-1).mean(dim=2)
        y = self.relu(self.fc1(y))
        y = self.sigmoid(self.fc2(y)).view(b,c,1,1,1)
        return x * y

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, in_channels//8, 1)
        self.conv2 = nn.Conv3d(in_channels//8, in_channels, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        attention = F.avg_pool3d(x, x.size()[2:])
        attention = F.relu(self.conv1(attention))
        attention = self.sigmoid(self.conv2(attention))
        return x * attention

class Inception3D_SE_Attention(nn.Module):
    def __init__(self, in_ch, c1, c3r, c3, c5r, c5, pp):
        super().__init__()
        self.b1 = nn.Sequential(nn.Conv3d(in_ch,c1,1,bias=False), nn.BatchNorm3d(c1), nn.ReLU(True))
        self.b3 = nn.Sequential(nn.Conv3d(in_ch,c3r,1,bias=False), nn.BatchNorm3d(c3r), nn.ReLU(True),
                              nn.Conv3d(c3r,c3,3,padding=1,bias=False), nn.BatchNorm3d(c3), nn.ReLU(True))
        self.b5 = nn.Sequential(nn.Conv3d(in_ch,c5r,1,bias=False), nn.BatchNorm3d(c5r), nn.ReLU(True),
                              nn.Conv3d(c5r,c5,5,padding=2,bias=False), nn.BatchNorm3d(c5), nn.ReLU(True))
        self.bp = nn.Sequential(nn.MaxPool3d(3,1,1), nn.Conv3d(in_ch,pp,1,bias=False), nn.BatchNorm3d(pp), nn.ReLU(True))
        out_ch = c1+c3+c5+pp
        self.se = SEBlock(out_ch)
        self.attention = AttentionBlock(out_ch)
        
    def forward(self, x):
        out = torch.cat([self.b1(x), self.b3(x), self.b5(x), self.bp(x)], dim=1)
        out = self.se(out)
        out = self.attention(out)
        return out

class Inception3DNet_SE_Attention(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.3):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv3d(1, 64, 7, 2, 3, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(True),
            nn.MaxPool3d(3, 2, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(64, 64, 1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(True),
            nn.Conv3d(64, 192, 3, 1, 1, bias=False),
            nn.BatchNorm3d(192),
            nn.ReLU(True),
            nn.MaxPool3d(3, 2, 1)
        )
        self.i3a = Inception3D_SE_Attention(192, 64, 96, 128, 16, 32, 32)
        self.i3b = Inception3D_SE_Attention(256, 128, 128, 192, 32, 96, 64)
        self.mp3 = nn.MaxPool3d(3, 2, 1)
        self.i3c = Inception3D_SE_Attention(480, 192, 96, 208, 16, 48, 64)
        self.i3d = Inception3D_SE_Attention(512, 160, 112, 224, 24, 64, 64)
        self.i3e = Inception3D_SE_Attention(512, 128, 128, 256, 24, 64, 64)
        self.mp4 = nn.MaxPool3d(3, 2, 1)
        self.avgp = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.drop = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(512, num_classes)
        
        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
                if hasattr(m, 'weight') and m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.i3a(x)
        x = self.i3b(x)
        x = self.mp3(x)
        x = self.i3c(x)
        x = self.i3d(x)
        x = self.i3e(x)
        x = self.mp4(x)
        x = self.avgp(x)
        x = self.drop(x.flatten(1))
        return self.fc(x)

# ---------------------------------------
# 4) 2-Fold Cross-Validation -> Her Fold 20 Epoch
# ---------------------------------------
skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)

for fold, (tr_idx, va_idx) in enumerate(skf.split(limited_paths, limited_labels), 1):
    print(f"\n=== Fold {fold}/2 ===")
    
    # Fold1 sonuçlarına göre hiperparametreleri ayarla
    if fold == 1:
        current_hyperparams = initial_hyperparams
    else:
        # Fold1 sonuçlarına göre hiperparametreleri dinamik olarak ayarla
        if fold1_metrics is not None:
            # Validation accuracy'ye göre learning rate ayarla
            if fold1_metrics['val_acc'] < 0.7:
                current_hyperparams['lr'] *= 1.5  # Düşük accuracy için learning rate'i artır
            elif fold1_metrics['val_acc'] > 0.85:
                current_hyperparams['lr'] *= 0.8  # Yüksek accuracy için learning rate'i azalt
            
            # Validation loss'a göre dropout ayarla
            if fold1_metrics['val_loss'] > 1.0:
                current_hyperparams['dropout'] = min(0.5, current_hyperparams['dropout'] + 0.05)  # Overfitting varsa dropout'u artır
            elif fold1_metrics['val_loss'] < 0.5:
                current_hyperparams['dropout'] = max(0.2, current_hyperparams['dropout'] - 0.05)  # Underfitting varsa dropout'u azalt
            
            # Validation accuracy ve loss'a göre weight decay ayarla
            if fold1_metrics['val_acc'] < 0.7 and fold1_metrics['val_loss'] > 1.0:
                current_hyperparams['weight_decay'] *= 1.5  # Düşük performans için regularization'ı artır
            elif fold1_metrics['val_acc'] > 0.85 and fold1_metrics['val_loss'] < 0.5:
                current_hyperparams['weight_decay'] *= 0.8  # Yüksek performans için regularization'ı azalt
        else:
            current_hyperparams = initial_hyperparams
    
    print(f"\nFold {fold} Hyperparameters:")
    print(f"Learning Rate: {current_hyperparams['lr']}")
    print(f"Weight Decay: {current_hyperparams['weight_decay']}")
    print(f"Dropout: {current_hyperparams['dropout']}")
    
    # Data split
    tr_p, tr_l = limited_paths[tr_idx], limited_labels[tr_idx]
    va_p, va_l = limited_paths[va_idx], limited_labels[va_idx]
    
    # Sınıf dağılımını kontrol et ve class weights hesapla
    class_counts = Counter(tr_l)
    total_samples = len(tr_l)
    class_weights = {cls: total_samples / (len(class_counts) * count) for cls, count in class_counts.items()}
    class_weights = torch.tensor([class_weights[i] for i in range(len(class_counts))]).to(DEVICE)
    
    print("\nTrain set class distribution:")
    print(Counter(tr_l))
    print("\nValidation set class distribution:")
    print(Counter(va_l))
    print("\nClass weights:", class_weights)

    # Dataset ve DataLoader
    train_ds = ADNI_Dataset(tr_p, tr_l, augment=True)
    val_ds   = ADNI_Dataset(va_p, va_l, augment=False)
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4, 
        pin_memory=True, 
        prefetch_factor=2,
        persistent_workers=True
    )
    val_loader = DataLoader(
        val_ds,   
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=4, 
        pin_memory=True, 
        prefetch_factor=2,
        persistent_workers=True
    )

    # Model, loss, optimizer
    model = Inception3DNet_SE_Attention(3, dropout_rate=current_hyperparams['dropout']).to(DEVICE)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    # Weighted Cross Entropy Loss
    criterion = CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=current_hyperparams['lr'], weight_decay=current_hyperparams['weight_decay'])
    
    # Cosine Annealing with Warm Restarts
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,  # İlk restart periyodu azaltıldı
        T_mult=2,  # Her restart'ta periyodu 2 katına çıkar
        eta_min=current_hyperparams['lr'] * 0.01  # Minimum learning rate
    )
    
    scaler = GradScaler()

    # Eğitim metrikleri
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    best_val_acc = 0.0
    best_epoch = 0
    no_improve = 0

    # Eğitim döngüsü
    for epoch in range(1, num_epochs + 1):
        # Train
        model.train()
        run_loss = corr = tot = 0
        for imgs, labels in tqdm(train_loader, desc=f"Fold {fold} Train E{epoch}"):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            
            with autocast():
                logits = model(imgs)
                loss = criterion(logits, labels)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            run_loss += loss.item()
            preds = logits.argmax(1)
            corr += (preds==labels).sum().item()
            tot += labels.size(0)
        
        train_loss = run_loss/len(train_loader)
        train_acc = corr/tot
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        print(f"Epoch {epoch}/{num_epochs}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.4f}")

        # Validation
        model.eval()
        val_loss = corr = tot = 0
        with torch.no_grad():
            for imgs, labels in tqdm(val_loader, desc=f"Fold {fold} Val"):
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                with autocast():
                    logits = model(imgs)
                    loss = criterion(logits, labels)
                val_loss += loss.item()
                preds = logits.argmax(1)
                corr += (preds==labels).sum().item()
                tot += labels.size(0)
        
        val_loss = val_loss/len(val_loader)
        val_acc = corr/tot
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        print(f"Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}")

        # Learning rate scheduler update
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr:.6f}")

        # Early stopping kontrolü
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            best_epoch = epoch
            no_improve = 0
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), f'best_model_fold_{fold}.pth')
            else:
                torch.save(model.state_dict(), f'best_model_fold_{fold}.pth')
            print(f"Yeni en iyi model kaydedildi (Val Acc: {val_acc:.4f})")
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"\nEarly stopping triggered! No improvement for {patience} epochs.")
                print(f"Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
                break

    # Fold sonuçlarını kaydet
    fold_metrics = {
        'fold': fold,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc,
        'best_epoch': best_epoch,
        'hyperparams': current_hyperparams
    }
    fold_results.append(fold_metrics)
    
    # Fold1 sonuçlarını sakla
    if fold == 1:
        fold1_metrics = {
            'val_acc': best_val_acc,
            'val_loss': val_losses[best_epoch-1]
        }

    # Fold sonuçlarını görselleştir
    plt.figure(figsize=(15, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss - Fold {fold}')
    plt.legend()
    
    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Val')
    plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'Training and Validation Accuracy - Fold {fold}')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Final sonuçları
fold_accs = [result['best_val_acc'] for result in fold_results]
print(f"\n2-Fold Sonuçları:")
for i, result in enumerate(fold_results, 1):
    hp = result['hyperparams']
    print(f"Fold {i}: {result['best_val_acc']:.4f} (lr={hp['lr']}, wd={hp['weight_decay']}, dropout={hp['dropout']})")
print(f"Ortalama Acc: {np.mean(fold_accs):.4f} ± {np.std(fold_accs):.4f}")

# Final test - Test seti için ayrı bir değerlendirme
print("\nFinal Test Results on Test Set:")

# En iyi fold'u bul
best_fold = np.argmax(fold_accs) + 1
print(f"Using best model from Fold {best_fold}")

# Test seti için son 20% veriyi kullan
test_size = int(len(limited_paths) * 0.2)
test_indices = np.random.choice(len(limited_paths), test_size, replace=False)
test_paths = limited_paths[test_indices]
test_labels = limited_labels[test_indices]

# Test dataset oluştur
test_ds = ADNI_Dataset(test_paths, test_labels, augment=False)

# Test seti için DataLoader
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

# En iyi modeli yükle
test_model = Inception3DNet_SE_Attention(3, dropout_rate=fold_results[best_fold-1]['hyperparams']['dropout']).to(DEVICE)
if torch.cuda.device_count() > 1:
    test_model = nn.DataParallel(test_model)
test_model.load_state_dict(torch.load(f'best_model_fold_{best_fold}.pth'))
test_model.eval()

# Test seti değerlendirmesi
all_test_targets = []
all_test_preds = []
all_test_probs = []
all_test_losses = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing on Test Set"):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        with autocast():
            logits = test_model(imgs)
            probs = F.softmax(logits, dim=1)
            loss = criterion(logits, labels)
        
        preds = logits.argmax(1)
        all_test_targets.extend(labels.cpu().numpy())
        all_test_preds.extend(preds.cpu().numpy())
        all_test_probs.extend(probs.cpu().numpy())
        all_test_losses.append(loss.item())

# Test seti metrikleri
test_loss = np.mean(all_test_losses)
test_acc = np.mean(np.array(all_test_preds) == np.array(all_test_targets))
print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

# Per-class metrics
print("\nPer-class Metrics:")
class_names = ["CN", "MCI", "AD"]
print(classification_report(all_test_targets, all_test_preds, target_names=class_names))

# ROC ve AUC hesapla
auc_scores = []
plt.figure(figsize=(10, 8))
for i, label in enumerate(class_names):
    auc = roc_auc_score(np.array(all_test_targets) == i, np.array(all_test_probs)[:, i])
    auc_scores.append(auc)
    fpr, tpr, _ = roc_curve(np.array(all_test_targets) == i, np.array(all_test_probs)[:, i])
    plt.plot(fpr, tpr, label=f'{label} (AUC = {auc:.3f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves - Test Set')
plt.legend()
plt.savefig('roc_curves.png')
plt.close()

# Confusion Matrix
cm = confusion_matrix(all_test_targets, all_test_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    xticklabels=class_names,
    yticklabels=class_names,
    cmap='Blues'
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Test Set Confusion Matrix")
plt.savefig('confusion_matrix.png')
plt.close()

# Per-class accuracy
class_accuracies = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(10, 6))
plt.bar(class_names, class_accuracies)
plt.title('Per-class Accuracy')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
for i, acc in enumerate(class_accuracies):
    plt.text(i, acc + 0.02, f'{acc:.3f}', ha='center')
plt.savefig('per_class_accuracy.png')
plt.close()

# Learning curves for best fold
best_fold_metrics = fold_results[best_fold-1]
plt.figure(figsize=(15, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(best_fold_metrics['train_losses'], label='Train')
plt.plot(best_fold_metrics['val_losses'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'Learning Curves - Best Fold ({best_fold})')
plt.legend()

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(best_fold_metrics['train_accs'], label='Train')
plt.plot(best_fold_metrics['val_accs'], label='Val')
plt.axvline(x=best_fold_metrics['best_epoch'], color='r', linestyle='--', 
            label=f'Best Epoch ({best_fold_metrics["best_epoch"]})')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title(f'Accuracy Curves - Best Fold ({best_fold})')
plt.legend()

plt.tight_layout()
plt.savefig('best_fold_curves.png')
plt.close()

# Final rapor
print("\nFinal Test Results Summary:")
print(f"Best Fold: {best_fold}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print("\nPer-class AUC Scores:")
for i, (label, auc) in enumerate(zip(class_names, auc_scores)):
    print(f"{label}: {auc:.4f}")
print("\nPer-class Accuracies:")
for i, (label, acc) in enumerate(zip(class_names, class_accuracies)):
    print(f"{label}: {acc:.4f}")

# Model ve hiperparametre bilgilerini kaydet
with open('test_results.txt', 'w') as f:
    f.write("Test Results Summary\n")
    f.write("===================\n\n")
    f.write(f"Best Fold: {best_fold}\n")
    f.write(f"Test Loss: {test_loss:.4f}\n")
    f.write(f"Test Accuracy: {test_acc:.4f}\n\n")
    f.write("Hyperparameters:\n")
    hp = fold_results[best_fold-1]['hyperparams']
    f.write(f"Learning Rate: {hp['lr']}\n")
    f.write(f"Weight Decay: {hp['weight_decay']}\n")
    f.write(f"Dropout: {hp['dropout']}\n\n")
    f.write("Per-class Metrics:\n")
    f.write(classification_report(all_test_targets, all_test_preds, target_names=class_names))
    f.write("\nPer-class AUC Scores:\n")
    for i, (label, auc) in enumerate(zip(class_names, auc_scores)):
        f.write(f"{label}: {auc:.4f}\n")
    f.write("\nPer-class Accuracies:\n")
    for i, (label, acc) in enumerate(zip(class_names, class_accuracies)):
        f.write(f"{label}: {acc:.4f}\n")