In [6]:
from datasets import load_dataset
import numpy as np
import faiss
import random


# Function to perform similarity search
def search(query_embedding, index, k=10):
    query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
    distances, indices = index.search(query_embedding, k)
    return distances, indices


# Load the clinical dataset
clinical_dataset = load_dataset("Lab-Rasool/TCGA", "clinical", split="gatortron")
clinical_embeddings = []
for item in clinical_dataset:
    embedding = np.frombuffer(item.get("embedding"), dtype=np.float32).reshape(
        item.get("embedding_shape")
    )
    clinical_embeddings.append(embedding.flatten())
clinical_embeddings_array = np.vstack(clinical_embeddings)
print(f"Clinical embeddings array shape: {clinical_embeddings_array.shape}")
# Build the FAISS index for the clinical embeddings
dimension = clinical_embeddings_array.shape[1]
index_clinical = faiss.IndexFlatL2(dimension)  # Using L2 distance
index_clinical.add(clinical_embeddings_array)

# Load the WSI dataset
wsi_dataset = load_dataset("Lab-Rasool/TCGA", "wsi", split="uni")
# Extract and concatenate all patch embeddings into a single array
wsi_all_embeddings = []
wsi_patient_ids = []
wsi_indices = []
for i, item in enumerate(wsi_dataset):
    embedding = np.frombuffer(item.get("embedding"), dtype=np.float32).reshape(
        item.get("embedding_shape")
    )
    num_patches = embedding.shape[0]
    wsi_all_embeddings.extend(embedding)
    wsi_patient_ids.extend([item["PatientID"]] * num_patches)
    wsi_indices.extend([(i, j) for j in range(num_patches)])
wsi_all_embeddings_array = np.vstack(wsi_all_embeddings)
print(f"WSI embeddings array shape: {wsi_all_embeddings_array.shape}")
# Build the FAISS index for the WSI embeddings
index_wsi = faiss.IndexFlatL2(dimension)  # Using L2 distance
index_wsi.add(wsi_all_embeddings_array)

# -----------------------------------------------------------------------------------------------

# Randomly pick one patch embedding from a random patient in the WSI dataset
random_patient_index = random.randint(0, len(wsi_dataset) - 1)
random_patch_index = random.randint(
    0, wsi_dataset[random_patient_index]["embedding_shape"][0] - 1
)
query_embedding = wsi_all_embeddings_array[
    wsi_indices.index((random_patient_index, random_patch_index))
]

# Perform similarity search on WSI embeddings
distances, indices = search(query_embedding, index_wsi, k=10)

# Get the PatientIDs of the original query embedding and the similar embeddings
original_patient_id = wsi_patient_ids[
    wsi_indices.index((random_patient_index, random_patch_index))
]

# original_patient_project
original_patient_project = None
for clinical_item in clinical_dataset:
    if clinical_item["case_submitter_id"] == original_patient_id:
        original_patient_project = clinical_item["project_id"]
        break
if original_patient_project is None:
    print(f"Project ID not found for PatientID: {original_patient_id}")

print(
    f"Original PatientID: {original_patient_id}, Original Patient Project: {original_patient_project}, Random Patient Index: {random_patient_index}, Random Patch Index: {random_patch_index}"
)
similar_patient_ids = [wsi_patient_ids[idx] for idx in indices[0]]

# Find the corresponding project IDs from the clinical dataset
project_ids = []
for patient_id in similar_patient_ids:
    for clinical_item in clinical_dataset:
        if clinical_item["case_submitter_id"] == patient_id:
            project_ids.append(clinical_item["project_id"])
            break

print("Similar PatientIDs and their Project IDs:")
for patient_id, project_id in zip(similar_patient_ids, project_ids):
    print(f"PatientID: {patient_id}, ProjectID: {project_id}")
print(f"Distances to similar items: {distances}")

Clinical embeddings array shape: (11428, 1024)
WSI embeddings array shape: (456387, 1024)
Original PatientID: TCGA-KN-8430, Original Patient Project: TCGA-KICH, Random Patient Index: 10492, Random Patch Index: 6
Similar PatientIDs and their Project IDs:
PatientID: TCGA-KN-8430, ProjectID: TCGA-KICH
PatientID: TCGA-RY-A83Z, ProjectID: TCGA-LGG
PatientID: TCGA-E2-A9RU, ProjectID: TCGA-BRCA
PatientID: TCGA-LQ-A4E4, ProjectID: TCGA-BRCA
PatientID: TCGA-ZJ-AAX4, ProjectID: TCGA-CESC
PatientID: TCGA-C8-A131, ProjectID: TCGA-BRCA
PatientID: TCGA-B0-4707, ProjectID: TCGA-KIRC
PatientID: TCGA-CV-A45Y, ProjectID: TCGA-HNSC
PatientID: TCGA-12-3649, ProjectID: TCGA-GBM
PatientID: TCGA-CV-5443, ProjectID: TCGA-HNSC
Distances to similar items: [[0.         0.03129907 0.03277805 0.03298555 0.03298659 0.03384599
  0.03439792 0.03646088 0.03675169 0.03732942]]
