**‚ÄúExplainable Lung Cancer Classification via Hypergraph Neural Networks Modeling Inter-Nodule Relationships‚Äù**

In [None]:
"""
===============================================================================
CELL 1: MOUNT DRIVE & SETUP
===============================================================================
"""

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.append('/content/drive/MyDrive/lung_cancer_urop')

print("=" * 80)
print("‚úì Drive mounted and path added")
print("=" * 80)

In [None]:
"""
===============================================================================
CELL 2: INSTALL DEPENDENCIES
===============================================================================
"""

!pip install -q nibabel SimpleITK torch-geometric wandb

print("=" * 80)
print("‚úì Dependencies installed")
print("=" * 80)

In [None]:
"""
===============================================================================
CELL 3: IMPORT MODULES
===============================================================================
"""

import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from src.config import ExperimentConfig
from src.utils import set_global_seed, setup_directories
from src.preprocessing import AdvancedPreprocessor
from src.hypergraph import HypergraphConstructor
from src.models import HypergraphNeuralNetwork
from src.dataset import LungNoduleHypergraphDataset, collate_hypergraph_batch
from src.early_stopping import EarlyStopping
from src.trainer import HGNNTrainer
from src.visualization import ResultsVisualizer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

print("=" * 80)
print("‚úì ALL MODULES IMPORTED SUCCESSFULLY")
print("=" * 80)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print("=" * 80)

In [None]:
"""
===============================================================================
CELL 4: INITIALIZE CONFIGURATION
===============================================================================
"""

config = ExperimentConfig()
set_global_seed(config.random_seed)

print("=" * 80)
print("CONFIGURATION INITIALIZED")
print("=" * 80)
print(f"Experiment: {config.experiment_name}")
print(f"Random Seed: {config.random_seed}")
print(f"Num Patients: {config.num_patients}")
print(f"Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Num Epochs: {config.num_epochs}")
print("=" * 80)

In [None]:
"""
===============================================================================
CELL 5: SETUP DIRECTORIES & LOAD DATA
===============================================================================
"""

BASE_PATH = Path(config.base_path)
SUBSET_PATH = BASE_PATH / config.subset_name
ANNOTATIONS_PATH = BASE_PATH / config.annotations_file
METADATA_PATH = BASE_PATH / config.metadata_file

OUTPUT_PATH, MODELS_PATH, RESULTS_PATH, LOGS_PATH, CONFIG_PATH = setup_directories(
    BASE_PATH, config.experiment_name
)

print("=" * 80)
print("DIRECTORY STRUCTURE")
print("=" * 80)
print(f"‚úì Base Path: {BASE_PATH}")
print(f"‚úì Output Path: {OUTPUT_PATH}")
print("=" * 80)

# Verify
assert SUBSET_PATH.exists(), f"Subset path not found: {SUBSET_PATH}"
assert ANNOTATIONS_PATH.exists(), f"Annotations not found: {ANNOTATIONS_PATH}"
assert METADATA_PATH.exists(), f"Metadata not found: {METADATA_PATH}"

# Load CSVs
annotations_df = pd.read_csv(ANNOTATIONS_PATH)
metadata_df = pd.read_csv(METADATA_PATH)

print(f"\n‚úì Annotations: {annotations_df.shape}")
print(f"‚úì Metadata: {metadata_df.shape}")

# Get patient files
# Get patient files from multiple subsets
patient_files = []
subsets_to_load = ["subset01", "subset02", "subset03"]  # Add "subset04" later

print("\nLOADING SUBSETS:")
print("-" * 60)

for subset_name in subsets_to_load:
    subset_path = BASE_PATH / subset_name
    if subset_path.exists():
        files = sorted(list(subset_path.glob("*.nii.gz")))
        patient_files.extend(files)

        # Determine patient range
        if subset_name == "subset01":
            patient_range = "1-160"
        elif subset_name == "subset02":
            patient_range = "161-320"
        elif subset_name == "subset03":
            patient_range = "321-480"
        else:
            patient_range = "unknown"

        print(f"‚úì {subset_name}: {len(files)} files (Patients {patient_range})")
    else:
        print(f"‚ùå {subset_name}: NOT FOUND at {subset_path}")

print("-" * 60)
print(f"‚úì TOTAL LOADED: {len(patient_files)} patient files")

# Apply num_patients limit if needed (optional)
if config.num_patients < len(patient_files):
    patient_files = patient_files[:config.num_patients]
    print(f"‚ö†Ô∏è Limited to first {config.num_patients} patients (config setting)")

# Save config
config.save_config(str(CONFIG_PATH / "experiment_config.yaml"))
print(f"‚úì Config saved")
print("=" * 80)

In [None]:
"""
===============================================================================
CELL 6: INITIALIZE EXPERIMENT TRACKING
===============================================================================
"""

# Use local Colab storage instead of Drive for TensorBoard
tensorboard_local_dir = Path('/content/logs/tensorboard')
tensorboard_local_dir.mkdir(parents=True, exist_ok=True)
tensorboard_writer = SummaryWriter(log_dir=str(tensorboard_local_dir))

print("=" * 80)
print("EXPERIMENT TRACKING")
print("=" * 80)
print(f"‚úì TensorBoard: {tensorboard_local_dir} (local - avoiding Drive disconnects)")

try:
    import wandb
    wandb.login()
    wandb.init(
        project=config.project_name,
        name=config.experiment_name,
        config=vars(config),
        tags=["hgnn", "lung-cancer", "baseline"]
    )
    print(f"‚úì W&B: {wandb.run.url}")
    USING_WANDB = True
except:
    print("‚ö† W&B not available - using TensorBoard only")
    USING_WANDB = False

print("=" * 80)

In [None]:
"""
===============================================================================
CELL 7: INITIALIZE PREPROCESSING & HYPERGRAPH
===============================================================================
"""

preprocessor = AdvancedPreprocessor(
    target_spacing=config.target_spacing,
    target_size=config.patch_size
)

hypergraph_constructor = HypergraphConstructor(
    k_neighbors=config.k_neighbors,
    spatial_threshold=config.spatial_threshold,
    feature_threshold=config.feature_similarity_threshold
)

print("=" * 80)
print("PREPROCESSING & HYPERGRAPH INITIALIZED")
print("=" * 80)
print(f"‚úì Target spacing: {config.target_spacing}")
print(f"‚úì Patch size: {config.patch_size}")
print(f"‚úì k-neighbors: {config.k_neighbors}")
print("=" * 80)

In [None]:
"""
=======================================================
CELL 7.5: VISUALIZE PREPROCESSING
=======================================================
"""
from pathlib import Path

subset01 = Path("/content/drive/MyDrive/duke_lung_data/subset01")
subset02 = Path("/content/drive/MyDrive/duke_lung_data/subset02")
subset03 = Path("/content/drive/MyDrive/duke_lung_data/subset3")

print("Total patient files:", len(patient_files))
assert all(p.exists() for p in patient_files)


In [None]:
"""
===============================================================================
CELL 8: ROBUST DATASET & DATALOADERS (COLAB-FRIENDLY)
===============================================================================
"""
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
import torch

print("=" * 80)
print("CREATING ROBUST DATASET")
print("=" * 80)

# Define collate function inline (no import needed)
def collate_hypergraph_batch(batch):
    """Filter out None values and batch remaining data."""
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return Batch.from_data_list(batch)

# Create dataset
dataset = LungNoduleHypergraphDataset(
    patient_files=patient_files,
    annotations_df=annotations_df,
    preprocessor=preprocessor,
    hypergraph_constructor=hypergraph_constructor,
    augment=True
)

