In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
import boto3
import os
from botocore.exceptions import ClientError
import json

In [21]:
print(os.getcwd())

/Users/spandan/Projects/dxhub/immunization-indicator-classifier


In [None]:
REPO_NAME = "immunization-indicator-classifier"
while os.getcwd().split("/")[-1] != REPO_NAME:
    os.chdir("..")

In [None]:
os.environ["AWS_PROFILE"] = "profile_name"
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

In [None]:
sts = boto3.client("sts")
BEDROCK_ROLE_ARN = "role_arn"
sts.assume_role(
    RoleArn=BEDROCK_ROLE_ARN,
    RoleSessionName="session-name",
)

In [25]:
bedrock = boto3.client("bedrock-runtime")

In [26]:
TABLE = "diseases-attributes"
dynamodb = boto3.client("dynamodb")


def get_csdi_objects(codes: list[str]) -> list[dict]:
    try:
        response = dynamodb.batch_get_item(
            RequestItems={
                TABLE: {
                    "Keys": [{"csdi_code": {"S": code}} for code in codes],
                },
            },
        )
        return response["Responses"][TABLE]
    except ClientError as e:
        print(e)
        print("Error getting items from table")
        raise e

In [27]:
def semantic_contextualize(code_obj: dict) -> str:
    """
    Takes a JSON object representing a CSDi code and constructs a paragraph around it.
    """
    name = code_obj["disease_name"]["S"]
    medications = [med["S"] for med in code_obj["medications"]["M"]["medications"]["L"]]
    observations = [
        obs["S"]
        for obs in code_obj["observations / symptoms"]["M"]["observations / symptoms"][
            "L"
        ]
    ]
    disorders = [dis["S"] for dis in code_obj["disorders"]["M"]["disorders"]["L"]]

    paragraph = f"""The patient has the condition {name}.
    The patient has the following observations and symptoms: {", ".join(observations)}.
    They have the following disorders: {", ".join(disorders)}.
    They are taking the medications {", ".join(medications)}.
    """
    return paragraph

In [28]:
test_codes = [str(i) for i in range(114, 120)]
test_objects = get_csdi_objects(test_codes)

In [29]:
test_objects[0].keys()

dict_keys(['csdi_code', 'disorders', 'medications', 'disease_name', 'observations / symptoms'])

In [30]:
paragraph = semantic_contextualize(test_objects[0])

In [31]:
def embed_doc_attn(doc: str) -> list[float]:
    """
    Takes a document and embeds it using a Bedrock embedding model.
    """
    EMBED_MODEL = "amazon.titan-embed-text-v2:0"
    try:
        native_request = {
            "inputText": doc,
            "dimensions": 1024,
            "normalize": True,
            "embeddingTypes": ["float"],
        }
        request = json.dumps(native_request)
        response = bedrock.invoke_model(
            modelId=EMBED_MODEL,
            body=request,
            accept="application/json",
            contentType="application/json",
        )
        vec = json.loads(response["body"].read())["embeddingsByType"]["float"]
        return vec
    except ClientError as e:
        print(e)
        print("Error invoking bedrock endpoint")
        raise e

In [32]:
from collections import Counter
import re
import nltk
from nltk.corpus import stopwords
import string
from sklearn.feature_extraction.text import TfidfVectorizer

# Download stopwords if not already downloaded
nltk.download("stopwords")


def clean_text(text: str) -> str:
    """
    Cleans the input text by removing punctuation and stop words.
    """
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = text.lower()
    stop_words = set(stopwords.words("english"))
    text = " ".join(word for word in text.split() if word not in stop_words)
    return text


def derive_vocabulary(csdi_objects: list[dict]) -> set[str]:
    """
    Derives a vocabulary from a list of CSDi JSON objects.
    """
    vocabulary = set()
    for obj in csdi_objects:
        fields = [
            obj["disease_name"]["S"],
            *[med["S"] for med in obj["medications"]["M"]["medications"]["L"]],
            *[
                obs["S"]
                for obs in obj["observations / symptoms"]["M"][
                    "observations / symptoms"
                ]["L"]
            ],
            *[dis["S"] for dis in obj["disorders"]["M"]["disorders"]["L"]],
        ]
        for field in fields:
            cleaned_text = clean_text(field)
            words = re.findall(r"\b\w+\b", cleaned_text)
            vocabulary.update(words)
    return vocabulary


