In [3]:
from math import e
import numpy as np
from sklearn.metrics.pairwise import cosine_distances

def robustness_index(embeddings, biological_classes, medical_centers, k=50):
    """
    Calculate the robustness index R_k for a given dataset.

    Parameters:
    - embeddings: numpy array of shape (n_samples, n_features), the embeddings of the samples.
    - biological_classes: numpy array of shape (n_samples,), the biological class labels for each sample.
    - medical_centers: numpy array of shape (n_samples,), the medical center labels for each sample.
    - k: int, the MAXIMUM number of nearest neighbors to consider (default is 50).

    Returns:
    - R_k: float, the robustness index.
    """
    n_samples = embeddings.shape[0]
    
    # Compute pairwise cosine distances
    distances = cosine_distances(embeddings)
    
    # Initialize counters for numerator and denominator
    numerator = np.zeros(k)
    denominator = np.zeros(k)
    
    for i in range(n_samples):
        # Get the distances from the i-th sample to all others
        dist_to_others = distances[i]
        
        for j in range(1, k+1):
            # Find the indices of the k nearest neighbors (excluding itself)
            nearest_neighbors = np.argsort(dist_to_others)[1:j+1]            
            
            # Count the number of neighbors with the same biological class
            same_bio_class = np.sum(biological_classes[nearest_neighbors] == biological_classes[i])
            numerator[j-1] += same_bio_class
            
            # Count the number of neighbors from the same medical center
            same_med_center = np.sum(medical_centers[nearest_neighbors] == medical_centers[i])
            denominator[j-1] += same_med_center
    
    # Calculate the robustness index
    if np.sum(denominator) != 0:
        R_k = np.mean(numerator) / np.mean(denominator)
    else:
        R_k = 0 
            
    return R_k

In [5]:
import pickle

with open('/home/hai/Desktop/trials/_Efficient_Annotation/UNI_github/notebooks/bcss_all_stainnorm_emb.pkl', 'rb') as f:
    bcss_df = pickle.load(f)

bcss_df.head()

Unnamed: 0,FileName,Major_labs,Tissue,SiteID,PatientID,Hoptimus0,Hoptimus0_macenko,Hoptimus0_reinhard,UNI,UNI_macenko,...,Virchow_reinhard,ProvGigaPath,ProvGigaPath_macenko,ProvGigaPath_reinhard,phikon2_reinhard,phikon2,phikon2_macenko,phikon,phikon_macenko,phikon_reinhard
0,patch_1,3,Inflammation,A2,A1G6,"[9.386892318725586, 2.6558210849761963, 5.2113...","[0.2757183015346527, 7.881718158721924, 0.0712...","[-0.021643538028001785, 9.49809455871582, 8.82...","[9.217123985290527, 4.81818962097168, 5.033939...","[0.22901754081249237, 8.738898277282715, 1.424...",...,"[9.305919647216797, 10.870243072509766, 2.9676...","[9.4022798538208, 4.9569926261901855, 4.942555...","[0.2955218255519867, 9.455558776855469, 9.4714...","[0.08865882456302643, 2.310716390609741, 0.055...","[9.733214378356934, 3.191523790359497, 5.13570...","[8.198899269104004, 3.805241584777832, 4.99574...","[9.899100303649902, 0.03361562639474869, 0.976...","[7.694921970367432, 4.723465919494629, 5.06652...","[10.116778373718262, 7.506051540374756, 8.9885...","[9.574226379394531, 2.9870800971984863, 5.0316..."
1,patch_10,3,Inflammation,A2,A1G6,"[9.387246131896973, 2.6558589935302734, 5.2112...","[0.27751803398132324, 7.901925086975098, 0.071...","[-0.021869109943509102, 9.497689247131348, 8.8...","[9.215758323669434, 4.817876815795898, 5.03385...","[0.2274865061044693, 8.7388334274292, 1.424104...",...,"[9.305421829223633, 10.863147735595703, 2.9705...","[9.401104927062988, 4.956780433654785, 4.94249...","[0.29420167207717896, 9.457695007324219, 9.471...","[0.09160616248846054, 2.3105783462524414, 0.05...","[9.733199119567871, 3.1916344165802, 5.1351733...","[8.197966575622559, 3.8045284748077393, 4.9961...","[9.898354530334473, 0.0345156267285347, 0.9793...","[7.695714950561523, 4.7228875160217285, 5.0671...","[10.119156837463379, 7.505539894104004, 8.9828...","[9.57403564453125, 2.9870951175689697, 5.03193..."
2,patch_100,2,Stroma,A2,A0D0,"[9.602423667907715, 2.661569833755493, 5.20872...","[0.4720607101917267, 4.370136737823486, 0.0754...","[-0.38453924655914307, 9.493077278137207, 8.78...","[9.40522289276123, 4.843436241149902, 5.053972...","[1.6307872533798218, 2.1803367137908936, 0.865...",...,"[8.729555130004883, 11.917633056640625, 2.2772...","[9.341516494750977, 4.909373760223389, 4.97259...","[1.170862078666687, 6.4747490882873535, 9.8488...","[-0.5134817361831665, 2.2605855464935303, 0.04...","[9.715886116027832, 3.3313395977020264, 4.9930...","[5.693242073059082, 7.669175624847412, 4.98756...","[9.87010383605957, 0.1895584762096405, 8.38606...","[4.394604206085205, 7.5308613777160645, 3.5564...","[10.074995994567871, 7.703090190887451, 11.388...","[9.501676559448242, 3.1050126552581787, 4.9997..."
3,patch_1000,1,Tumor,AO,A0J2,"[9.654601097106934, 2.6842236518859863, 5.1010...","[0.582181453704834, 9.384984016418457, 0.09836...","[0.16584688425064087, 9.381376266479492, 8.756...","[9.662715911865234, 4.871389865875244, 5.10533...","[0.30600500106811523, 9.311566352844238, 1.508...",...,"[9.075523376464844, 5.24211311340332, 2.713519...","[9.569384574890137, 4.982851505279541, 4.94853...","[0.01773758791387081, 10.196998596191406, 9.38...","[0.35758697986602783, 2.2831103801727295, 0.04...","[9.792506217956543, 3.2389402389526367, 4.8827...","[4.088003635406494, 6.0379533767700195, 8.7737...","[9.88654613494873, 0.023365242406725883, 1.188...","[10.001880645751953, 5.141750335693359, 5.0529...","[9.989755630493164, 7.569756984710693, 9.13021...","[9.549050331115723, 3.1495110988616943, 5.1261..."
4,patch_10000,1,Tumor,BH,A0BL,"[9.662805557250977, 2.615037441253662, 5.09154...","[0.4039507508277893, 3.662231206893921, 0.0921...","[0.3497996926307678, 9.493809700012207, 8.7036...","[9.608992576599121, 4.822984218597412, 5.01442...","[1.7786369323730469, 1.3074615001678467, 0.943...",...,"[9.321929931640625, 1.196954607963562, 6.85257...","[9.507142066955566, 4.903769493103027, 4.98371...","[0.8053656816482544, 6.800408363342285, 9.7126...","[0.062062233686447144, 2.268545150756836, -0.0...","[9.763350486755371, 3.1787168979644775, 4.9173...","[9.635196685791016, 5.023486137390137, 5.00341...","[9.662620544433594, 0.19142776727676392, 7.633...","[10.011054992675781, 5.168505668640137, 5.0156...","[10.08000659942627, 7.560948371887207, 10.9823...","[9.59817886352539, 3.1327221393585205, 4.96545..."


