### 2. Define Prototype Features

***
#### PANTHER: Global Prototype Generation

This section executes the **PANTHER** workflow to extract representative tissue patterns (prototypes) from the CPTAC dataset.

##### 1. Configuration & Setup
* **Source**: Appends PANTHER scripts to `sys.path` and configures CPTAC feature paths.
* **Hyperparameters**: Generates **16 prototypes** (`n_proto`) using **K-means** clustering on a sample of **10,000 patches** (`n_proto_patches`) with a **1536-d** feature space.

##### 2. Data Loading
* **Dataset**: Uses `WSIProtoDataset` to load UNI features from LSCC and LUAD H5 files based on pre-defined `train` splits.

##### 3. Execution & Saving
* **`cluster()`**: Samples patches across the cohort and performs K-means to find 16 centroids (tissue patterns).
* **Storage**: Saves the centroids as a `.pkl` file to be used as a reference for downstream analysis (e.g., local prototype counting or graph building).

In [None]:
BASE_DIR = '/workspace/HDDX/Pathology_Graph'
split_n = 'split_0'

import sys
sys.path.append(f'{BASE_DIR}/github/PANTHER/src')

from __future__ import print_function

import argparse
import torch
from torch.utils.data import DataLoader
from wsi_datasets import WSIProtoDataset
from utils.utils import seed_torch, read_splits
from utils.file_utils import save_pkl
from utils.proto_utils import cluster

import os
from os.path import join as j_

def build_datasets(csv_splits, batch_size=1, num_workers=2, train_kwargs={}):
    dataset_splits = {}
    for k in csv_splits.keys(): # ['train']
        df = csv_splits[k]
        dataset_kwargs = train_kwargs.copy()
        dataset = WSIProtoDataset(df, **dataset_kwargs)

        batch_size = 1
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
        dataset_splits[k] = dataloader
        print(f'split: {k}, n: {len(dataset)}')

    return dataset_splits

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1)
# model / loss fn args ###
parser.add_argument('--n_proto', type=int, default=16)
parser.add_argument('--n_proto_patches', type=int, default=10000)
parser.add_argument('--n_init', type=int, default=5)
parser.add_argument('--n_iter', type=int, default=50)
parser.add_argument('--in_dim', type=int, default=1536)
parser.add_argument('--mode', type=str, choices=['kmeans', 'faiss'], default='kmeans')

# dataset / split args ###
parser.add_argument('--data_source', type=str, default=None)
parser.add_argument('--split_dir', type=str, default=f'./splits/CPTAC/{split_n}/')
parser.add_argument('--split_names', type=str, default='train')
parser.add_argument('--num_workers', type=int, default=0)
args = parser.parse_args(args=[])

args.data_source = [f'{BASE_DIR}/datasource/CPTAC/LSCC_CLAM/patch_512/uni_features/feats_h5',
                    f'{BASE_DIR}/datasource/CPTAC/LUAD_CLAM/patch_512/uni_features/feats_h5']

# data loading
train_kwargs = dict(data_source=args.data_source)
seed_torch(args.seed)
csv_splits = read_splits(args)
print('\nsuccessfully read splits for: ', list(csv_splits.keys()))

# run Panther prototyping
dataset_splits = build_datasets(csv_splits,
                                batch_size=1,
                                num_workers=args.num_workers,
                                train_kwargs=train_kwargs)
print('\nInit Datasets...', end=' ')


os.makedirs(j_(args.split_dir, 'prototypes'), exist_ok=True)
loader_train = dataset_splits['train']

tmp = next(iter(loader_train))

_, weights = cluster(loader_train,
                     n_proto=args.n_proto,
                     n_iter=args.n_iter,
                     n_init=args.n_init,
                     feature_dim=args.in_dim,
                     mode=args.mode,
                     n_proto_patches=args.n_proto_patches,
                     use_cuda=True if torch.cuda.is_available() else False)