def tfidf_embedding(
    vocabulary: set[str], csdi_objects: list[dict]
) -> list[list[float]]:
    """
    Embeds a list of CSDi JSON objects using TF-IDF with respect to the given vocabulary.
    """
    vectorizer = TfidfVectorizer(vocabulary=vocabulary)
    documents = [semantic_contextualize(obj) for obj in csdi_objects]
    tfidf_matrix = vectorizer.fit_transform(documents)
    return tfidf_matrix.toarray()


vocabulary = derive_vocabulary(test_objects)

tfidf_embeddings = tfidf_embedding(vocabulary, test_objects)
vocabulary = derive_vocabulary(test_objects)
print("Vocabulary:", vocabulary)

Vocabulary: {'severe', 'hives', 'protamine', 'context', 'kidney', 'cramps', 'finding', 'meningococcal', 'arbs', 'dizziness', 'phosphate', 'decreased', 'previous', 'fatigue', 'difficulty', 'allergy', 'cholera', 'b', 'wheezing', 'procedure', 'muscle', 'erythropoietin', 'dependence', 'dose', 'nausea', 'vomiting', 'hypertension', 'therapy', 'epinephrine', 'disorder', 'throat', 'prednisone', 'pain', 'warfarin', 'adverse', 'medications', 'bone', 'lips', 'diphtheria', 'group', 'abdominal', 'diabetes', 'swelling', 'mellitus', 'anticoagulants', 'substance', 'hemodialysis', 'diphenhydramine', 'angioedema', 'face', 'disease', 'insulin', 'eg', 'respiratory', 'binders', 'end', 'disorders', 'allergic', 'continuous', 'cognitive', 'heparin', 'antihypertensive', 'reaction', 'edema', 'tachycardia', 'anaphylaxis', 'ace', 'corticosteroids', 'toxoid', 'stage', 'itching', 'diarrhea', 'distress', 'antihistamines', 'chronic', 'urticaria', 'mineral', 'cardiovascular', 'vaccine', 'caused', 'inhibitors', 'rash',

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/spandan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [36]:
def embed_doc_stat(objs: list[str], vocabulary: set[list]) -> list[float]:
    """
    Uses a TF-IDF model to embed a document in the TF-IDF space derived from the CSDi codes.
    """

    return tfidf_embedding(vocabulary, objs)

In [37]:
import numpy as np


def embed_csdi_codes(
    codes: list[str], vocabulary: set[str], return_objs: bool = False
) -> list[dict]:
    """
    Embeds a list of CSDi codes using the embeding space of a Bedrock model and
    a pure (paradigmatic semantic) TF-IDF model.
    """
    code_objects = get_csdi_objects(codes)
    paragraphs = [semantic_contextualize(obj) for obj in code_objects]
    attn_embeddings = [embed_doc_attn(paragraph) for paragraph in paragraphs]
    stat_embeddings = embed_doc_stat(code_objects, vocabulary)
    if return_objs:
        return np.array(attn_embeddings), np.array(stat_embeddings), code_objects
    return np.array(attn_embeddings), np.array(stat_embeddings)


embed_csdi_codes.col_name = "vect"

In [40]:
attn_embed, stat_embed = embed_csdi_codes(test_codes, vocabulary=vocabulary)
print(attn_embed)
print(stat_embed)

[[-0.00403074  0.02435845 -0.03966947 ... -0.03850954  0.06495586
   0.00991737]
 [-0.04367695 -0.0319647  -0.04465298 ... -0.00982121  0.07564165
   0.0046666 ]
 [ 0.00022379  0.01963892 -0.0518641  ...  0.00781217  0.04340092
   0.01779438]
 [-0.00460978  0.01658431 -0.03316863 ... -0.00796484  0.05106223
  -0.03185934]
 [ 0.02005979  0.02082294 -0.10247938 ...  0.03008969  0.03532268
  -0.0185335 ]
 [-0.01146274 -0.00603447 -0.10228288 ...  0.0301999   0.0414422
   0.02083132]]
[[0.         0.         0.2199518  0.137421   0.         0.137421
  0.         0.15912944 0.         0.137421   0.         0.
  0.         0.         0.         0.         0.18569843 0.
  0.         0.         0.2199518  0.53645846 0.         0.
  0.         0.15912944 0.         0.137421   0.         0.
  0.         0.         0.         0.18569843 0.         0.
  0.         0.15912944 0.11906678 0.         0.         0.
  0.2199518  0.         0.         0.         0.137421   0.
  0.         0.         0.  

In [45]:
import pandas as pd

attn_embed_df = pd.DataFrame(attn_embed, index=test_codes)
stat_embed_df = pd.DataFrame(stat_embed, index=test_codes)

In [46]:
attn_embed_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
114,-0.004031,0.024358,-0.039669,0.051733,0.051269,-0.012063,0.046397,-0.028534,-0.019719,0.011831,...,-0.011425,0.029926,0.039669,-0.020647,-0.012179,0.011541,0.006699,-0.03851,0.064956,0.009917
115,-0.043677,-0.031965,-0.044653,-0.012444,0.011102,-0.011224,0.000812,-0.000183,0.003218,-0.08589,...,-0.040505,0.012871,-0.002638,-0.047337,0.042945,0.040505,-0.026353,-0.009821,0.075642,0.004667
116,0.000224,0.019639,-0.051864,0.006212,0.024413,-0.023219,0.040797,-0.059025,-0.034938,-0.044486,...,-0.030815,0.001709,0.026149,-0.032985,-0.050128,0.003106,-0.024088,0.007812,0.043401,0.017794
117,-0.00461,0.016584,-0.033169,0.053026,0.056299,0.004937,0.02564,-0.063719,-0.012384,-0.015602,...,-0.01162,0.04517,0.047571,-0.026295,0.012111,-0.008565,-0.030768,-0.007965,0.051062,-0.031859
118,0.02006,0.020823,-0.102479,0.037721,0.041428,-0.046225,0.050149,-0.040556,0.010248,-0.071518,...,-0.003707,0.002548,0.009648,-0.015263,-0.010411,-0.011829,0.030308,0.03009,0.035323,-0.018534
119,-0.011463,-0.006034,-0.102283,0.022705,0.086411,0.006448,0.012675,-0.017745,0.02513,-0.032184,...,-0.015761,0.023917,0.011793,-0.029318,-0.000599,0.009809,0.021713,0.0302,0.041442,0.020831


In [48]:
stat_embed_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,87,88,89,90,91,92,93,94,95,96
114,0.0,0.0,0.219952,0.137421,0.0,0.137421,0.0,0.159129,0.0,0.137421,...,0.0,0.0,0.0,0.0,0.0,0.0,0.219952,0.0,0.0,0.219952
115,0.0,0.118666,0.0,0.0,0.0,0.0,0.118666,0.0,0.118666,0.0,...,0.0,0.118666,0.0,0.0,0.0,0.0,0.0,0.082154,0.118666,0.0
116,0.0,0.0,0.302053,0.094358,0.0,0.094358,0.0,0.109264,0.0,0.094358,...,0.0,0.0,0.0,0.0,0.0,0.151027,0.302053,0.0,0.0,0.0
117,0.0,0.0,0.0,0.126896,0.0,0.126896,0.0,0.0,0.0,0.126896,...,0.0,0.0,0.247686,0.247686,0.203106,0.0,0.0,0.171476,0.0,0.0
118,0.0,0.0,0.0,0.188456,0.0,0.188456,0.0,0.218226,0.0,0.188456,...,0.367843,0.0,0.0,0.0,0.301637,0.0,0.0,0.0,0.0,0.301637
119,0.138728,0.0,0.0,0.071074,0.277457,0.071074,0.0,0.082302,0.0,0.071074,...,0.0,0.0,0.0,0.0,0.0,0.113759,0.0,0.096043,0.0,0.0


In [None]:
code_objs

Cosine similarity seems like a better metric to use as compared to simple Euclidean distance:

In [None]:
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import cosine_similarity


def dist_metrics_dfs(embedding_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Computes the pairwise cosine similarity and Euclidean distance matrices for a given embedding DataFrame.
    Returns the results as DataFrames indexed by the CSDi code.
    """
    cosine_sim = pd.DataFrame(
        cosine_similarity(embedding_df),
        index=embedding_df.index,
        columns=embedding_df.index,
    )
    euclidean_dist = pd.DataFrame(
        squareform(pdist(embedding_df, metric="euclidean")),
        index=embedding_df.index,
        columns=embedding_df.index,
    )
    return cosine_sim, euclidean_dist

In [59]:
attn_cos_sim, attn_euc_dist = dist_metrics_dfs(attn_embed_df)
stat_cos_sim, stat_euc_dist = dist_metrics_dfs(stat_embed_df)

distances = {
    "attn": {
        "cosine_similarity": attn_cos_sim,
        "euclidean_distance": attn_euc_dist,
    },
    "stat": {
        "cosine_similarity": stat_cos_sim,
        "euclidean_distance": stat_euc_dist,
    },
}

In [60]:
distances

{'attn': {'cosine_similarity':           114       115       116       117       118       119
  114  1.000000  0.298031  0.600751  0.568322  0.540199  0.510626
  115  0.298031  1.000000  0.288192  0.283026  0.243350  0.386240
  116  0.600751  0.288192  1.000000  0.509444  0.481716  0.520846
  117  0.568322  0.283026  0.509444  1.000000  0.588259  0.478278
  118  0.540199  0.243350  0.481716  0.588259  1.000000  0.496233
  119  0.510626  0.386240  0.520846  0.478278  0.496233  1.000000,
  'euclidean_distance':           114       115       116       117       118       119
  114  0.000000  1.184879  0.893587  0.929170  0.958959  0.989317
  115  1.184879  0.000000  1.193154  1.197475  1.230163  1.107935
  116  0.893587  1.193154  0.000000  0.990511  1.018120  0.978933
  117  0.929170  1.197475  0.990511  0.000000  0.907459  1.021491
  118  0.958959  1.230163  1.018120  0.907459  0.000000  1.003760
  119  0.989317  1.107935  0.978933  1.021491  1.003760  0.000000},
 'stat': {'cosine_simi

In [64]:
def distances_var(adj: pd.DataFrame) -> float:
    """
    Computes the variance of document adjacency matrices.
    """
    # Get just the upper diagonal
    flattened = adj.where(np.triu(np.ones(adj.shape), k=1).astype(bool)).stack()
    return np.var(flattened)


variances = {
    "attn": {
        "cosine_similarity": distances_var(attn_cos_sim),
        "euclidean_distance": distances_var(attn_euc_dist),
    },
    "stat": {
        "cosine_similarity": distances_var(stat_cos_sim),
        "euclidean_distance": distances_var(stat_euc_dist),
    },
}

variances

{'attn': {'cosine_similarity': np.float64(0.013596601043872062),
  'euclidean_distance': np.float64(0.01191713020044291)},
 'stat': {'cosine_similarity': np.float64(0.024104877546204127),
  'euclidean_distance': np.float64(0.01684039348925827)}}

In [66]:
from form import Form

In [67]:
form = Form("shortened_example.txt")
form.print_form()

Adelina682 Julissa825 Mante251
Race:                White
Ethnicity:           Non-Hispanic
Gender:              F
Age:                 DECEASED
Birth Date:          1931-11-26
Marital Status:      M
--------------------------------------------------------------------------------
ALLERGIES:
No Known Allergies
--------------------------------------------------------------------------------
MEDICATIONS:
  2013-06-06[STOPPED] : sodium fluoride 0.0272 MG/MG Oral Gel for Patient referral for dental care (procedure)
  2012-05-31[STOPPED] : sodium fluoride 0.0272 MG/MG Oral Gel for Gingivitis (disorder)
  2011-06-23[STOPPED] : Acetaminophen 21.7 MG/ML / Dextromethorphan Hydrobromide 1 MG/ML / doxylamine succinate 0.417 MG/ML Oral Solution for Acute bronchitis (disorder)
  2011-05-26[STOPPED] : sodium fluoride 0.0272 MG/MG Oral Gel for Gingivitis (disorder)
  2011-03-03[STOPPED] : sodium fluoride 0.0272 MG/MG Oral Gel for Gingivitis (disorder)
  2010-04-25[CURRENT] : Memantine hydrochloride 2 

In [68]:
conditions_df = form.process_conditions_section()
conditions_df.head()

Unnamed: 0,start,end,description,type
0,2014-04-03,,Viral sinusitis,disorder
1,2014-03-27,,Part-time employment,finding
2,2014-03-27,2014-03-27,Medication review due,situation
3,2013-05-23,,Stress,finding
4,2013-05-23,2014-03-27,Not in labor force,finding


In [69]:
observations_df = form.process_observations_section()
observations_df.head()

Unnamed: 0,date,content
0,2014-04-24,Cause of Death [US Standard Certificate of Dea...
1,2014-03-27,Patient Health Questionnaire 2 item (PHQ-2) to...
2,2014-03-27,Total score [HARK] 0.0 {...
3,2014-03-27,Fall risk level [Morse Fall Scale] High ...
4,2014-03-27,Fall risk total [Morse Fall Scale] 101.0...


In [71]:
medications_df = form.process_medications_section()
medications_df.head()

Unnamed: 0,date,status,medication,reason,type
0,2013-06-06,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Patient referral for dental care,procedure
1,2012-05-31,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder
2,2011-06-23,STOPPED,Acetaminophen 21.7 MG/ML / Dextromethorphan Hy...,Acute bronchitis,disorder
3,2011-05-26,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder
4,2011-03-03,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder


In [86]:
# put conditions that haven't ended at the beginning
ongoing_conditions = conditions_df[conditions_df["end"].isna()].sort_values("start", ascending=False)
ended_conditions = conditions_df[~conditions_df["end"].isna()].sort_values("end", ascending=False)
conditions_df = pd.concat([ongoing_conditions, ended_conditions])

stopped_mask = medications_df["status"] == "STOPPED"
stopped_medications = medications_df[stopped_mask].sort_values("date", ascending=False)
ongoing_medications = medications_df[~stopped_mask].sort_values("date", ascending=False)
medications_df = pd.concat([ongoing_medications, stopped_medications])

observations_df.sort_values("date", ascending=True)

Unnamed: 0,date,content
29,2011-02-10,Protocol for Responding to and Assessing Patie...
27,2013-05-23,Fall risk total [Morse Fall Scale] 12.0 {#}
26,2013-05-23,Fall risk level [Morse Fall Scale] Low R...
25,2013-05-23,Patient Health Questionnaire 2 item (PHQ-2) to...
28,2013-05-23,Protocol for Responding to and Assessing Patie...
24,2013-10-03,Total score [MMSE] 4.7 {...
23,2014-03-27,Hemoglobin A1c/Hemoglobin.total in Blood 6.4 %
22,2014-03-27,Body Height 159.2 cm
21,2014-03-27,Pain severity - 0-10 verbal numeric rating [Sc...
20,2014-03-27,Body Weight 69.9 kg


In [82]:
conditions_df

Unnamed: 0,start,end,description,type
0,2014-04-03,,Viral sinusitis,disorder
1,2014-03-27,,Part-time employment,finding
3,2013-05-23,,Stress,finding
19,2009-04-30,,Alzheimer's disease,disorder
2,2014-03-27,2014-03-27,Medication review due,situation
4,2013-05-23,2014-03-27,Not in labor force,finding
5,2013-05-23,2013-05-23,Medication review due,situation
7,2012-05-17,2013-05-23,Reports of violence in the environment,finding
13,2011-05-12,2013-05-23,Full-time employment,finding
6,2012-05-17,2012-05-31,Gingivitis,disorder


In [87]:
medications_df

Unnamed: 0,date,status,medication,reason,type
5,2010-04-25,CURRENT,Memantine hydrochloride 2 MG/ML Oral Solution,Alzheimer's disease,disorder
0,2013-06-06,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Patient referral for dental care,procedure
1,2012-05-31,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder
2,2011-06-23,STOPPED,Acetaminophen 21.7 MG/ML / Dextromethorphan Hy...,Acute bronchitis,disorder
3,2011-05-26,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder
4,2011-03-03,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Gingivitis,disorder
6,2009-04-30,STOPPED,Galantamine 4 MG Oral Tablet,Alzheimer's disease,disorder
7,2008-08-21,STOPPED,Amoxicillin 250 MG / Clavulanate 125 MG Oral T...,Acute bacterial sinusitis,disorder
8,2007-05-03,STOPPED,sodium fluoride 0.0272 MG/MG Oral Gel,Patient referral for dental care,procedure
9,2006-07-06,STOPPED,Acetaminophen 325 MG Oral Tablet,Acute bronchitis,disorder


In [94]:
med_group_df = medications_df[["medication"]]
med_group_df["native_index"] = med_group_df.index
med_group_df

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  med_group_df["native_index"] = med_group_df.index


Unnamed: 0,medication,native_index
5,Memantine hydrochloride 2 MG/ML Oral Solution,5
0,sodium fluoride 0.0272 MG/MG Oral Gel,0
1,sodium fluoride 0.0272 MG/MG Oral Gel,1
2,Acetaminophen 21.7 MG/ML / Dextromethorphan Hy...,2
3,sodium fluoride 0.0272 MG/MG Oral Gel,3
4,sodium fluoride 0.0272 MG/MG Oral Gel,4
6,Galantamine 4 MG Oral Tablet,6
7,Amoxicillin 250 MG / Clavulanate 125 MG Oral T...,7
8,sodium fluoride 0.0272 MG/MG Oral Gel,8
9,Acetaminophen 325 MG Oral Tablet,9


In [None]:
from sklearn.cluster import SpectralClustering


def medication_sentence_maker(medication_name: str, reason: str) -> str:
    """
    A little processing is necessary to express medicines before TF-IDF embedding
    in the CSDi vocabulary.
    """
    return f"The patient is taking {medication_name} for {reason}."


def embed_and_group(
    medications_df: pd.DataFrame,
    conditions_df: pd.DataFrame,
    observations_df: pd.DataFrame,
    csdi_vocabulary: set[str],
) -> pd.DataFrame:
    """
    Mutates dataframes to assign clusters based on cosine similarity over some
    threshold.
    """
    medications_df["doc"] = medications_df.apply(
        lambda row: medication_sentence_maker(row["medication"], row["reason"]), axis=1
    )

    med_group_df = medications_df[["doc", "native_index"]]
    med_group_df["record_type"] = "m"

    cond_group_df = conditions_df[["description", "native_index"]]
    cond_group_df = cond_group_df.rename(columns={"description": "doc"})
    cond_group_df["record_type"] = "c"

    obs_group_df = observations_df[["content", "native_index"]]
    obs_group_df = obs_group_df.rename(columns={"content": "doc"})
    obs_group_df["record_type"] = "o"

    # Create a dataframe of documents that coalesces the medications, observations, and conditions.
    # Embed all of the documents, then group based on cosine similarity over some threshold.
    documents_df = pd.concat(
        [med_group_df, cond_group_df, obs_group_df],
        ignore_index=True,
        axis=0,
    )

    documents_df["vec"] = documents_df["doc"].apply(clean_text)
    documents_df["vec"] = documents_df["vec"].apply(
        lambda x: " ".join(re.findall(r"\b\w+\b", x))
    )

    vectorizer = TfidfVectorizer(vocabulary=csdi_vocabulary)
    tfidf_matrix = vectorizer.fit_transform(documents_df["vec"])
    documents_df["vec"] = tfidf_matrix.toarray().tolist()

    # Perform spectral clustering
    n_clusters = 5  # You can adjust the number of clusters as needed
    spectral = SpectralClustering(n_clusters=n_clusters, affinity="precomputed")
    cosine_sim_matrix = cosine_similarity(tfidf_matrix)
    documents_df["cluster"] = spectral.fit_predict(cosine_sim_matrix)

    docs_med = documents_df[documents_df["record_type"] == "m"]
    docs_cond = documents_df[documents_df["record_type"] == "c"]
    docs_obs = documents_df[documents_df["record_type"] == "o"]

    medications_df = medications_df.merge(
        docs_med, on="native_index", suffixes=("", "_doc")
    )
    conditions_df = conditions_df.merge(
        docs_cond, on="native_index", suffixes=("", "_doc")
    )
    observations_df = observations_df.merge(
        docs_obs, on="native_index", suffixes=("", "_doc")
    )
    return medications_df, conditions_df, observations_df

    # # compute the 95th percentile of cosine similarity
    # cosine_sim = cosine_similarity(tfidf_matrix)
    # cosine_sim_flat = cosine_sim[np.triu_indices(cosine_sim.shape[0], k=1)]
    # threshold = np.percentile(cosine_sim_flat, 95)
    # del cosine_sim_flat

    # group_mask = cosine_sim > threshold

    # # get the row and column of entries such that group_mask
    # group_indices = np.where(group_mask)

    # # this creates a (hopefully relatively sparse?) adjacency list
    # r1, r2 = group_indices

    # # create a dataframe of the groups. the cols should be the native_index and record_type
    # # of the original documents
    # groups_df = pd.DataFrame(
    #     {
    #         "native_index": documents_df["native_index"].iloc[r1].tolist(),
    #         "record_type": documents_df["record_type"].iloc[r1].tolist(),
    #     }
    # )

    # assign groups based on the adjacencies between documents (clique finding)

    # return groups_df

In [189]:
groups = embed_and_group(medications_df, conditions_df, observations_df, vocabulary)
groups

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  med_group_df["record_type"] = "m"


(          date   status                                         medication  \
 0   2010-04-25  CURRENT      Memantine hydrochloride 2 MG/ML Oral Solution   
 1   2010-04-25  CURRENT      Memantine hydrochloride 2 MG/ML Oral Solution   
 2   2010-04-25  CURRENT      Memantine hydrochloride 2 MG/ML Oral Solution   
 3   2013-06-06  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 4   2013-06-06  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 5   2013-06-06  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 6   2012-05-31  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 7   2012-05-31  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 8   2012-05-31  STOPPED              sodium fluoride 0.0272 MG/MG Oral Gel   
 9   2011-06-23  STOPPED  Acetaminophen 21.7 MG/ML / Dextromethorphan Hy...   
 10  2011-06-23  STOPPED  Acetaminophen 21.7 MG/ML / Dextromethorphan Hy...   
 11  2011-06-23  STOPPED  Acetaminophen 21.7 MG/ML /

In [190]:
medications_df["cluster"]

KeyError: 'cluster'

In [116]:
test_mat = np.random.rand(10, 10)
test_mat

array([[0.52855603, 0.39353799, 0.87400479, 0.86396869, 0.1902567 ,
        0.39317449, 0.03861874, 0.39942849, 0.52255281, 0.63999519],
       [0.67531274, 0.44963277, 0.70710816, 0.97496196, 0.08686605,
        0.41582403, 0.10589334, 0.47632367, 0.44571601, 0.13606867],
       [0.77452144, 0.62257026, 0.10321121, 0.00374146, 0.6486583 ,
        0.79309495, 0.29555079, 0.76823478, 0.37146544, 0.1309319 ],
       [0.00555091, 0.74950111, 0.6737189 , 0.02258497, 0.60234826,
        0.38650864, 0.39188961, 0.98045164, 0.82183001, 0.27645484],
       [0.61422659, 0.26991787, 0.46684592, 0.04228864, 0.74713263,
        0.02828127, 0.28106418, 0.44796535, 0.67457211, 0.4813338 ],
       [0.7663723 , 0.27428195, 0.29283645, 0.74751111, 0.15758178,
        0.8321921 , 0.25085573, 0.04589311, 0.22005681, 0.38063593],
       [0.06910535, 0.89218739, 0.0253104 , 0.29454438, 0.38572411,
        0.13252751, 0.50647225, 0.23205673, 0.77756757, 0.5417445 ],
       [0.84846453, 0.33123915, 0.1936128

In [117]:
cond_mask = test_mat > 0.3
np.where(cond_mask)

(array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3,
        3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7,
        7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9]),
 array([0, 1, 2, 3, 5, 7, 8, 9, 0, 1, 2, 3, 5, 7, 8, 0, 1, 4, 5, 7, 8, 1,
        2, 4, 5, 6, 7, 8, 0, 2, 4, 7, 8, 9, 0, 3, 5, 9, 1, 4, 6, 8, 9, 0,
        1, 3, 5, 6, 7, 8, 9, 0, 2, 3, 6, 7, 8, 9, 1, 4, 5, 7, 8]))

In [114]:
groups

Unnamed: 0,doc,record_type,native_index,vec
0,The patient is taking Memantine hydrochloride ...,m,5,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,The patient is taking sodium fluoride 0.0272 M...,m,0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,The patient is taking sodium fluoride 0.0272 M...,m,1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,The patient is taking Acetaminophen 21.7 MG/ML...,m,2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,The patient is taking sodium fluoride 0.0272 M...,m,3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
...,...,...,...,...
59,Patient Health Questionnaire 2 item (PHQ-2) to...,o,25,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
60,Fall risk level [Morse Fall Scale] Low R...,o,26,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
61,Fall risk total [Morse Fall Scale] 12.0 {#},o,27,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
62,Protocol for Responding to and Assessing Patie...,o,28,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
