In [1]:
import sys, os

sys.path.append(os.path.abspath('..'))


import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data._utils.collate import default_collate
import importlib
import src.feature_extraction.slide_level_features.slide_level_features  as slf
importlib.reload(slf)

from src.feature_extraction.slide_level_features.slide_level_features import calculate_basic_stats, full_slide_features, get_graph_feats, extract_distance_feats
from src.PROTAS_model.model import MLPClassifier


# Synthetic Data Folder

In [2]:
data_root = './synthetic_data'
cancer_masks = os.path.join(data_root, 'cancer_masks')
tissue_masks = os.path.join(data_root, 'tissue_masks')
synthetic_test_data = pd.read_csv(os.path.join(data_root, 'synthetic_test_data.csv'), index_col = 0)
synthetic_clinical = pd.read_csv(os.path.join(data_root, 'synthetic_clinical_data.csv'), index_col = 0)
demo_checkpoint = './model_weights.pth'

In [3]:
synthetic_test_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 45884 entries, 0 to 45883
Data columns (total 7 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   slide_name               45884 non-null  object
 1   x                        45884 non-null  int64 
 2   y                        45884 non-null  int64 
 3   in_cancer                45884 non-null  int64 
 4   distance_from_cancer_mm  45884 non-null  int64 
 5   stroma_label             45884 non-null  int64 
 6   feature_path             45884 non-null  object
dtypes: int64(5), object(2)
memory usage: 2.8+ MB


In [4]:
synthetic_clinical.info()

<class 'pandas.core.frame.DataFrame'>
Index: 5 entries, 0 to 4
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   patient_id         5 non-null      object 
 1   slide_name         5 non-null      object 
 2   age                5 non-null      float64
 3   gleason_score      5 non-null      int64  
 4   gleason_primary    5 non-null      int64  
 5   gleason_secondary  5 non-null      int64  
 6   grade_group        5 non-null      int64  
dtypes: float64(1), int64(4), object(2)
memory usage: 320.0+ bytes


# Load model (synthetic model in this case)

### Initalize model

In [5]:
in_channels = 1024
l1 = 128
l2 = 64
dropout_prob = 0.25
forward_passes = 100
batch_size = 64
gpu_int = 0
device = torch.device(f'cuda:{gpu_int}')

model = MLPClassifier(
    in_channels = in_channels, 
    layers = [l1, l2], 
    out_channels = 1,
    dropout_prob = dropout_prob,
    initalize_weight = False)

def enable_dropout(model):
    for m in model.modules():
        if isinstance(m, torch.nn.Dropout):
            m.train()


### Load checkpoint

In [6]:
checkpoint = torch.load(demo_checkpoint)
print(checkpoint.keys())
model.load_state_dict(checkpoint['model_state_dict'])
model.cuda(device)

dict_keys(['model_state_dict', 'model_config', 'seed'])


MLPClassifier(
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.25, inplace=False)
    (8): Linear(in_features=64, out_features=1, bias=True)
  )
)

### Set up dataset

In [7]:
class TestDataset(Dataset):
    def __init__(self,
        df,
        uni_feature_path,
        rank = 0 # for ddp
        ):
        self.df = df
        self.uni_feature_path = uni_feature_path
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        curr_row = self.df.loc[index]
        x, y = curr_row['x'], curr_row['y']
        slide_name = curr_row['slide_name']
        feat_path = os.path.join(self.uni_feature_path, slide_name, f'{x}_{y}.npy')
        try:
            feats = np.load(feat_path)
        except:
            return None, None
        try:
            label = curr_row['stroma_label']
        except:
            label = curr_row['label']
        return feats, label

In [8]:
test_dataset = TestDataset(
    df = synthetic_test_data,
    uni_feature_path = './synthetic_data/uni_features'
    )

In [9]:
ex = next(iter(test_dataset))
ex[0].shape, ex[1]

((1024,), 0)

In [10]:
def custom_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        return []
    return default_collate(batch)

test_loader = DataLoader(
        test_dataset,
        batch_size = batch_size,
        shuffle = False,
        num_workers = 2,
        pin_memory = True,
        collate_fn = custom_collate_fn
    )

### Run MCD

In [11]:
def test(model, test_loader):
    
    correct = 0
    total = 0

    n_samples = len(test_loader.dataset)
    dropout_predictions = np.empty((0, n_samples, 1))


    for i in tqdm(range(forward_passes)):
        predictions = np.empty((0, 1))
        model.eval()
        enable_dropout(model)

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                
                inputs = inputs.to(device)
                outputs = model(inputs)

                probs = torch.sigmoid(outputs)
                probs = probs.cpu().numpy()

                for pred in probs:
                    predictions = np.vstack((predictions, pred))
            
        dropout_predictions = np.vstack((
            dropout_predictions, predictions[np.newaxis, :, :]
        ))


    return dropout_predictions

In [12]:
dropout_predictions = test(model, test_loader)
dropout_predictions.shape

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [20:48<00:00, 12.49s/it]


(100, 45884, 1)

In [13]:
synthetic_test_data.shape

(45884, 7)

# Get confident predictions

