# Phikon Feature Extraction for PANDA Prostate Cancer Dataset

Extract 768-dimensional pathology-specific features from WSI patches using the Phikon foundation model

---

## Overview

### Background

Our baseline GTP (Graph Transformer) model used **ResNet50** pretrained on **ImageNet** (natural images like cats, dogs, cars) to extract features from histopathology patches. While this achieved strong results (QWK: 0.7568, AUPRC: 0.8129), ImageNet features are not optimized for medical tissue analysis.

### Motivation

**Phikon** is a Vision Transformer (ViT-Base) pretrained on **40+ million histopathology images** - the same type of H&E-stained tissue images we're analyzing. This domain-specific pretraining should provide features that better capture:

- Cellular morphology patterns
- Gland architecture
- Tissue texture specific to pathology
- Cancer-relevant visual features

### What We're Doing

1. **Replace ImageNet features with pathology-specific features**
   - ResNet50 (ImageNet): 2048-dim general features
   - Phikon (Pathology): 768-dim domain-specific features

2. **Extract features from all WSI patches**
   - ~8,500 Whole Slide Images
   - ~150 patches per WSI (average)
   - ~1.2 million total patches

3. **Build spatial graphs** (same as baseline)
   - 8-connectivity adjacency
   - Preserve tissue architecture

4. **Train GTP with new features**
   - Same architecture, different input features
   - Compare with ResNet50 baseline

### Expected Outcome (Yet to run - 12/6)

| Model | Feature Dim | QWK | AUPRC |
|-------|-------------|-----|-------|
| ResNet50 (ImageNet) | 2048 | 0.7568 | 0.8129 |
| Phikon (Pathology) | 768 | 0.78-0.82 | 0.84-0.87 |

**Expected improvement: +3-8% QWK**

### Why This Matters

Using a pathology foundation model demonstrates:
- Understanding of domain-specific transfer learning
- Awareness of recent advances in medical AI
- Ability to improve upon standard baselines
- Satisfies professor's requirement to try an H&E foundation model

In [2]:
#!/usr/bin/env python3
"""
Extract Phikon Features - TRAINING + VALIDATION TILES
"""
import torch
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import os
import glob
from tqdm import tqdm

print("="*70)
print("EXTRACTING PHIKON FEATURES FROM WSI PATCHES")
print("="*70)
print()

# ============================================================
# CONFIGURATION
# ============================================================
TILES_DIRS = [
    # Training tiles
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_01",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_02",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_03",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_04",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_05",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_06",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_07",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_08",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_09",
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/tiles_10",
    # Validation tiles (single directory)
    "/projectnb/ec500kb/projects/Project_1_Team_1/PANDA_DATA_MANNY/val_tiles",
]

