In [1]:
!nvidia-smi

Thu Dec  4 13:28:48 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L40S                    On  |   00000000:49:00.0 Off |                    0 |
| N/A   48C    P0            159W /  350W |    5465MiB /  46068MiB |     63%   E. Process |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L40S                    On  |   00

In [3]:
import os
import pandas as pd
import numpy as np

print("="*70)
print("PANDA GRAPH TRANSFORMER - DATA PREPARATION (3-CLASS)")
print("="*70)

# ============================================================================
# STEP 1: Load your CSV splits
# ============================================================================
print("\n[1/5] Loading CSV splits...")

train_csv_path = "./data/splits/train_split.csv"
val_csv_path = "./data/splits/val_split.csv"
test_csv_path = "./data/splits/test_split.csv"

if not os.path.exists(train_csv_path):
    print(f"✗ ERROR: {train_csv_path} not found!")
    exit(1)
if not os.path.exists(val_csv_path):
    print(f"✗ ERROR: {val_csv_path} not found!")
    exit(1)
if not os.path.exists(test_csv_path):
    print(f"✗ ERROR: {test_csv_path} not found!")
    exit(1)

train_df = pd.read_csv(train_csv_path)
val_df = pd.read_csv(val_csv_path)
test_df = pd.read_csv(test_csv_path)

print(f"✓ Loaded train CSV: {len(train_df)} samples")
print(f"✓ Loaded val CSV: {len(val_df)} samples")
print(f"✓ Loaded test CSV: {len(test_df)} samples")

id_column = 'image_id'
label_column = 'isup_grade'

# ============================================================================
# STEP 2: Define label remapping for 3-class problem
# ============================================================================
print(f"\n[2/5] Remapping ISUP grades to 3 classes...")

print("\nOriginal ISUP grade distribution (train):")
print(train_df[label_column].value_counts().sort_index())

# Define mapping: ISUP 0-5 -> 3 classes
# Strategy 1: Background, Benign, Cancerous
# ISUP 0 -> Class 0 (Background)
# ISUP 1 -> Class 1 (Benign/Low grade)
# ISUP 2-5 -> Class 2 (Cancerous/High grade)

def remap_labels(isup_grade):
    """
    Remap ISUP grades (0-5) to 3 classes:
    - Class 0: Background (ISUP 0)
    - Class 1: Benign/Low-grade (ISUP 1)
    - Class 2: Cancerous/High-grade (ISUP 2-5)
    """
    if isup_grade == 0:
        return 0  # Background
    elif isup_grade == 1:
        return 1  # Benign
    else:  # isup_grade in [2, 3, 4, 5]
        return 2  # Cancerous

# Apply remapping
train_df['label_3class'] = train_df[label_column].apply(remap_labels)
val_df['label_3class'] = val_df[label_column].apply(remap_labels)
test_df['label_3class'] = test_df[label_column].apply(remap_labels)

print("\n✓ Remapping applied:")
print("  ISUP 0     -> Class 0 (Background)")
print("  ISUP 1     -> Class 1 (Benign)")
print("  ISUP 2-5   -> Class 2 (Cancerous)")

print("\nNew 3-class distribution (train):")
print(train_df['label_3class'].value_counts().sort_index())

print("\nNew 3-class distribution (val):")
print(val_df['label_3class'].value_counts().sort_index())

print("\nNew 3-class distribution (test):")
print(test_df['label_3class'].value_counts().sort_index())

# Show mapping examples
print("\nMapping examples:")
for i in range(6):
    count_train = len(train_df[train_df[label_column] == i])
    mapped = remap_labels(i)
    print(f"  ISUP {i} ({count_train:4d} samples) -> Class {mapped}")

# ============================================================================
# STEP 3: Verify graphs_all
# ============================================================================
print(f"\n[3/5] Verifying graphs_all...")

graphs_path = "./feature_extractor/graphs_phikon/panda"
if not os.path.exists(graphs_path):
    print(f"✗ ERROR: {graphs_path} not found!")
    exit(1)

available_slides = set(os.listdir(graphs_path))
print(f"✓ Available slides in graphs_all: {len(available_slides)}")

# ============================================================================
# STEP 4: Create train_set.txt, val_set.txt, and test_set.txt
# ============================================================================
print(f"\n[4/5] Creating train_set.txt, val_set.txt, and test_set.txt...")