save_fpath = j_(args.split_dir,
                'prototypes',
                f"prototypes_c{args.n_proto}_{args.data_source[0].split('/')[-2]}_{args.mode}_num_{args.n_proto_patches:.1e}.pkl")

save_pkl(save_fpath, {'prototypes': weights})


Using the following split names: ['train']

successfully read splits for:  ['train']
split: train, n: 1662

Init Datasets... Sampling maximum of 160000 patches: 97 each from 1662


  np.random.shuffle(data_reshaped)
100%|██████████| 1662/1662 [05:33<00:00,  4.98it/s]



Total of 160000 patches aggregated

Using Kmeans for clustering...

	Num of clusters 16, num of iter 50

Clustering took 289.72745847702026 seconds!


***
#### Feature Quantization: Bag-of-Prototypes Encoding

This script quantifies Whole Slide Images (WSIs) by mapping individual patch features to the previously generated global prototypes.

##### 1. Resource Preparation
* **Prototypes**: Loads the 16-dimensional centroids and ensures they are in `float32` format for GPU-accelerated computation.
* **Target Data**: Targets `.h5` files containing patch-level embeddings (e.g., UNI features).

##### 2. Similarity Mapping & Assignment
* **Normalization**: Applies L2-normalization to both patch features and prototypes to calculate **cosine similarity** via matrix multiplication (`torch.mm`).
* **Assignment**: Maps each patch to its most similar prototype using `torch.argmax`.

##### 3. Counting & Serialization
* **Aggregation**: Uses `torch.bincount` to calculate the frequency of each prototype (0–15) within the slide, creating a **"Bag-of-Prototypes"** vector.
* **Output**: Saves the resulting 16-dimensional frequency tensor as a `.pt` file, which serves as a simplified explanatory representation (`expl`) of the slide's morphology.

In [None]:
import h5py
import pickle
import os
import torch
from glob import glob
from tqdm import tqdm

# Settings
BASE_DIR = '/workspace/HDDX/Pathology_Graph'
split_n = 'split_0'
NUM_PROTOTYPES = 16 

H5_FILES = sorted(glob(f'{BASE_DIR}/datasource/CPTAC/*_CLAM/patch_512/uni_features/feats_h5/*.h5'))
PROTO_PATH = f'./splits/CPTAC/{split_n}/prototypes/prototypes_c16_uni_features_kmeans_num_1.0e+04.pkl'
SAVE_DIR = f'./splits/CPTAC/{split_n}/expl_16x16'

os.makedirs(SAVE_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Load prototypes
with open(PROTO_PATH, 'rb') as f:
    proto_data = pickle.load(f)
proto_tensor = torch.from_numpy(proto_data['prototypes'].squeeze()).to(device).float()

for h5_path in tqdm(H5_FILES):
    slide_id = os.path.basename(h5_path).replace('.h5', '')
    
    try:
        with h5py.File(h5_path, 'r') as h5:
            feats = torch.from_numpy(h5['features'][:]).to(device).float()
    
        target_dtype = proto_tensor.dtype # float32
        
        norm_feats = torch.nn.functional.normalize(feats.to(target_dtype), dim=1)
        norm_protos = torch.nn.functional.normalize(proto_tensor.to(target_dtype), dim=1)
        
        # calculate cosine similarity
        sim = torch.mm(norm_feats, norm_protos.t()) 
        
        # assign each patch to the most similar prototype
        global_cluster_labels = torch.argmax(sim, dim=1)
        count_expl = torch.bincount(global_cluster_labels, minlength=NUM_PROTOTYPES).float()
        
        torch.save(count_expl.cpu(), f'{SAVE_DIR}/{slide_id}_expl.pt')
        
    except Exception as e:
        print(f"Error processing {slide_id}: {e}")
        continue


100%|██████████| 2162/2162 [02:05<00:00, 17.22it/s] 