# Filter valid samples
print("\nüîç Filtering valid samples...")
valid_indices = []
error_count = 0

for idx in range(len(dataset)):
    try:
        data = dataset[idx]
        if data is not None:
            valid_indices.append(idx)
        else:
            error_count += 1
    except Exception as e:
        error_count += 1
        if error_count <= 5:  # Print first 5 errors only
            print(f"‚ö†Ô∏è Sample {idx} failed: {str(e)[:100]}")

print(f"‚úì Valid samples: {len(valid_indices)} / {len(dataset)}")
print(f"‚úó Invalid/missing samples: {error_count}")

if len(valid_indices) == 0:
    raise RuntimeError("‚ùå No valid samples found! Check your data paths and files.")

# Filtered dataset wrapper
class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, valid_indices):
        self.base_dataset = base_dataset
        self.valid_indices = valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        return self.base_dataset[self.valid_indices[idx]]

filtered_dataset = FilteredDataset(dataset, valid_indices)

# Split
train_size = int(config.train_split * len(filtered_dataset))
val_size = len(filtered_dataset) - train_size

generator = torch.Generator().manual_seed(config.random_seed)
train_dataset, val_dataset = torch.utils.data.random_split(
    filtered_dataset, [train_size, val_size], generator=generator
)

print(f"\n‚úì Train samples: {train_size}")
print(f"‚úì Val samples: {val_size}")

# DataLoaders with custom collate
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
    collate_fn=collate_hypergraph_batch
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    collate_fn=collate_hypergraph_batch
)

print(f"‚úì Train batches: {len(train_loader)}")
print(f"‚úì Val batches: {len(val_loader)}")
print("=" * 80)

In [None]:
"""
================================================================================
CELL 8.5: FORCE CLASS BALANCE (OVERSAMPLING) - ROBUST VERSION
Run this RIGHT AFTER loading your dataset (Cell 8) and BEFORE training (Cell 11).
================================================================================
"""
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from collections import Counter

print("=" * 80)
print("üöë APPLYING ROBUST OVERSAMPLING (Wrapper Method)")
print("=" * 80)

# 1. Extract Labels safely
# We iterate over the train_dataset as-is, treating it as a black box.
num_samples = len(train_dataset)
all_indices = list(range(num_samples))
train_labels = []

print(f"Scanning {num_samples} patients for labels...")

for i in all_indices:
    try:
        # Get the data object (graph)
        data = train_dataset[i]

        # Extract label safely (handle both 0-d and 1-d tensors)
        if hasattr(data, 'y'):
            label_val = data.y.item() if data.y.numel() == 1 else data.y[0].item()
            train_labels.append(label_val)
        else:
            # Fallback for weird edge cases
            train_labels.append(0)

    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not read label for index {i}: {e}")
        train_labels.append(0) # Assume Benign if broken

train_labels = np.array(train_labels)

# 2. Check Distribution
unique, counts = np.unique(train_labels, return_counts=True)
print(f"\nOriginal Distribution:")
for cls, c in zip(unique, counts):
    print(f"  Class {cls}: {c} patients")

# 3. Calculate INDICES to repeat
# We want to match the majority class count
class_counts = Counter(train_labels)
if len(class_counts) > 0:
    max_count = max(class_counts.values())
else:
    max_count = 0
    print("‚ùå CRITICAL: No labels found.")

final_indices = []

print(f"\nTargeting {max_count} samples per class...")

for cls in unique:
    # Get indices for this class
    cls_indices = [i for i, label in enumerate(train_labels) if label == cls]

    if len(cls_indices) == 0: continue

    # Math to fill the gap
    n_current = len(cls_indices)
    n_repeat = max_count // n_current
    n_remainder = max_count % n_current

    # Add full repeats
    final_indices.extend(cls_indices * n_repeat)
    # Add random remainder to hit exact target
    final_indices.extend(cls_indices[:n_remainder])

    print(f"  Class {cls}: Oversampled from {n_current} -> {len(cls_indices * n_repeat) + n_remainder}")

# 4. Create Safe Balanced Loader
if len(final_indices) > 0:
    # üö® THE FIX: Wrap 'train_dataset' DIRECTLY. Do not use .dataset
    # This creates a "Subset of a Subset", which preserves all previous filters/splits.
    balanced_train_dataset = Subset(train_dataset, final_indices)

    # Overwrite the loader
    train_loader = DataLoader(
        balanced_train_dataset,
        batch_size=train_loader.batch_size,
        shuffle=True,                       # Shuffle is mandatory here
        num_workers=0,
        collate_fn=collate_hypergraph_batch
    )

    print("\n‚úÖ SUCCESS: Balanced Train Loader Ready.")
    print(f"  Total Training Samples: {len(balanced_train_dataset)}")

    # 5. Sanity Check (Test one batch)
    try:
        test_batch = next(iter(train_loader))
        print(f"  ‚úì Sanity Check Passed: Loaded a batch of {test_batch.num_graphs} graphs.")
    except Exception as e:
        print(f"  ‚ùå Sanity Check Failed: {e}")

else:
    print("\n‚ùå FAILED: No indices generated.")

print("=" * 80)

In [None]:
"""
===============================================================================
CELL 9: INITIALIZE MODEL
===============================================================================
"""

# Get feature dimension from sample
sample_data = None
for i in range(len(train_dataset)):
    data = train_dataset.dataset[train_dataset.indices[i]]
    if data is not None:
        sample_data = data
        break

assert sample_data is not None, "No valid samples found!"

in_channels = sample_data.x.shape[1]

print("=" * 80)
print("MODEL INITIALIZATION")
print("=" * 80)
print(f"‚úì Input features: {in_channels}")

device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
print(f"‚úì Device: {device}")