In [14]:
mean_prediction = dropout_predictions.mean(axis = 0)
uncertainty = dropout_predictions.var(axis = 0)
synthetic_test_data['mean_pred'] = mean_prediction
synthetic_test_data['uncertainty'] = uncertainty

In [15]:
percentile_75 = np.percentile(uncertainty, 75)
confident_mask = uncertainty <= percentile_75

synthetic_test_data['confident'] = confident_mask.astype(int)

In [16]:
synthetic_test_data['confident'].value_counts()

confident
1    34413
0    11471
Name: count, dtype: int64

# Extract slide-level features

### subset to confident features OUTSIDE of tumor boundary

In [17]:
optimal_cutoff = 0.5
synthetic_test_data['rs_pred'] = (synthetic_test_data['mean_pred'] > optimal_cutoff).astype(int)
model_preds_no_cancer = synthetic_test_data[synthetic_test_data['in_cancer'] == 0]
confident_set = model_preds_no_cancer[model_preds_no_cancer['confident'] == 1]

In [18]:
model_preds_no_cancer.shape, confident_set.shape

((41066, 11), (30837, 11))

In [19]:
slide_names = synthetic_test_data['slide_name'].unique()
len(slide_names)

5

In [20]:
basic_stats_all = []

connected_comp_size = 3
hotspot_feats_all = []

centroids_dict = {}
region_probs_dict = {}
region_patch_counts_dict = {}

distance_feats_all = []

all_topo_feats = []



for slide_name in tqdm(slide_names):
    current_slide = confident_set[confident_set['slide_name'] == slide_name].copy()
    current_all_preds = model_preds_no_cancer[model_preds_no_cancer['slide_name'] == slide_name].copy()
    
    ## BASIC FEATURES
    basic_stats_feats = calculate_basic_stats(current_slide, current_all_preds)
    basic_stats_feats['slide_name'] = slide_name
    basic_stats_all.append(basic_stats_feats)
    
    
    ## HOTSPOT FEATURES
    features, region_patch_counts, region_probs, centroids = full_slide_features(
        current_slide, len(current_all_preds), connected_comp_size)
    features['slide_name'] = slide_name
    hotspot_feats_all.append(features)
    centroids_dict[slide_name] = centroids
    region_probs_dict[slide_name] = region_probs
    region_patch_counts_dict[slide_name] = region_patch_counts
    
    ## DISTANCE FEATURES
    distance_feats = extract_distance_feats(confident_set, distance_col = 'distance_from_cancer_mm')
    distance_feats['slide_name'] = slide_name
    distance_feats_all.append(distance_feats)
    
    ## GRAPH FEATURES
    topo_feats = get_graph_feats(confident_set, centroids_dict[slide_name])
    topo_feats['slide_name'] = slide_name
    all_topo_feats.append(topo_feats)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [02:50<00:00, 34.17s/it]


In [21]:
basic_stats_all = pd.DataFrame(basic_stats_all)
hotspot_feats_all = pd.DataFrame(hotspot_feats_all)
distance_feats_all = pd.DataFrame(distance_feats_all)
all_topo_feats = pd.DataFrame(all_topo_feats)

all_feats = basic_stats_all.merge(hotspot_feats_all, on = 'slide_name', how = 'outer')
all_feats = all_feats.merge(distance_feats_all, on = 'slide_name', how = 'outer')
all_feats = all_feats.merge(all_topo_feats, on = 'slide_name', how = 'outer')


all_feats = all_feats.merge(synthetic_clinical, on = 'slide_name', how = 'left')
all_feats.to_csv(os.path.join(data_root, 'all_features.csv'))



## All Features

In [23]:
all_feats

Unnamed: 0,mean_pos_prob,mean_neg_prob,mean_prob,median_pos_prob,median_neg_prob,median_prob,std_pos_prob,std_neg_prob,std_prob,entropy_pos_prob,...,mean_betweenness,max_betweenness,std_betweenness,high_betweenness_count,patient_id,age,gleason_score,gleason_primary,gleason_secondary,grade_group
0,0.826744,0.314419,0.768338,0.843264,0.321926,0.825559,0.106192,0.107062,0.194456,8.123558,...,0.010261,0.050263,0.009009,0,synthetic_p0,59.723587,6,3,3,1
1,0.829388,0.312026,0.771129,0.846205,0.31204,0.82915,0.105026,0.107842,0.194546,8.39231,...,0.010664,0.055323,0.00996,0,synthetic_p1,70.42174,7,3,4,2
2,0.829891,0.314632,0.774334,0.845004,0.313073,0.828499,0.103515,0.108041,0.190683,8.93208,...,9.2e-05,0.00242,0.000414,0,synthetic_p2,66.509132,7,4,3,3
3,0.829491,0.312457,0.773442,0.84565,0.31649,0.829632,0.104211,0.107802,0.191787,8.787983,...,0.020861,0.224166,0.031366,18,synthetic_p3,67.744723,6,3,3,1
4,0.827741,0.313178,0.77065,0.843201,0.310456,0.826961,0.104772,0.10803,0.192806,8.617762,...,0.009284,0.0552,0.008078,0,synthetic_p4,65.092791,8,5,3,4