os.makedirs("./scripts", exist_ok=True)

# Write train_set.txt with remapped labels
train_count = 0
missing_train = []

with open("./scripts/train_set.txt", 'w') as f:
    for _, row in train_df.iterrows():
        slide_id = str(row[id_column])
        label_3class = int(row['label_3class'])  # Use remapped label
        
        if slide_id not in available_slides:
            missing_train.append(slide_id)
            continue
        
        f.write(f"panda/{slide_id}\t{label_3class}\n")
        train_count += 1

print(f"✓ Created train_set.txt with {train_count} samples")
if missing_train:
    print(f"  ⚠ {len(missing_train)} train slides missing from graphs_all")

# Write val_set.txt with remapped labels
val_count = 0
missing_val = []

with open("./scripts/val_set.txt", 'w') as f:
    for _, row in val_df.iterrows():
        slide_id = str(row[id_column])
        label_3class = int(row['label_3class'])  # Use remapped label
        
        if slide_id not in available_slides:
            missing_val.append(slide_id)
            continue
        
        f.write(f"panda/{slide_id}\t{label_3class}\n")
        val_count += 1

print(f"✓ Created val_set.txt with {val_count} samples")
if missing_val:
    print(f"  ⚠ {len(missing_val)} val slides missing from graphs_all")

# Write test_set.txt with remapped labels
test_count = 0
missing_test = []

with open("./scripts/test_set.txt", 'w') as f:
    for _, row in test_df.iterrows():
        slide_id = str(row[id_column])
        label_3class = int(row['label_3class'])  # Use remapped label
        
        if slide_id not in available_slides:
            missing_test.append(slide_id)
            continue
        
        f.write(f"panda/{slide_id}\t{label_3class}\n")
        test_count += 1

print(f"✓ Created test_set.txt with {test_count} samples")
if missing_test:
    print(f"  ⚠ {len(missing_test)} test slides missing from graphs_all")

# Verify class distribution in output files
print("\nVerifying class distribution in output files:")

from collections import Counter

train_labels = []
with open("./scripts/train_set.txt", 'r') as f:
    for line in f:
        label = int(line.strip().split('\t')[1])
        train_labels.append(label)

val_labels = []
with open("./scripts/val_set.txt", 'r') as f:
    for line in f:
        label = int(line.strip().split('\t')[1])
        val_labels.append(label)

test_labels = []
with open("./scripts/test_set.txt", 'r') as f:
    for line in f:
        label = int(line.strip().split('\t')[1])
        test_labels.append(label)

train_counter = Counter(train_labels)
val_counter = Counter(val_labels)
test_counter = Counter(test_labels)

print(f"\nTrain set class distribution:")
for cls in sorted(train_counter.keys()):
    pct = 100 * train_counter[cls] / len(train_labels)
    print(f"  Class {cls}: {train_counter[cls]:5d} samples ({pct:.1f}%)")

print(f"\nVal set class distribution:")
for cls in sorted(val_counter.keys()):
    pct = 100 * val_counter[cls] / len(val_labels)
    print(f"  Class {cls}: {val_counter[cls]:5d} samples ({pct:.1f}%)")

print(f"\nTest set class distribution:")
for cls in sorted(test_counter.keys()):
    pct = 100 * test_counter[cls] / len(test_labels)
    print(f"  Class {cls}: {test_counter[cls]:5d} samples ({pct:.1f}%)")

# Show examples
print("\n" + "="*70)
print("SAMPLE OUTPUT")
print("="*70)
print("\nFirst 5 lines of train_set.txt:")
with open("./scripts/train_set.txt", 'r') as f:
    for i, line in enumerate(f):
        if i < 5:
            print(f"  {line.strip()}")

print("\nFirst 5 lines of val_set.txt:")
with open("./scripts/val_set.txt", 'r') as f:
    for i, line in enumerate(f):
        if i < 5:
            print(f"  {line.strip()}")

print("\nFirst 5 lines of test_set.txt:")
with open("./scripts/test_set.txt", 'r') as f:
    for i, line in enumerate(f):
        if i < 5:
            print(f"  {line.strip()}")

# ============================================================================
# STEP 5: Create output directories
# ============================================================================
print("\n[5/5] Creating output directories...")

os.makedirs("./graph_transformer/saved_models", exist_ok=True)
os.makedirs("./graph_transformer/runs", exist_ok=True)

