In [1]:
import numpy as np
import json
import pandas as pd

In [160]:

def generate_dynamic_few_shot_id(exclude_id, k_shot=3):

    # Load precomputed patient embeddings
    patient = json.load(open("patients_embeddings.json"))

    # Retrieve precomputed embedding for the excluded patient
    exclude_patient_embedding = np.array(patient[exclude_id])  

    # Get patient IDs and their corresponding embeddings
    patient_ids = list(patient.keys())  # List of patient IDs
    patient_embeddings = np.array(list(patient.values()))  # Matrix of embeddings

    # Compute cosine similarity
    cosine_sim = np.dot(patient_embeddings, exclude_patient_embedding)

    # Sort patient IDs based on similarity (descending order)
    sorted_ids = [patient_ids[i] for i in np.argsort(cosine_sim)[::-1] if patient_ids[i] != exclude_id]

    # Return the top k similar patients
    return sorted_ids[:k_shot]


In [171]:
generate_dynamic_few_shot_id("<https://pubmed.ncbi.nlm.nih.gov/31069201?Patient>")

['<https://pubmed.ncbi.nlm.nih.gov/27939403?C_II_2>',
 '<https://pubmed.ncbi.nlm.nih.gov/27939403?C_II_1>',
 '<https://pubmed.ncbi.nlm.nih.gov/26981933?Family_F_individual_F10>']

In [190]:
def get_name(ids):
    names = {'disease': None, 'phenotype': []}  # Start with an empty list for phenotypes
    
    with open("mapping/id2name_map(1).json", "r", encoding="utf-8") as file:
        id2name_map = json.load(file)

    # Flatten ids if they contain lists
    flat_ids = [item for sublist in ids for item in (sublist if isinstance(sublist, list) else [sublist])]

    for id in flat_ids:
        if "MONDO" in id:
            names['disease'] = id2name_map.get(id)  # For disease, just set the name

        elif "HP" in id or "PATO" in id:
            names['phenotype'].append(id2name_map.get(id))  # For phenotype, append to the list

    return names


In [191]:
def create_examples(exclude_id, k_shot=3):
    example = generate_dynamic_few_shot_id(exclude_id, k_shot)
    
    few_shot_examples = {}
    patient_info = {}

    # Load patient data from file
    with open("data/patients.json", "r", encoding="utf-8") as file:
        patients = json.load(file)

    # Select the examples based on exclude_id
    patient_info = {id: patients[id] for id in example}

    # Now pass individual patient IDs to get_name()
    for id, patient_ids in patient_info.items():
        few_shot_examples[id] = get_name(patient_ids)  # Pass the list of IDs (values) for each patient

    return few_shot_examples


In [188]:
create_examples("<https://pubmed.ncbi.nlm.nih.gov/31069201?Patient>")

{'<https://pubmed.ncbi.nlm.nih.gov/27939403?C_II_2>': {'disease': 'pancytopenia due to IKZF1 mutations',
  'phenotype': ['B lymphocytopenia',
   'Decreased circulating IgG concentration']},
 '<https://pubmed.ncbi.nlm.nih.gov/27939403?C_II_1>': {'disease': 'pancytopenia due to IKZF1 mutations',
  'phenotype': ['B lymphocytopenia',
   'Decreased circulating IgG concentration']},
 '<https://pubmed.ncbi.nlm.nih.gov/26981933?Family_F_individual_F10>': {'disease': 'pancytopenia due to IKZF1 mutations',
  'phenotype': ['Decreased circulating IgG concentration',
   'Decreased circulating total IgM']}}

In [192]:
create_examples("<https://pubmed.ncbi.nlm.nih.gov/38433265?index_case_patient_III_1>")

{'<https://pubmed.ncbi.nlm.nih.gov/10077612?Family_B>': {'disease': 'Holt-Oram syndrome',
  'phenotype': ['Abnormal carpal morphology']},
 '<https://pubmed.ncbi.nlm.nih.gov/10077612?Family_B_IV_10>': {'disease': 'Holt-Oram syndrome',
  'phenotype': ['Abnormal carpal morphology']},
 '<https://pubmed.ncbi.nlm.nih.gov/32154675?Family_6_Patient_14>': {'disease': 'aneurysm-osteoarthritis syndrome',
  'phenotype': ['Abnormal sternum morphology',
   'Joint hypermobility',
   'Aortic aneurysm',
   'Soft skin']}}