In [6]:
# all_med = [med for med in bcss_df.columns.values if 'Virchow' in med]
all_med = bcss_df.columns.values[5:]
all_med 

array(['Hoptimus0', 'Hoptimus0_macenko', 'Hoptimus0_reinhard', 'UNI',
       'UNI_macenko', 'UNI_reinhard', 'Virchow2', 'Virchow2_macenko',
       'Virchow2_reinhard', 'UNI2h', 'UNI2h_macenko', 'UNI2h_reinhard',
       'Virchow', 'Virchow_macenko', 'Virchow_reinhard', 'ProvGigaPath',
       'ProvGigaPath_macenko', 'ProvGigaPath_reinhard',
       'phikon2_reinhard', 'phikon2', 'phikon2_macenko', 'phikon',
       'phikon_macenko', 'phikon_reinhard'], dtype=object)

In [7]:
all_Rk = np.zeros(len(all_med))
k = 50
biological_classes = np.array(bcss_df['Major_labs'])
medical_centers = np.array(bcss_df['SiteID'])

for i, med in enumerate(all_med):
    embeddings = np.array(bcss_df[med].tolist())
    R_k = robustness_index(embeddings, biological_classes, medical_centers, k)
    all_Rk[i] = R_k

In [8]:
print(f"Robustness Index R_k: {all_Rk}")
len(all_Rk)

Robustness Index R_k: [1.07677801 1.19396234 1.14270723 1.0786237  1.18755799 1.19313763
 1.20782095 1.32781398 1.34752304 1.11693221 1.16563545 1.16218609
 1.124292   1.34901705 1.2868736  1.06977897 1.20486982 1.18105545
 1.17473458 1.24608761 1.14051629 1.21200848 1.1386981  1.17886861]


24

In [9]:
import pandas as pd
Rk_all = pd.DataFrame({'Med': all_med, 'Rk': all_Rk})
Rk_all = Rk_all.sort_values(by='Med')
Rk_all

Unnamed: 0,Med,Rk
0,Hoptimus0,1.076778
1,Hoptimus0_macenko,1.193962
2,Hoptimus0_reinhard,1.142707
15,ProvGigaPath,1.069779
16,ProvGigaPath_macenko,1.20487
17,ProvGigaPath_reinhard,1.181055
3,UNI,1.078624
9,UNI2h,1.116932
10,UNI2h_macenko,1.165635
11,UNI2h_reinhard,1.162186