print("✓ Created ./graph_transformer/saved_models/")
print("✓ Created ./graph_transformer/runs/")

# ============================================================================
# FINAL SUMMARY
# ============================================================================
print("\n" + "="*70)
print("SETUP COMPLETE - READY TO TRAIN!")
print("="*70)
print(f"\nDataset Split:")
print(f"  Training samples:   {train_count:,} ({100*train_count/(train_count+val_count+test_count):.0f}%)")
print(f"  Validation samples: {val_count:,} ({100*val_count/(train_count+val_count+test_count):.0f}%)")
print(f"  Test samples:       {test_count:,} ({100*test_count/(train_count+val_count+test_count):.0f}%)")
print(f"  Total:              {train_count+val_count+test_count:,}")

print(f"\nClasses (n_class): 3")
print(f"  - Class 0: Background (ISUP 0)")
print(f"  - Class 1: Benign (ISUP 1)")  
print(f"  - Class 2: Cancerous (ISUP 2-5)")

print("\n" + "="*70)
print("OUTPUT FILES:")
print("="*70)
print("  ./scripts/train_set.txt")
print("  ./scripts/val_set.txt")
print("  ./scripts/test_set.txt")
print("  ./scripts/label_mapping_info.json")

# Save mapping info for reference
mapping_info = {
    'original_classes': 6,
    'new_classes': 3,
    'mapping': {
        'ISUP_0': 'Class_0_Background',
        'ISUP_1': 'Class_1_Benign',
        'ISUP_2-5': 'Class_2_Cancerous'
    },
    'train_distribution': dict(train_counter),
    'val_distribution': dict(val_counter),
    'test_distribution': dict(test_counter),
    'total_samples': {
        'train': train_count,
        'val': val_count,
        'test': test_count,
        'total': train_count + val_count + test_count
    }
}

import json
with open('./scripts/label_mapping_info.json', 'w') as f:
    json.dump(mapping_info, f, indent=2)

print("\n✓ Saved label mapping info to ./scripts/label_mapping_info.json")

PANDA GRAPH TRANSFORMER - DATA PREPARATION (3-CLASS)

[1/5] Loading CSV splits...
✓ Loaded train CSV: 8492 samples
✓ Loaded val CSV: 1062 samples
✓ Loaded test CSV: 1062 samples

[2/5] Remapping ISUP grades to 3 classes...

Original ISUP grade distribution (train):
isup_grade
0    2323
1    2180
2    1074
3     976
4     997
5     942
Name: count, dtype: int64

✓ Remapping applied:
  ISUP 0     -> Class 0 (Background)
  ISUP 1     -> Class 1 (Benign)
  ISUP 2-5   -> Class 2 (Cancerous)

New 3-class distribution (train):
label_3class
0    2323
1    2180
2    3989
Name: count, dtype: int64

New 3-class distribution (val):
label_3class
0    284
1    243
2    535
Name: count, dtype: int64

New 3-class distribution (test):
label_3class
0    285
1    243
2    534
Name: count, dtype: int64