model = HypergraphNeuralNetwork(
    in_channels=in_channels,
    hidden_channels=config.hidden_channels,
    num_classes=config.num_classes,
    num_layers=config.num_layers,
    dropout=config.dropout
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Total parameters: {total_params:,}")
print(f"‚úì Trainable parameters: {trainable_params:,}")

if USING_WANDB:
    wandb.config.update({
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'input_features': in_channels
    })

print("=" * 80)

In [None]:
"""
===============================================================================
CELL 10: TRAINING SETUP (WITH MODULE RELOAD)
===============================================================================
"""

# Reload config to pick up changes
import importlib
import src.config
importlib.reload(src.config)
from src.config import ExperimentConfig

# Recreate config with new attributes
config = ExperimentConfig()

# Now initialize trainer
trainer = HGNNTrainer(
    model=model,
    device=device,
    output_dir=MODELS_PATH,
    patience=config.patience,
    config=config
)

trainer.setup_training(
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

print("=" * 80)
print("TRAINER INITIALIZED")
print("=" * 80)
print(f"‚úì Optimizer: AdamW")
print(f"‚úì Learning rate: {config.learning_rate}")
print(f"‚úì LR factor: {config.lr_factor}")
print(f"‚úì LR patience: {config.lr_patience}")
print(f"‚úì Mixed precision: {config.use_mixed_precision}")
print(f"‚úì Early stopping patience: {config.patience}")
print("=" * 80)

In [None]:
"""
===============================================================================
FINAL TRAINING CELL: PROFESSIONAL SAMPLER + SAFE WEIGHTS  11
===============================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import copy

print("=" * 80)
print("üß™ PROFESSIONAL RUN: WEIGHTED SAMPLER + SAFE WEIGHTS")
print("=" * 80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# ==============================================================================
# 1. EXTRACT TRAINING LABELS FOR SAMPLER
# ==============================================================================
print("\n‚öñÔ∏è Configuring Weighted Sampler...")

train_labels = []
valid_train_indices = []

for i in range(len(train_dataset)):
    data = train_dataset[i]
    if data is not None:
        train_labels.append(data.y.item())
        valid_train_indices.append(i)

if len(train_labels) == 0:
    raise ValueError("‚ùå No valid training samples found!")

class_counts = Counter(train_labels)
n_benign = class_counts[0]
n_cancer = class_counts[1]

print(f"Training Distribution:")
print(f"  - Benign: {n_benign}")
print(f"  - Cancer: {n_cancer}")
print(f"  - Ratio: {n_benign/n_cancer:.2f}:1")

# ==============================================================================
# 2. CREATE WEIGHTED SAMPLER
# ==============================================================================
# Calculate inverse frequency weights
weight_benign = 1.0 / n_benign
weight_cancer = 1.0 / n_cancer

# Assign weights to each sample
samples_weight = torch.tensor(
    [weight_cancer if label == 1 else weight_benign for label in train_labels]
)

# Create sampler
sampler = WeightedRandomSampler(
    weights=samples_weight.type(torch.DoubleTensor),
    num_samples=len(samples_weight),
    replacement=True
)

print(f"‚úÖ Sampler configured for {len(train_labels)} samples")

# ==============================================================================
# 3. CREATE DATA LOADERS
# ==============================================================================
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    sampler=sampler,  # ‚Üê This replaces shuffle=True
    collate_fn=collate_hypergraph_batch,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=collate_hypergraph_batch,
    num_workers=2,
    pin_memory=True
)

print(f"‚úÖ Loaders created")
print(f"  - Train batches: {len(train_loader)}")
print(f"  - Val batches: {len(val_loader)}")

# ==============================================================================
# 4. INITIALIZE MODEL
# ==============================================================================
print("\nüèóÔ∏è Initializing Model...")

# Get feature dimensions from first batch
try:
    sample = next(iter(train_loader))
    in_feats = sample.x.shape[1]
    print(f"Input features: {in_feats}")
except Exception as e:
    print(f"‚ö†Ô∏è Could not detect features, using default 128")
    in_feats = 128

model = HypergraphNeuralNetwork(
    in_channels=in_feats,
    hidden_channels=256,
    num_classes=2,
    dropout=0.5
).to(device)

print(f"‚úÖ Model initialized")
print(f"  - Parameters: {sum(p.numel() for p in model.parameters()):,}")

# ==============================================================================
# 5. SETUP LOSS & OPTIMIZER (SAFE WEIGHTS)
# ==============================================================================
# ‚ö†Ô∏è CRITICAL: Sampler balances data, so we use GENTLE loss weight
# Weight 1.1 = slight nudge, not panic
weights = torch.tensor([1.0, 1.1]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-2
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=3
)

print(f"\n‚úÖ Training Setup:")
print(f"  - Loss weights: Benign=1.0, Cancer=1.5")
print(f"  - Learning rate: 1e-5")
print(f"  - Optimizer: AdamW")
print("=" * 80)

# ==============================================================================
# 6. TRAINING LOOP
# ==============================================================================
num_epochs = 50
best_f1 = 0.0
best_model_state = None

print(f"\nüöÄ Training for {num_epochs} epochs...")
print("=" * 80)

for epoch in range(1, num_epochs + 1):
    # ============= TRAINING =============
    model.train()
    train_loss = 0.0

    for batch in train_loader:
        if batch is None:
            continue

        batch = batch.to(device)
        optimizer.zero_grad()

        try:
            out = model(batch.x, batch.edge_index, batch.batch)
        except:
            out = model(batch.x, batch.edge_index)

        if isinstance(out, tuple):
            out = out[0]

        loss = criterion(out, batch.y)
        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # ============= VALIDATION =============
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            if batch is None:
                continue

            batch = batch.to(device)

            try:
                out = model(batch.x, batch.edge_index, batch.batch)
            except:
                out = model(batch.x, batch.edge_index)

            if isinstance(out, tuple):
                out = out[0]

            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    # ============= METRICS =============
    if len(all_preds) > 0:
        report = classification_report(
            all_labels,
            all_preds,
            output_dict=True,
            zero_division=0
        )

        cancer_f1 = report['1']['f1-score']
        cancer_recall = report['1']['recall']
        cancer_precision = report['1']['precision']
        accuracy = report['accuracy']

        # Update scheduler
        scheduler.step(cancer_f1)

        # Print progress
        print(f"Epoch {epoch:02d}/{num_epochs} | "
              f"Loss: {avg_train_loss:.4f} | "
              f"Acc: {accuracy:.3f} | "
              f"Recall: {cancer_recall:.3f} | "
              f"Prec: {cancer_precision:.3f} | "
              f"F1: {cancer_f1:.3f}")

        # Save best model
        if cancer_f1 >= best_f1:
            best_f1 = cancer_f1
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"  ‚úì Best F1: {best_f1:.3f}")

print("=" * 80)
print("‚úÖ Training Complete!")
print("=" * 80)

# ==============================================================================
# 7. LOAD BEST MODEL & FINAL EVALUATION
# ==============================================================================
if best_model_state is not None:
    model.load_state_dict(best_model_state)

    # Save model
    save_path = str(MODELS_PATH / "best_model_professional.pth")
    torch.save(model.state_dict(), save_path)
    print(f"\n‚úÖ Best model saved: {save_path}")

    # Final evaluation
    print("\n" + "=" * 80)
    print("üìä FINAL EVALUATION ON VALIDATION SET")
    print("=" * 80)

    model.eval()
    final_preds = []
    final_labels = []

    with torch.no_grad():
        for batch in val_loader:
            if batch is None:
                continue

            batch = batch.to(device)

            try:
                out = model(batch.x, batch.edge_index, batch.batch)
            except:
                out = model(batch.x, batch.edge_index)

            if isinstance(out, tuple):
                out = out[0]

            final_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_labels.extend(batch.y.cpu().numpy())

    # Classification Report
    print("\nClassification Report:")
    print(classification_report(
        final_labels,
        final_preds,
        target_names=['Benign', 'Cancer'],
        zero_division=0
    ))

    # Confusion Matrix
    cm = confusion_matrix(final_labels, final_preds)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Greens',
        cbar=False,
        xticklabels=['Benign', 'Cancer'],
        yticklabels=['Benign', 'Cancer'],
        annot_kws={'size': 16}
    )
    plt.xlabel('Predicted', fontsize=12)
    plt.ylabel('Actual', fontsize=12)
    plt.title('Professional Model - Confusion Matrix', fontsize=14)
    plt.tight_layout()
    plt.show()

    print("=" * 80)
    print("üéâ PROFESSIONAL TRAINING COMPLETE!")
    print("=" * 80)
else:
    print("\n‚ö†Ô∏è No model was saved (no improvement detected)")

print("\nüí° Next Step: Apply threshold tuning (0.46) for even better recall!")

In [None]:
"""
FEATURE QUALITY DIAGNOSTIC
"""
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

print("üî¨ ANALYZING FEATURE QUALITY...")

# Collect features from validation set
all_features = []
all_labels = []

for i in range(len(val_dataset)):
    data = val_dataset[i]
    if data is not None:
        all_features.append(data.x.cpu().numpy())
        all_labels.extend([data.y.item()] * len(data.x))

# Flatten to (n_samples, n_features)
features_flat = np.vstack(all_features)
labels_flat = np.array(all_labels)

print(f"Total nodules: {len(labels_flat)}")
print(f"  Benign: {sum(labels_flat == 0)}")
print(f"  Cancer: {sum(labels_flat == 1)}")

# Check for zero-variance features
feature_stds = features_flat.std(axis=0)
zero_var = sum(feature_stds < 1e-6)
print(f"\n‚ö†Ô∏è Zero-variance features: {zero_var}/{features_flat.shape[1]}")

# Check for NaN/Inf
nan_count = np.isnan(features_flat).sum()
inf_count = np.isinf(features_flat).sum()
print(f"‚ö†Ô∏è NaN values: {nan_count}")
print(f"‚ö†Ô∏è Inf values: {inf_count}")

# PCA visualization
pca = PCA(n_components=2)
features_2d = pca.fit_transform(features_flat)

plt.figure(figsize=(10, 6))
plt.scatter(features_2d[labels_flat==0, 0], features_2d[labels_flat==0, 1],
            alpha=0.5, label='Benign', s=30, c='blue')
plt.scatter(features_2d[labels_flat==1, 0], features_2d[labels_flat==1, 1],
            alpha=0.7, label='Cancer', s=50, c='red', marker='^')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Feature Space Visualization (PCA)')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print(f"\nüìä PCA explained variance: {pca.explained_variance_ratio_.sum():.2%}")

In [None]:
#diagnosti
"""
DEBUG NODULE EXTRACTION
"""
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

# Get a cancer case
cancer_row = annotations_df[annotations_df['Malignant_lbl'] == 1].iloc[0]
patient_id = str(cancer_row['patient-id']).strip()

print(f"Testing Cancer Patient: {patient_id}")
print(f"World coords: X={cancer_row['coordX']}, Y={cancer_row['coordY']}, Z={cancer_row['coordZ']}")
print(f"Diameter: {cancer_row['w']}mm")

# Find the file
patient_file = None
for f in patient_files:
    file_id = f.stem.replace('.nii', '')
    if file_id == patient_id:
        patient_file = f
        break

if patient_file is None:
    print(f"‚ùå File not found for {patient_id}")
else:
    print(f"‚úì Found file: {patient_file.name}")

    # Load image
    image_sitk = sitk.ReadImage(str(patient_file))
    image_array = sitk.GetArrayFromImage(image_sitk)

    print(f"\nImage shape: {image_array.shape} (Z, Y, X)")

    # Try coordinate transform
    try:
        point_world = (float(cancer_row['coordX']),
                       float(cancer_row['coordY']),
                       float(cancer_row['coordZ']))

        idx_voxel = image_sitk.TransformPhysicalPointToIndex(point_world)
        x_voxel, y_voxel, z_voxel = idx_voxel

        print(f"\nTransformed to voxel coords:")
        print(f"  X: {x_voxel} (max: {image_array.shape[2]-1})")
        print(f"  Y: {y_voxel} (max: {image_array.shape[1]-1})")
        print(f"  Z: {z_voxel} (max: {image_array.shape[0]-1})")

        # Check if within bounds
        if (0 <= x_voxel < image_array.shape[2] and
            0 <= y_voxel < image_array.shape[1] and
            0 <= z_voxel < image_array.shape[0]):

            print("‚úì Coordinates within bounds")

            # Extract 32x32x32 patch
            patch_size = 16
            z_start = max(0, z_voxel - patch_size)
            z_end = min(image_array.shape[0], z_voxel + patch_size)
            y_start = max(0, y_voxel - patch_size)
            y_end = min(image_array.shape[1], y_voxel + patch_size)
            x_start = max(0, x_voxel - patch_size)
            x_end = min(image_array.shape[2], x_voxel + patch_size)

            patch = image_array[z_start:z_end, y_start:y_end, x_start:x_end]

            print(f"\nExtracted patch shape: {patch.shape}")
            print(f"Patch HU range: [{patch.min():.1f}, {patch.max():.1f}]")
            print(f"Patch mean: {patch.mean():.1f}")
            print(f"Patch std: {patch.std():.1f}")

            # Visualize center slice
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            # Axial slice (XY plane at nodule center)
            axes[0].imshow(image_array[z_voxel, :, :], cmap='gray', vmin=-1000, vmax=400)
            axes[0].plot(x_voxel, y_voxel, 'r+', markersize=15, markeredgewidth=2)
            axes[0].set_title(f'Axial Slice (Z={z_voxel})')
            axes[0].set_xlabel('X')
            axes[0].set_ylabel('Y')

            # Coronal slice (XZ plane)
            axes[1].imshow(image_array[:, y_voxel, :], cmap='gray', vmin=-1000, vmax=400)
            axes[1].plot(x_voxel, z_voxel, 'r+', markersize=15, markeredgewidth=2)
            axes[1].set_title(f'Coronal Slice (Y={y_voxel})')
            axes[1].set_xlabel('X')
            axes[1].set_ylabel('Z')

            # Sagittal slice (YZ plane)
            axes[2].imshow(image_array[:, :, x_voxel], cmap='gray', vmin=-1000, vmax=400)
            axes[2].plot(y_voxel, z_voxel, 'r+', markersize=15, markeredgewidth=2)
            axes[2].set_title(f'Sagittal Slice (X={x_voxel})')
            axes[2].set_xlabel('Y')
            axes[2].set_ylabel('Z')

            plt.tight_layout()
            plt.show()

            # Check if the patch looks like a nodule
            if patch.std() < 50:
                print("\n‚ùå WARNING: Patch has very low variation - might be empty space!")
            elif patch.mean() < -500:
                print("\n‚ùå WARNING: Patch is mostly air (mean HU < -500)")
            else:
                print("\n‚úì Patch looks reasonable")

        else:
            print("‚ùå Coordinates OUT OF BOUNDS!")

    except Exception as e:
        print(f"‚ùå Transform failed: {e}")

In [None]:
"""
CHECK ANNOTATION VALIDITY diagnostic
"""
# Let's see if ANY nodules are being found correctly
print("üîç Checking multiple cancer cases...\n")

cancer_cases = annotations_df[annotations_df['Malignant_lbl'] == 1].head(5)

for idx, row in cancer_cases.iterrows():
    patient_id = str(row['patient-id']).strip()
    diameter = row['w']

    print(f"Patient {patient_id}: Diameter {diameter:.2f}mm")

    # Find file
    patient_file = None
    for f in patient_files:
        if f.stem.replace('.nii', '') == patient_id:
            patient_file = f
            break

    if patient_file:
        image_sitk = sitk.ReadImage(str(patient_file))
        image_array = sitk.GetArrayFromImage(image_sitk)

        try:
            point_world = (float(row['coordX']), float(row['coordY']), float(row['coordZ']))
            idx_voxel = image_sitk.TransformPhysicalPointToIndex(point_world)
            x_v, y_v, z_v = idx_voxel

            # Check HU value at exact coordinate
            hu_at_center = image_array[z_v, y_v, x_v]

            # Check 5x5x5 region around it
            patch_3d = image_array[
                max(0, z_v-2):z_v+3,
                max(0, y_v-2):y_v+3,
                max(0, x_v-2):x_v+3
            ]

            print(f"  HU at center: {hu_at_center:.0f}")
            print(f"  5x5x5 region mean: {patch_3d.mean():.0f}")

            # Nodules should be -400 to +100 HU range
            if hu_at_center < -500:
                print(f"  ‚ùå CENTER IS AIR!")
            elif -500 <= hu_at_center <= 100:
                print(f"  ‚úì Looks like tissue/nodule")
            else:
                print(f"  ‚ö†Ô∏è Unusual HU value")

        except Exception as e:
            print(f"  ‚ùå Error: {e}")

    print()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from pathlib import Path
import os
import numpy as np
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

print("üß™ STARTING FINE-TUNING EXPERIMENT (Weight: 3.75)...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. LOAD MODEL (Reset to Original)
# ==============================================================================
target_path = Path("/content/drive/MyDrive/duke_lung_data/outputs/HGNN_LungCancer_MultiClass_v1.0/models/best_model.pth")
if not target_path.exists():
    found = list(Path("/content/drive/MyDrive/duke_lung_data").rglob("best_model.pth"))
    best_model_path = max(found, key=os.path.getmtime) if found else Path("/content/checkpoints/best_model.pth")
else:
    best_model_path = target_path

print(f"üîÑ Resetting weights from: {best_model_path.name}")
checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)

try:
    if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'])
    else: model.load_state_dict(checkpoint)
    model.to(device)
except:
    try: in_feats = dataset[0].x.shape[1]
    except: in_feats = 128
    model = HypergraphNeuralNetwork(num_features=in_feats, hidden_dim=256, num_classes=2, dropout=0.5).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

# ==============================================================================
# 2. THE CONFIGURATION (Weight 3.75)
# ==============================================================================
# A. Use STANDARD Loader (No Oversampling) to protect Accuracy
active_loader = train_loader
print(f"üìâ Using STANDARD Imbalanced Loader ({len(train_loader.dataset)} samples)")

# B. Use MANUAL Weight (3.75) - Testing the Edge
weights = torch.tensor([1.0, 3.75]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)

# ==============================================================================
# 3. TRAINING LOOP
# ==============================================================================
print(f"\n‚ñ∂Ô∏è Fine-tuning for 20 Epochs with Weight 3.75...")
best_f1 = 0.0
final_model_state = None

for epoch in range(1, 21):
    model.train()
    for batch in active_loader: # Standard Loader
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)
        if isinstance(out, tuple): out = out[0]
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    cancer_f1 = report['1']['f1-score']
    cancer_recall = report['1']['recall']
    current_acc = report['accuracy']

    print(f"Epoch {epoch:02d} | Acc: {current_acc:.2f} | Recall: {cancer_recall:.2f} | F1: {cancer_f1:.2f}")

    # Save best F1
    if cancer_f1 >= best_f1:
        best_f1 = cancer_f1
        final_model_state = copy.deepcopy(model.state_dict())

# 4. Save & Plot
if final_model_state:
    model.load_state_dict(final_model_state)
    save_path = "/content/drive/MyDrive/duke_lung_data/best_model_w3_75.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n‚úÖ Model (W=3.75) saved to: {save_path}")

    print("\nüìä GENERATING CONFUSION MATRIX (W=3.75)...")
    model.eval()
    final_preds = []
    final_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            final_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_labels.extend(batch.y.cpu().numpy())

    cm = confusion_matrix(final_labels, final_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Benign', 'Cancer'],
                yticklabels=['Benign', 'Cancer'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Weight 3.75 Model')
    plt.show()

    print(classification_report(final_labels, final_preds, target_names=['Benign', 'Cancer']))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from pathlib import Path
import os
import numpy as np
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

print("üß™ STARTING STRATEGY B TEST (Weight 3.75, NO Oversampling)...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. LOAD MODEL (Reset to Original)
# ==============================================================================
target_path = Path("/content/drive/MyDrive/duke_lung_data/outputs/HGNN_LungCancer_MultiClass_v1.0/models/best_model.pth")
if not target_path.exists():
    found = list(Path("/content/drive/MyDrive/duke_lung_data").rglob("best_model.pth"))
    best_model_path = max(found, key=os.path.getmtime) if found else Path("/content/checkpoints/best_model.pth")
else:
    best_model_path = target_path

print(f"üîÑ Resetting weights from: {best_model_path.name}")
checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)

try:
    if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'])
    else: model.load_state_dict(checkpoint)
    model.to(device)
except:
    try: in_feats = dataset[0].x.shape[1]
    except: in_feats = 128
    model = HypergraphNeuralNetwork(num_features=in_feats, hidden_dim=256, num_classes=2, dropout=0.5).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

# ==============================================================================
# 2. THE CRITICAL SETUP (Strategy B)
# ==============================================================================
# A. Use STANDARD Loader (Real, Imbalanced Data)
# This prevents the "Double Penalty" that caused 11% accuracy.
active_loader = train_loader
print(f"üìâ Using STANDARD Imbalanced Loader ({len(train_loader.dataset)} samples)")

# B. Use CALCULATED Weight (6)
# This forces the model to learn from the imbalance mathematically.
weights = torch.tensor([1.0, 6]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)

# ==============================================================================
# 3. TRAINING LOOP
# ==============================================================================
print(f"\n‚ñ∂Ô∏è Fine-tuning for 20 Epochs with Weight 6...")
best_f1 = 0.0
final_model_state = None

for epoch in range(1, 21):
    model.train()
    for batch in active_loader: # Standard Loader
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)
        if isinstance(out, tuple): out = out[0]
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    cancer_f1 = report['1']['f1-score']
    cancer_recall = report['1']['recall']
    current_acc = report['accuracy']

    print(f"Epoch {epoch:02d} | Acc: {current_acc:.2f} | Recall: {cancer_recall:.2f} | F1: {cancer_f1:.2f}")

    # Save best F1 (Balanced Metric)
    if cancer_f1 >= best_f1:
        best_f1 = cancer_f1
        final_model_state = copy.deepcopy(model.state_dict())

# 4. Save & Plot
if final_model_state:
    model.load_state_dict(final_model_state)
    save_path = "/content/drive/MyDrive/duke_lung_data/best_model_strategy_b_test.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n‚úÖ Strategy B Test model saved to: {save_path}")

    print("\nüìä GENERATING CONFUSION MATRIX...")
    model.eval()
    final_preds = []
    final_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            final_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_labels.extend(batch.y.cpu().numpy())

    cm = confusion_matrix(final_labels, final_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Purples', cbar=False,
                xticklabels=['Benign', 'Cancer'],
                yticklabels=['Benign', 'Cancer'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Strategy B: Weight 3.75 + Standard Data')
    plt.show()

    print(classification_report(final_labels, final_preds, target_names=['Benign', 'Cancer']))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from pathlib import Path
import os
import numpy as np
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

print("üß™ STARTING ADJUSTED RUN (Weight 3.75, LR 5e-5)...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. LOAD MODEL
# ==============================================================================
target_path = Path("/content/drive/MyDrive/duke_lung_data/outputs/HGNN_LungCancer_MultiClass_v1.0/models/best_model.pth")
if not target_path.exists():
    found = list(Path("/content/drive/MyDrive/duke_lung_data").rglob("best_model.pth"))
    best_model_path = max(found, key=os.path.getmtime) if found else Path("/content/checkpoints/best_model.pth")
else:
    best_model_path = target_path

print(f"üîÑ Resetting weights from: {best_model_path.name}")
checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)

try:
    if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'])
    else: model.load_state_dict(checkpoint)
    model.to(device)
except:
    try: in_feats = dataset[0].x.shape[1]
    except: in_feats = 128
    model = HypergraphNeuralNetwork(num_features=in_feats, hidden_dim=256, num_classes=2, dropout=0.5).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

# ==============================================================================
# 2. CONFIGURATION (The Fix)
# ==============================================================================
active_loader = train_loader

# Weight 3.75
weights = torch.tensor([1.0, 3.75]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

# OPTIMIZED LEARNING RATE: 5e-5 (0.00005)
# This is the "Sweet Spot"
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-2)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# ==============================================================================
# 3. TRAINING LOOP
# ==============================================================================
print(f"\n‚ñ∂Ô∏è Fine-tuning for 20 Epochs...")
best_f1 = 0.0
final_model_state = None

for epoch in range(1, 21):
    model.train()
    for batch in active_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)
        if isinstance(out, tuple): out = out[0]
        loss = criterion(out, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Keep safety clip
        optimizer.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    cancer_f1 = report['1']['f1-score']
    cancer_recall = report['1']['recall']
    current_acc = report['accuracy']

    scheduler.step(cancer_f1)

    print(f"Epoch {epoch:02d} | Acc: {current_acc:.2f} | Recall: {cancer_recall:.2f} | F1: {cancer_f1:.2f}")

    if cancer_f1 >= best_f1:
        best_f1 = cancer_f1
        final_model_state = copy.deepcopy(model.state_dict())

# 4. Save
if final_model_state:
    model.load_state_dict(final_model_state)
    save_path = "/content/drive/MyDrive/duke_lung_data/best_model_optimized.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n‚úÖ Optimized model saved to: {save_path}")

    print("\nüìä GENERATING CONFUSION MATRIX...")
    model.eval()
    final_preds = []
    final_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]
            final_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_labels.extend(batch.y.cpu().numpy())

    cm = confusion_matrix(final_labels, final_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Oranges', cbar=False,
                xticklabels=['Benign', 'Cancer'],
                yticklabels=['Benign', 'Cancer'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Optimized LR (5e-5) + Weight 3.75')
    plt.show()

    print(classification_report(final_labels, final_preds, target_names=['Benign', 'Cancer']))

In [None]:
import os
import shutil
from pathlib import Path

# Set your root path
root_path = Path("/content/drive/MyDrive/duke_lung_data")

# 1. Clean up Output Folders (Optional, but good for hygiene)
old_folder = root_path / "outputs" / "HGNN_LungCancer_v1.0"
if old_folder.exists():
    shutil.rmtree(old_folder)
    print(f"üóëÔ∏è Deleted old output folder: {old_folder.name}")

# 2. THE IMPORTANT PART: Delete the Dataset Cache
processed_folder = root_path / "processed"
found_cache = False

if processed_folder.exists():
    for file in processed_folder.glob("*.pth"):
        os.remove(file)
        print(f"üî• DELETED CACHE FILE: {file.name} (Now the code will see 480 patients!)")
        found_cache = True
    for file in processed_folder.glob("*.pt"):
        os.remove(file)
        print(f"üî• DELETED CACHE FILE: {file.name}")
        found_cache = True
else:
    print("‚ö†Ô∏è Could not find 'processed' folder. Check if it's inside 'data' or another subfolder.")

if not found_cache:
    print("‚ÑπÔ∏è No cache file found. You might be ready to run dataset.py immediately.")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

print("üß™ STARTING ADJUSTED RUN (Weight 4.1)...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. INITIALIZE FRESH MODEL
# ==============================================================================
try:
    sample = next(iter(train_loader))
    in_feats = sample.x.shape[1]
except:
    in_feats = 128

model = HypergraphNeuralNetwork(
    in_channels=in_feats,
    hidden_channels=256,
    num_classes=2,
    dropout=0.5
).to(device)

# ==============================================================================
# 2. CONFIGURATION (Weight 4.1)
# ==============================================================================
# Specific request: Weight 4.1
weights = torch.tensor([1.0, 4.1]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

# Standard Safe Learning Rate
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# ==============================================================================
# 3. TRAINING LOOP
# ==============================================================================
print(f"\n‚ñ∂Ô∏è Training for 20 Epochs...")
best_f1 = 0.0
final_model_state = None

for epoch in range(1, 21):
    model.train()
    # Using the Balanced Loader
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)

        # ‚úÖ TUPLE FIX
        if isinstance(out, tuple): out = out[0]

        loss = criterion(out, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)

            # ‚úÖ TUPLE FIX
            if isinstance(out, tuple): out = out[0]

            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    # Metrics
    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    cancer_f1 = report['1']['f1-score']
    cancer_recall = report['1']['recall']
    current_acc = report['accuracy']

    scheduler.step(cancer_f1)

    print(f"Epoch {epoch:02d} | Acc: {current_acc:.2f} | Recall: {cancer_recall:.2f} | F1: {cancer_f1:.2f}")

    if cancer_f1 >= best_f1:
        best_f1 = cancer_f1
        final_model_state = copy.deepcopy(model.state_dict())

# ==============================================================================
# 4. SAVE & REPORT
# ==============================================================================
if final_model_state:
    model.load_state_dict(final_model_state)
    save_path = "/content/drive/MyDrive/duke_lung_data/best_model_w4_1.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n‚úÖ Model (Weight 4.1) saved to: {save_path}")

    print("\nüìä FINAL MATRIX (Weight 4.1):")
    model.eval()
    final_preds = []
    final_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0] # ‚úÖ TUPLE FIX
            final_preds.extend(out.argmax(dim=1).cpu().numpy())
            final_labels.extend(batch.y.cpu().numpy())

    cm = confusion_matrix(final_labels, final_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Oranges', cbar=False,
                xticklabels=['Benign', 'Cancer'],
                yticklabels=['Benign', 'Cancer'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Weight 4.1 Results')
    plt.show()

    print(classification_report(final_labels, final_preds, target_names=['Benign', 'Cancer']))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from sklearn.metrics import recall_score, f1_score, accuracy_score
from collections import Counter
import copy
import matplotlib.pyplot as plt
import seaborn as sns

# Ensure we have the model class available
# (If it's imported from src.models, we use that. If not, we define a compatible placeholder)
try:
    from src.models import HypergraphNeuralNetwork
except ImportError:
    # Fallback definition if import fails
    class HypergraphNeuralNetwork(nn.Module):
        def __init__(self, in_channels, hidden_channels, num_classes, num_layers=3, dropout=0.5):
            super(HypergraphNeuralNetwork, self).__init__()
            self.conv1 = nn.Linear(in_channels, hidden_channels)
            self.conv2 = nn.Linear(hidden_channels, hidden_channels)
            self.classifier = nn.Linear(hidden_channels, num_classes)
            self.dropout = dropout

        def forward(self, x, edge_index, batch=None):
            x = F.relu(self.conv1(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = F.relu(self.conv2(x))
            if batch is not None:
                from torch_geometric.nn import global_mean_pool
                x = global_mean_pool(x, batch)
            else:
                x = x.mean(dim=0, keepdim=True)
            return self.classifier(x)

print("‚ò¢Ô∏è STARTING CLINICAL-GRADE RUN (Focal-Tversky + OneCycle)...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. DEFINE FOCAL-TVERSKY LOSS
# ==============================================================================
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        probs = F.softmax(inputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=2).float()
        tp = (probs * targets_one_hot).sum(dim=0)
        fp = (probs * (1 - targets_one_hot)).sum(dim=0)
        fn = ((1 - probs) * targets_one_hot).sum(dim=0)
        tversky = tp / (tp + self.alpha * fn + (1 - self.alpha) * fp + 1e-7)
        focal_tversky = (1 - tversky) ** self.gamma
        weights = torch.tensor([1.0, 5.0]).to(inputs.device)
        loss = (focal_tversky * weights).mean()
        return loss

# ==============================================================================
# 2. RESTORE DATA & SPLITS
# ==============================================================================
if 'dataset' not in locals():
    print("üîÑ Reloading dataset...")
    # Assuming the Dataset class is already defined in your session
    dataset = LungCancerDataset(root_dir='/content/drive/MyDrive/duke_lung_data')

train_size = int(0.8 * len(dataset))
train_indices = list(range(train_size))
val_indices = list(range(train_size, len(dataset)))

train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)
print(f"‚úÖ Data Ready: {len(train_subset)} Train / {len(val_subset)} Val")

# ==============================================================================
# 3. AGGRESSIVE SAMPLER SETUP
# ==============================================================================
train_labels = [train_subset[i].y.item() for i in range(len(train_subset))]
counts = Counter(train_labels)
n_benign, n_cancer = counts[0], counts[1]

# Make cancer appear 10x more often
cancer_weight = 10.0 / (n_cancer + 1e-6)
benign_weight = 1.0 / (n_benign + 1e-6)
sample_weights = torch.tensor([benign_weight if l == 0 else cancer_weight for l in train_labels])

sampler = WeightedRandomSampler(
    weights=sample_weights.type(torch.DoubleTensor),
    num_samples=len(sample_weights) * 2,
    replacement=True
)

train_loader = DataLoader(train_subset, batch_size=16, sampler=sampler, collate_fn=collate_hypergraph_batch)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, collate_fn=collate_hypergraph_batch)

# ==============================================================================
# 4. MODEL & OPTIMIZER SETUP (FIXED)
# ==============================================================================
try:
    sample = next(iter(train_loader))
    in_feats = sample.x.shape[1]
except:
    in_feats = 128

# ‚úÖ FIX: Explicitly naming arguments to avoid TypeError
model = HypergraphNeuralNetwork(
    in_channels=in_feats,
    hidden_channels=256,
    num_classes=2,
    num_layers=3,      # Standard depth
    dropout=0.5
).to(device)

criterion = FocalTverskyLoss(alpha=0.7, gamma=2.0)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-4,
    epochs=50,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos'
)

# ==============================================================================
# 5. TRAINING LOOP
# ==============================================================================
def find_optimal_threshold(model, loader):
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]

            probs = F.softmax(out, dim=1)[:, 1]
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)

    best_f1 = 0
    best_thresh = 0.5

    for t in np.arange(0.1, 0.9, 0.05):
        preds = (all_probs >= t).astype(int)
        recall = recall_score(all_labels, preds, zero_division=0)
        f1 = f1_score(all_labels, preds, zero_division=0)

        # Prioritize Recall > 0.60
        if recall >= 0.60 and f1 > best_f1:
            best_f1 = f1
            best_thresh = t

    return best_thresh, best_f1

print(f"\nüöÄ Training for 50 Epochs...")
best_val_recall = 0.0
final_state = None

for epoch in range(1, 51):
    model.train()
    train_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)
        if isinstance(out, tuple): out = out[0]

        loss = criterion(out, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    if epoch % 5 == 0:
        thresh, f1 = find_optimal_threshold(model, val_loader)

        model.eval()
        probs, labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                try: out = model(batch.x, batch.edge_index, batch.batch)
                except: out = model(batch.x, batch.edge_index)
                if isinstance(out, tuple): out = out[0]
                probs.extend(F.softmax(out, dim=1)[:, 1].cpu().numpy())
                labels.extend(batch.y.cpu().numpy())

        final_preds = (np.array(probs) >= thresh).astype(int)
        recall = recall_score(labels, final_preds, zero_division=0)
        acc = accuracy_score(labels, final_preds)

        print(f"Epoch {epoch:02d} | Loss: {train_loss/len(train_loader):.4f} | Optimal Thresh: {thresh:.2f} | Acc: {acc:.2f} | Recall: {recall:.2f} | F1: {f1:.2f}")

        if recall >= best_val_recall and acc > 0.50:
            best_val_recall = recall
            final_state = copy.deepcopy(model.state_dict())
            print(f"   üî• New Best Model Saved (Recall: {recall:.2f})")

if final_state:
    model.load_state_dict(final_state)
    torch.save(model.state_dict(), "/content/drive/MyDrive/duke_lung_data/best_model_clinical.pth")
    print("\n‚úÖ Clinical Model Saved.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from sklearn.metrics import recall_score, f1_score, accuracy_score, classification_report
from collections import Counter
import copy
import matplotlib.pyplot as plt

print("‚öñÔ∏è STARTING BALANCED RUN (Sampler + Neutral Loss)...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==============================================================================
# 1. NEUTRAL FOCAL-TVERSKY LOSS (The Fix)
# ==============================================================================
class NeutralFocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=2.0): # Alpha 0.5 = Neutral Balance
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        probs = F.softmax(inputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=2).float()

        tp = (probs * targets_one_hot).sum(dim=0)
        fp = (probs * (1 - targets_one_hot)).sum(dim=0)
        fn = ((1 - probs) * targets_one_hot).sum(dim=0)

        # Tversky Index
        tversky = tp / (tp + self.alpha * fn + (1 - self.alpha) * fp + 1e-7)
        focal_tversky = (1 - tversky) ** self.gamma

        # ‚ö†Ô∏è CRITICAL FIX: NEUTRAL WEIGHTS [1.0, 1.0]
        # The Sampler already balances the data. We don't need extra weights here.
        weights = torch.tensor([1.0, 1.0]).to(inputs.device)
        loss = (focal_tversky * weights).mean()

        return loss

# ==============================================================================
# 2. RESTORE DATA (Standard)
# ==============================================================================
if 'dataset' not in locals():
    print("üîÑ Reloading dataset...")
    dataset = LungCancerDataset(root_dir='/content/drive/MyDrive/duke_lung_data')

train_size = int(0.8 * len(dataset))
train_indices = list(range(train_size))
val_indices = list(range(train_size, len(dataset)))

train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)

# ==============================================================================
# 3. SAMPLER (Keep this!)
# ==============================================================================
train_labels = [train_subset[i].y.item() for i in range(len(train_subset))]
counts = Counter(train_labels)
n_benign, n_cancer = counts[0], counts[1]

# Make cancer appear 10x more often in the batches
cancer_weight = 10.0 / (n_cancer + 1e-6)
benign_weight = 1.0 / (n_benign + 1e-6)
sample_weights = torch.tensor([benign_weight if l == 0 else cancer_weight for l in train_labels])

sampler = WeightedRandomSampler(
    weights=sample_weights.type(torch.DoubleTensor),
    num_samples=len(sample_weights) * 2,
    replacement=True
)

train_loader = DataLoader(train_subset, batch_size=16, sampler=sampler, collate_fn=collate_hypergraph_batch)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, collate_fn=collate_hypergraph_batch)

# ==============================================================================
# 4. TRAINING SETUP
# ==============================================================================
try:
    sample = next(iter(train_loader))
    in_feats = sample.x.shape[1]
except:
    in_feats = 128

model = HypergraphNeuralNetwork(in_feats, 256, 2, num_layers=3, dropout=0.5).to(device)

# Use the Neutral Loss
criterion = NeutralFocalTverskyLoss(alpha=0.5, gamma=2.0)

# Faster LR to escape local minima
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-4,
    epochs=50,
    steps_per_epoch=len(train_loader),
    pct_start=0.3
)

# ==============================================================================
# 5. TRAINING LOOP
# ==============================================================================
def find_best_threshold(model, loader):
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            try: out = model(batch.x, batch.edge_index, batch.batch)
            except: out = model(batch.x, batch.edge_index)
            if isinstance(out, tuple): out = out[0]

            probs = F.softmax(out, dim=1)[:, 1]
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    best_thresh = 0.5
    best_score = 0
    # Optimize for a balance of Recall and Acc
    for t in np.arange(0.1, 0.9, 0.05):
        preds = (np.array(all_probs) >= t).astype(int)
        recall = recall_score(all_labels, preds, zero_division=0)
        acc = accuracy_score(all_labels, preds)
        # Score = Average of Recall and Accuracy
        score = (recall + acc) / 2

        if score > best_score:
            best_score = score
            best_thresh = t

    return best_thresh, best_score

print("üöÄ Training for 50 Epochs (Neutral Weights)...")
best_metric = 0.0
final_state = None

for epoch in range(1, 51):
    model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        try: out = model(batch.x, batch.edge_index, batch.batch)
        except: out = model(batch.x, batch.edge_index)
        if isinstance(out, tuple): out = out[0]

        loss = criterion(out, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    if epoch % 5 == 0:
        thresh, score = find_best_threshold(model, val_loader)

        # Calculate final stats at this threshold
        model.eval()
        probs, labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                try: out = model(batch.x, batch.edge_index, batch.batch)
                except: out = model(batch.x, batch.edge_index)
                if isinstance(out, tuple): out = out[0]
                probs.extend(F.softmax(out, dim=1)[:, 1].cpu().numpy())
                labels.extend(batch.y.cpu().numpy())

        final_preds = (np.array(probs) >= thresh).astype(int)
        recall = recall_score(labels, final_preds, zero_division=0)
        acc = accuracy_score(labels, final_preds)

        print(f"Epoch {epoch:02d} | Best Thresh: {thresh:.2f} | Acc: {acc:.2f} | Recall: {recall:.2f}")

        # We want balanced performance
        if score > best_metric:
            best_metric = score
            final_state = copy.deepcopy(model.state_dict())
            print(f"   üî• New Best Model (Balanced Score: {score:.2f})")

if final_state:
    model.load_state_dict(final_state)
    torch.save(model.state_dict(), "/content/drive/MyDrive/duke_lung_data/best_model_balanced.pth")
    print("\n‚úÖ Balanced Model Saved.")

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
from pathlib import Path
from torch_geometric.data import Data, InMemoryDataset
from tqdm import tqdm

# ==============================================================================
# 1. DEFINE HYPERGRAPH CONSTRUCTOR (Dependency)
# ==============================================================================
class HypergraphConstructor:
    def __init__(self, k_neighbors=5, m_clusters=10):
        self.k = k_neighbors
        self.m = m_clusters

    def construct_graph(self, features):
        # Create a simple star-like hypergraph structure for the feature vector
        # Node features: The radiomics features
        # Edges: Connections between related features

        num_nodes = features.shape[0]

        # 1. Edge Index (All-to-All or KNN)
        # For a single patch, we treat features as nodes
        # Here we create a fully connected graph for simplicity in this context
        rows = []
        cols = []
        for i in range(num_nodes):
            for j in range(num_nodes):
                rows.append(i)
                cols.append(j)

        edge_index = torch.tensor([rows, cols], dtype=torch.long)

        # 2. Node Features
        x = torch.tensor(features, dtype=torch.float).unsqueeze(1) # [18, 1]

        # 3. Dynamic Edges (Hyperedges)
        # We can simulate hyperedges by connecting features to a central "patch node"
        # But for HGNN standard input, we usually need H (Incidence Matrix) or edge_index

        return x, edge_index

# ==============================================================================
# 2. DEFINE DATASET CLASS (The Missing Piece)
# ==============================================================================
class LungCancerDataset(InMemoryDataset):
    def __init__(self, root_dir, transform=None, pre_transform=None):
        self.root_dir = Path(root_dir)
        # Check if AdvancedPreprocessor exists in memory, otherwise error
        if 'AdvancedPreprocessor' not in globals():
             raise NameError("‚ùå AdvancedPreprocessor not found! Please run the Preprocessor cell first.")

        self.preprocessor = AdvancedPreprocessor()
        self.graph_constructor = HypergraphConstructor()

        super(LungCancerDataset, self).__init__(root_dir, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['duke_lung_dataset.pth']

    def download(self):
        pass # Data is already local

    def process(self):
        print(f"üè≠ PROCESSING DATASET from {self.root_dir}...")

        # 1. Load Labels
        csv_path = self.root_dir / "Annotation_Boxes.csv"
        if not csv_path.exists():
            raise FileNotFoundError(f"CSV not found at {csv_path}")

        df = pd.read_csv(csv_path)
        data_list = []

        # 2. Iterate and Process
        print(f"   - Found {len(df)} patients in CSV")

        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Extracting Patches"):
            try:
                pid = str(row['Patient_ID'])
                # Handle path variations (00001 vs 1)
                img_path = list(self.root_dir.glob(f"**/*{pid}*/*.nii.gz"))
                if not img_path:
                    continue
                img_path = img_path[0]

                # A. Load Image
                img_array, spacing, origin, _, _ = self.preprocessor.load_nifti(img_path)
                if img_array is None: continue

                # B. Normalize
                img_norm = self.preprocessor.normalize_hu(img_array)

                # C. Extract Patch (USING NEW ADAPTIVE LOGIC)
                # Coordinates from CSV are usually (x, y, z) or (z, y, x).
                # Duke dataset is typically (x, y, z) in CSV, need to flip for numpy (z, y, x)
                center = (row['Start_z'] + row['End_z']) // 2, \
                         (row['Start_y'] + row['End_y']) // 2, \
                         (row['Start_x'] + row['End_x']) // 2

                # Default to 10mm if 'Box_Diameter' missing, or calculate from box
                diameter = 10.0
                if 'End_x' in row:
                    diameter = (row['End_x'] - row['Start_x']) * spacing[0]

                # THIS IS THE KEY FIX: Using the Adaptive Patch Sizing
                patch = self.preprocessor.extract_nodule_patch(
                    img_norm,
                    center,
                    diameter_mm=diameter,
                    spacing=spacing
                )

                # D. Extract Features
                # Create a dummy mask for radiomics (all ones since patch is cropped)
                mask = np.ones_like(patch)
                features = self.preprocessor.extract_radiomics_features(patch, mask)

                # E. Construct Graph
                x, edge_index = self.graph_constructor.construct_graph(features)

                # F. Label
                y = torch.tensor([int(row['Label'])], dtype=torch.long)

                # G. Save Data Object
                data = Data(x=x, edge_index=edge_index, y=y)
                data_list.append(data)

            except Exception as e:
                # print(f"Skipping {pid}: {e}")
                continue

        # 3. Save to Disk
        print(f"   - Successfully processed {len(data_list)} samples.")
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# ==============================================================================
# 3. TRIGGER BUILD
# ==============================================================================
print("\nüöÄ Starting Dataset Generation...")
dataset = LungCancerDataset(root_dir='/content/drive/MyDrive/duke_lung_data')
print(f"‚úÖ DONE! Dataset loaded with {len(dataset)} samples.")