OUTPUT_DIR = "/projectnb/ec500kb/projects/Project_1_Team_1/Official_GTP_PANDAS/feature_extractor/graphs_phikon/panda"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================
# LOAD PHIKON
# ============================================================
print("Loading Phikon model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = ViTImageProcessor.from_pretrained("owkin/phikon")
model = ViTModel.from_pretrained("owkin/phikon")
model = model.to(device)
model.eval()

for param in model.parameters():
    param.requires_grad = False

print(f"✓ Phikon loaded on {device}")
print(f"  Feature dimension: 768")
print()

# ============================================================
# FIND ALL WSI DIRECTORIES
# ============================================================
print("Finding all WSI directories...")

all_wsi_dirs = []
train_count = 0
val_count = 0

for tiles_dir in TILES_DIRS:
    if not os.path.exists(tiles_dir):
        print(f"  ⚠ Skipping {tiles_dir} (not found)")
        continue
    
    # Pattern: tiles_XX/wsi_id/
    wsi_dirs = [d for d in glob.glob(f"{tiles_dir}/*/") if os.path.isdir(d)]
    all_wsi_dirs.extend(wsi_dirs)
    
    dir_name = os.path.basename(tiles_dir)
    if "val" in dir_name:
        val_count += len(wsi_dirs)
    else:
        train_count += len(wsi_dirs)
    
    print(f"  {dir_name}: {len(wsi_dirs)} WSIs")

print(f"\n✓ Total WSIs found: {len(all_wsi_dirs)}")
print(f"  Training: {train_count}")
print(f"  Validation: {val_count}")
print()

# ============================================================
# EXTRACT FEATURES
# ============================================================
print("="*70)
print("Starting feature extraction...")
print("="*70)
print()

processed = 0
skipped = 0
errors = 0

for wsi_dir in tqdm(all_wsi_dirs, desc="Extracting features"):
    wsi_id = os.path.basename(wsi_dir.rstrip('/'))
    
    # Output directory for this WSI
    wsi_output_dir = os.path.join(OUTPUT_DIR, wsi_id)
    
    # Skip if already processed
    if os.path.exists(os.path.join(wsi_output_dir, 'features.pt')):
        skipped += 1
        continue
    
    # Get all patches (JPEG files!)
    patches = glob.glob(f"{wsi_dir}/*.jpeg") + glob.glob(f"{wsi_dir}/*.jpg") + glob.glob(f"{wsi_dir}/*.JPEG") + glob.glob(f"{wsi_dir}/*.JPG")
    
    if len(patches) == 0:
        continue
    
    try:
        features_list = []
        coords_list = []
        
        for patch_path in patches:
            try:
                # Get coordinates from filename (format: row_col.jpeg)
                filename = os.path.basename(patch_path)
                name_parts = filename.replace('.jpeg', '').replace('.jpg', '').replace('.JPEG', '').replace('.JPG', '').split('_')
                
                if len(name_parts) >= 2:
                    row, col = int(name_parts[0]), int(name_parts[1])
                else:
                    continue
                
                # Load and process image
                img = Image.open(patch_path).convert('RGB')
                inputs = processor(images=img, return_tensors="pt").to(device)
                
                # Extract features
                with torch.no_grad():
                    outputs = model(**inputs)
                    features = outputs.last_hidden_state[:, 0]  # CLS token (768-dim)
                
                features_list.append(features.cpu())
                coords_list.append([row, col])
                
            except Exception as e:
                continue
        
        # Build adjacency matrix and save
        if len(features_list) > 0:
            os.makedirs(wsi_output_dir, exist_ok=True)
            
            features_tensor = torch.cat(features_list, dim=0)  # [N, 768]
            N = features_tensor.shape[0]
            
            # Build adjacency matrix (8-connectivity)
            adj = torch.zeros(N, N)
            for i in range(N):
                for j in range(i+1, N):
                    row_diff = abs(coords_list[i][0] - coords_list[j][0])
                    col_diff = abs(coords_list[i][1] - coords_list[j][1])
                    if row_diff <= 1 and col_diff <= 1:
                        adj[i, j] = 1
                        adj[j, i] = 1
            
            # Save
            torch.save(features_tensor, os.path.join(wsi_output_dir, 'features.pt'))
            torch.save(adj, os.path.join(wsi_output_dir, 'adj_s.pt'))
            
            with open(os.path.join(wsi_output_dir, 'c_idx.txt'), 'w') as f:
                for coord in coords_list:
                    f.write(f"{coord[0]}\t{coord[1]}\n")
            
            processed += 1
            
    except Exception as e:
        errors += 1
        continue

# ============================================================
# SUMMARY
# ============================================================
print()
print("="*70)
print("FEATURE EXTRACTION COMPLETE")
print("="*70)
print()
print(f"Total WSIs: {len(all_wsi_dirs)}")
print(f"Processed: {processed}")
print(f"Skipped (already done): {skipped}")
print(f"Errors: {errors}")
print()
print(f"Output saved to: {OUTPUT_DIR}")
print(f"Feature dimension: 768")
print()

# Verify
output_count = len(glob.glob(f"{OUTPUT_DIR}/*/features.pt"))
print(f"Verified: {output_count} graphs created")
print()
print("Next step: Train GTP with --n_features 768")
print("="*70)

EXTRACTING PHIKON FEATURES FROM WSI PATCHES

Loading Phikon model...
✓ Phikon loaded on cuda
  Feature dimension: 768

Finding all WSI directories...
  tiles_01: 850 WSIs
  tiles_02: 850 WSIs
  tiles_03: 850 WSIs
  tiles_04: 850 WSIs
  tiles_05: 850 WSIs
  tiles_06: 850 WSIs
  tiles_07: 850 WSIs
  tiles_08: 850 WSIs
  tiles_09: 850 WSIs
  tiles_10: 842 WSIs
  val_tiles: 2124 WSIs

✓ Total WSIs found: 10616
  Training: 8492
  Validation: 2124

Starting feature extraction...



Extracting features: 100%|██████████| 10616/10616 [1:47:52<00:00,  1.64it/s]



FEATURE EXTRACTION COMPLETE

Total WSIs: 10616
Processed: 2123
Skipped (already done): 8492
Errors: 0

Output saved to: /projectnb/ec500kb/projects/Project_1_Team_1/Official_GTP_PANDAS/feature_extractor/graphs_phikon/panda
Feature dimension: 768

Verified: 10615 graphs created

Next step: Train GTP with --n_features 768


In [2]:
#!/usr/bin/env python3
"""
Build Graphs from Phikon Features
"""
import torch
import os
import glob
from tqdm import tqdm
import numpy as np

print("="*70)
print("BUILDING GRAPHS FROM PHIKON FEATURES")
print("="*70)
print()

# ============================================================
# CONFIGURATION
# ============================================================
FEATURES_DIR = "/projectnb/ec500kb/projects/Project_1_Team_1/Official_GTP_PANDAS/feature_extractor/phikon_features"
OUTPUT_DIR = "/projectnb/ec500kb/projects/Project_1_Team_1/Official_GTP_PANDAS/feature_extractor/graphs_phikon/panda"
os.makedirs(OUTPUT_DIR, exist_ok=True)

CONNECTIVITY = 8  # 8-connectivity (patches within 1 grid position)

# ============================================================
# BUILD GRAPHS
# ============================================================
print("Building graphs with 8-connectivity...")
print()

feature_files = glob.glob(f"{FEATURES_DIR}/*.pt")
print(f"Found {len(feature_files)} WSIs with Phikon features")
print()

processed = 0
for feature_file in tqdm(feature_files):
    wsi_id = os.path.basename(feature_file).replace('_phikon.pt', '')
    
    # Output directory for this WSI
    wsi_output_dir = os.path.join(OUTPUT_DIR, wsi_id)
    os.makedirs(wsi_output_dir, exist_ok=True)
    
    # Skip if already processed
    if os.path.exists(os.path.join(wsi_output_dir, 'features.pt')):
        continue
    
    # Load features and coordinates
    data = torch.load(feature_file)
    features = data['features']  # Shape: [N, 768]
    coords = data['coords']  # Shape: [N, 2]
    
    N = features.shape[0]
    
    # Build adjacency matrix (8-connectivity)
    adj = torch.zeros(N, N)
    
    for i in range(N):
        for j in range(i+1, N):
            # Check if within 1 grid position (8-connectivity)
            row_diff = abs(coords[i, 0] - coords[j, 0])
            col_diff = abs(coords[i, 1] - coords[j, 1])
            
            if row_diff <= 1 and col_diff <= 1:
                adj[i, j] = 1
                adj[j, i] = 1
    
    # Save graph
    torch.save(features, os.path.join(wsi_output_dir, 'features.pt'))
    torch.save(adj, os.path.join(wsi_output_dir, 'adj_s.pt'))
    
    # Save coordinates as text
    with open(os.path.join(wsi_output_dir, 'c_idx.txt'), 'w') as f:
        for coord in coords:
            f.write(f"{coord[0]} {coord[1]}\n")
    
    processed += 1

print()
print("="*70)
print("GRAPH CONSTRUCTION COMPLETE")
print("="*70)
print()
print(f"Processed: {processed} WSIs")
print(f"Graphs saved to: {OUTPUT_DIR}")
print(f"Feature dimension: 768")
print()
print("Next step: Train GTP with --n_features 768")
print("="*70)

BUILDING GRAPHS FROM PHIKON FEATURES

Building graphs with 8-connectivity...

Found 0 WSIs with Phikon features



0it [00:00, ?it/s]


GRAPH CONSTRUCTION COMPLETE

Processed: 0 WSIs
Graphs saved to: /projectnb/ec500kb/projects/Project_1_Team_1/Official_GTP_PANDAS/feature_extractor/graphs_phikon/panda
Feature dimension: 768

Next step: Train GTP with --n_features 768