Mapping examples:
  ISUP 0 (2323 samples) -> Class 0
  ISUP 1 (2180 samples) -> Class 1
  ISUP 2 (1074 samples) -> Class 2
  ISUP 3 ( 976 samples) -> Class 2
  ISUP 4 ( 997 samples) -> Class 2
  ISUP 5 ( 9

In [5]:
import torch
import numpy as np
from sklearn.metrics import cohen_kappa_score, confusion_matrix, classification_report
from torch.utils.data import DataLoader
from utils.dataset import GraphDataset
from helper import collate
from models.GraphTransformer import Classifier
from collections import OrderedDict

device = torch.device("cuda")

# Load data
val_file = './scripts/val_set.txt'
with open(val_file, 'r') as f:
    val_ids = [line.strip() for line in f if line.strip()]

dataset_val = GraphDataset('./feature_extractor/graphs_all', val_ids, site='panda')
dataloader_val = DataLoader(dataset_val, batch_size=1, collate_fn=collate, num_workers=0, shuffle=False)

# Calculate class weights
train_file = './scripts/train_set.txt'
with open(train_file, 'r') as f:
    train_labels = [int(line.split()[1]) for line in f if line.strip() and len(line.split()) == 2]

class_counts = np.bincount(train_labels)
class_weights = 1.0 / (class_counts + 1e-6)
class_weights = class_weights / class_weights.sum() * 3
class_weights = torch.FloatTensor(class_weights).cuda()

# Load model WITH class weights
model = Classifier(n_class=3, n_features=512, class_weights=class_weights)
state_dict = torch.load('./graph_transformer/saved_models/GraphCAM_PANDA_WEIGHTED_v2.pth', map_location=device)
new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
model.load_state_dict(new_state_dict)
model = model.to(device).eval()

# Hook to capture logits
logits_captured = None
def hook_fn(module, input, output):
    global logits_captured
    if hasattr(output, 'shape') and len(output.shape) == 2 and output.shape[1] == 3:
        logits_captured = output.clone()

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        module.register_forward_hook(hook_fn)

print("="*70)
print("GTP BASELINE EVALUATION - WEIGHTED MODEL")
print("="*70)
print(f"Samples: {len(dataset_val)}\n")

all_preds = []
all_labels = []

with torch.no_grad():
    for i, sample in enumerate(dataloader_val):
        if sample is None:
            continue
        
        try:
            img = sample["image"][0] if isinstance(sample["image"], list) else sample["image"]
            adj = sample["adj_s"][0] if isinstance(sample["adj_s"], list) else sample["adj_s"]
            label = sample["label"][0] if isinstance(sample["label"], list) else sample["label"]
            
            img = img.unsqueeze(0).float().to(device)
            adj = adj.unsqueeze(0).float().to(device)
            mask = torch.ones(1, img.size(1)).to(device)
            label_tensor = torch.tensor([label], dtype=torch.long).to(device)
            
            logits_captured = None
            model(img, label_tensor, adj, mask)
            
            if logits_captured is not None:
                pred = logits_captured.argmax(1).item()
                all_preds.append(pred)
                all_labels.append(label)
            
            if (i + 1) % 200 == 0:
                print(f"  {i+1}/{len(dataloader_val)}")
        except:
            continue

print(f"\n✓ Predictions: {len(all_preds)}\n")

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

acc = np.mean(all_preds == all_labels)
qwk = cohen_kappa_score(all_labels, all_preds, weights='quadratic')

print("="*70)
print("BASELINE RESULTS")
print("="*70)
print(f"Accuracy: {acc:.4f} ({acc*100:.2f}%)")
print(f"\n★★★ BASELINE QWK: {qwk:.4f} ★★★")

cm = confusion_matrix(all_labels, all_preds, labels=[0,1,2])
print(f"\nConfusion Matrix:")
print(f"         Pred0  Pred1  Pred2")
print(f"Actual 0: {cm[0][0]:4d}   {cm[0][1]:4d}   {cm[0][2]:4d}")
print(f"Actual 1: {cm[1][0]:4d}   {cm[1][1]:4d}   {cm[1][2]:4d}")
print(f"Actual 2: {cm[2][0]:4d}   {cm[2][1]:4d}   {cm[2][2]:4d}")

print(f"\n{classification_report(all_labels, all_preds, target_names=['C0','C1','C2'], digits=3, zero_division=0)}")

print("="*70)
print("GRADE TARGETS")
print("="*70)
print(f"★ BASELINE: {qwk:.4f}")
print(f"  B+ (5%):  {qwk*1.05:.4f}")
print(f"  A (10%):  {qwk*1.10:.4f}")
print("="*70)

Classifier initialized with class weights: tensor([1.13, 1.21, 0.66], device='cuda:0')
GTP BASELINE EVALUATION - WEIGHTED MODEL
Samples: 2123

  200/2123
  400/2123
  600/2123
  800/2123
  1000/2123
  1200/2123
  1400/2123
  1600/2123
  1800/2123
  2000/2123

✓ Predictions: 2123

BASELINE RESULTS
Accuracy: 0.5040 (50.40%)

★★★ BASELINE QWK: 0.0018 ★★★

Confusion Matrix:
         Pred0  Pred1  Pred2
Actual 0:    1      0    567
Actual 1:    0      0    486
Actual 2:    0      0   1069

              precision    recall  f1-score   support

          C0      1.000     0.002     0.004       568
          C1      0.000     0.000     0.000       486
          C2      0.504     1.000     0.670      1069

    accuracy                          0.504      2123
   macro avg      0.501     0.334     0.225      2123
weighted avg      0.521     0.504     0.338      2123

GRADE TARGETS
★ BASELINE: 0.0018
  B+ (5%):  0.0019
  A (10%):  0.0020


In [None]:
import torch
import numpy as np
from sklearn.metrics import cohen_kappa_score, confusion_matrix, classification_report
from torch.utils.data import DataLoader
from utils.dataset import GraphDataset
from helper import collate
from models.GraphTransformer import Classifier
from collections import OrderedDict

device = torch.device("cuda")

# Load data
val_file = './scripts/val_set.txt'
with open(val_file, 'r') as f:
    val_ids = [line.strip() for line in f if line.strip()]

dataset_val = GraphDataset('./feature_extractor/graphs_all', val_ids, site='panda')
dataloader_val = DataLoader(dataset_val, batch_size=1, collate_fn=collate, num_workers=0, shuffle=False)

# Calculate class weights
train_file = './scripts/train_set.txt'
with open(train_file, 'r') as f:
    train_labels = [int(line.split()[1]) for line in f if line.strip() and len(line.split()) == 2]

class_counts = np.bincount(train_labels)
class_weights = 1.0 / (class_counts + 1e-6)
class_weights = class_weights / class_weights.sum() * 3
class_weights = torch.FloatTensor(class_weights).cuda()

# Load model WITH class weights
model = Classifier(n_class=3, n_features=512, class_weights=class_weights)
state_dict = torch.load('./graph_transformer/saved_models/GraphCAM_PANDA_WEIGHTED_v2.pth', map_location=device)
new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
model.load_state_dict(new_state_dict)
model = model.to(device).eval()

# Hook to capture logits
logits_captured = None
def hook_fn(module, input, output):
    global logits_captured
    if hasattr(output, 'shape') and len(output.shape) == 2 and output.shape[1] == 3:
        logits_captured = output.clone()

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        module.register_forward_hook(hook_fn)

print("="*70)
print("GTP BASELINE EVALUATION - WEIGHTED MODEL")
print("="*70)
print(f"Samples: {len(dataset_val)}\n")

all_preds = []
all_labels = []

with torch.no_grad():
    for i, sample in enumerate(dataloader_val):
        if sample is None:
            continue
        
        try:
            img = sample["image"][0] if isinstance(sample["image"], list) else sample["image"]
            adj = sample["adj_s"][0] if isinstance(sample["adj_s"], list) else sample["adj_s"]
            label = sample["label"][0] if isinstance(sample["label"], list) else sample["label"]
            
            img = img.unsqueeze(0).float().to(device)
            adj = adj.unsqueeze(0).float().to(device)
            mask = torch.ones(1, img.size(1)).to(device)
            label_tensor = torch.tensor([label], dtype=torch.long).to(device)
            
            logits_captured = None
            model(img, label_tensor, adj, mask)
            
            if logits_captured is not None:
                pred = logits_captured.argmax(1).item()
                all_preds.append(pred)
                all_labels.append(label)
            
            if (i + 1) % 200 == 0:
                print(f"  {i+1}/{len(dataloader_val)}")
        except:
            continue

print(f"\n✓ Predictions: {len(all_preds)}\n")

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

acc = np.mean(all_preds == all_labels)
qwk = cohen_kappa_score(all_labels, all_preds, weights='quadratic')

print("="*70)
print("BASELINE RESULTS")
print("="*70)
print(f"Accuracy: {acc:.4f} ({acc*100:.2f}%)")
print(f"\n★★★ BASELINE QWK: {qwk:.4f} ★★★")

cm = confusion_matrix(all_labels, all_preds, labels=[0,1,2])
print(f"\nConfusion Matrix:")
print(f"         Pred0  Pred1  Pred2")
print(f"Actual 0: {cm[0][0]:4d}   {cm[0][1]:4d}   {cm[0][2]:4d}")
print(f"Actual 1: {cm[1][0]:4d}   {cm[1][1]:4d}   {cm[1][2]:4d}")
print(f"Actual 2: {cm[2][0]:4d}   {cm[2][1]:4d}   {cm[2][2]:4d}")

print(f"\n{classification_report(all_labels, all_preds, target_names=['C0','C1','C2'], digits=3, zero_division=0)}")

print("="*70)
print("GRADE TARGETS")
print("="*70)
print(f"★ BASELINE: {qwk:.4f}")
print(f"  B+ (5%):  {qwk*1.05:.4f}")
print(f"  A (10%):  {qwk*1.10:.4f}")
print("="*70)