In [1]:
from collections import defaultdict
import pickle as pkl
import pandas as pd
from sentence_transformers import SentenceTransformer, SimilarityFunction
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules

In [2]:
icd_df = pd.read_csv("data/raw/ICDCodeSet.csv")
icd_df['ICDCode']=icd_df['ICDCode'].str.strip()
icd_df['Description']=icd_df['Description'].str.strip()

# Filter on ICD10 on symptoms, injury, poisoning, and other causes
# df=df[df['ICDCode'].str.contains("^(R|S|T)", regex=True)].reset_index(drop=True)

# Filter on ICD10 symptoms that can be found without clinical and lab diag (R00-R69)
icd_df=icd_df[icd_df['ICDCode'].str.contains("^(R(0|1|2|3|4|5|6))", regex=True)].reset_index(drop=True)

# Part 1: Encode specific ICD10 Description in vector-based for similarity search

# embedding with "Qwen/Qwen3-Embedding-0.6B", 1024 vector-dimensions
embedding_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", similarity_fn_name=SimilarityFunction.COSINE)

documents = icd_df['Description'].values[:128]

document_embeddings = embedding_model.encode(documents, show_progress_bar=True, batch_size=32)

  icd_df=icd_df[icd_df['ICDCode'].str.contains("^(R(0|1|2|3|4|5|6))", regex=True)].reset_index(drop=True)


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
symptom_transaction_df = pd.read_csv("data/raw/[CONFIDENTIAL] AI symptom picker data (Agnos candidate assignment) - ai_symptom_picker.csv", usecols=['search_term'], nrows=30)

symptom_transaction_df['patient_idx']=range(symptom_transaction_df.shape[0])

symptom_transaction_df['search_term']=symptom_transaction_df['search_term'].str.strip()
symptom_transaction_df['search_term']=symptom_transaction_df['search_term'].str.split(',').apply(lambda list_x: [x.strip() for x in list_x if len(x.strip())>0])

In [4]:
symptom_transaction_df

Unnamed: 0,search_term,patient_idx
0,"[มีเสมหะ, ไอ]",0
1,"[ไอ, น้ำมูกไหล]",1
2,[ปวดท้อง],2
3,[น้ำมูกไหล],3
4,[ตาแห้ง],4
5,[ปวดกระดูก],5
6,"[น้ำมูกไหล, คันจมูกจามบ่อย, ไอ]",6
7,[ปวดท้อง],7
8,"[คันคอ, ไอ]",8
9,[ไอ],9


In [5]:
list_all_symptom_from_transaction = list(set(symptom_transaction_df['search_term'].explode().tolist()))
list_all_symptom_from_transaction

['ปวดกระดูก',
 'เจ็บคอ',
 'คันจมูกจามบ่อย',
 'ปวดข้อเท้า',
 'น้ำมูกไหล',
 'ไข้',
 'ปวดหลัง',
 'คันคอ',
 'บวม',
 'ปวดท้อง',
 'เสมหะไหลลงคอ',
 'Fever',
 'คันจมูก',
 'ปวดข้อมือ',
 'มีเสมหะ',
 'หายใจมีเสียงวี๊ด',
 'เสียงแหบ',
 'ไอ',
 'ท้องเสีย',
 'ถ่ายเป็นเลือดสด',
 'ตาแห้ง',
 'อาเจียน',
 'ปวดเมื่อยกล้ามเนื้อทั่วๆ',
 'หายใจหอบเหนื่อย']

In [6]:
query_embeddings = embedding_model.encode(list_all_symptom_from_transaction, prompt_name="query", show_progress_bar=True, batch_size=32)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
similarity = embedding_model.similarity(query_embeddings, document_embeddings)

In [8]:
similarity

tensor([[0.1852, 0.1807, 0.2117,  ..., 0.2033, 0.3342, 0.3473],
        [0.1733, 0.1802, 0.1866,  ..., 0.2603, 0.2556, 0.2279],
        [0.1834, 0.1635, 0.2115,  ..., 0.2370, 0.2365, 0.2263],
        ...,
        [0.2477, 0.2163, 0.2132,  ..., 0.2789, 0.2933, 0.2465],
        [0.1849, 0.1789, 0.1927,  ..., 0.2444, 0.3206, 0.3095],
        [0.2839, 0.3121, 0.3263,  ..., 0.2750, 0.2429, 0.2144]])

In [9]:
icd_df['ICDCode'].loc[similarity.argmin(dim=1)]

30       R064
113     R1900
59       R102
50       R093
77     R10826
31       R065
30       R064
77     R10826
72     R10821
30       R064
41       R072
116     R1903
59       R102
35      R0682
41       R072
57      R1012
64     R10811
55      R1010
8        R030
59       R102
7        R012
39       R070
50       R093
114     R1901
Name: ICDCode, dtype: object

In [10]:
dict_master_word_icd10 = dict(zip(list_all_symptom_from_transaction, icd_df['ICDCode'].loc[similarity.argmin(dim=1)]))
dict_master_word_icd10

{'ปวดกระดูก': 'R064',
 'เจ็บคอ': 'R1900',
 'คันจมูกจามบ่อย': 'R102',
 'ปวดข้อเท้า': 'R093',
 'น้ำมูกไหล': 'R10826',
 'ไข้': 'R065',
 'ปวดหลัง': 'R064',
 'คันคอ': 'R10826',
 'บวม': 'R10821',
 'ปวดท้อง': 'R064',
 'เสมหะไหลลงคอ': 'R072',
 'Fever': 'R1903',
 'คันจมูก': 'R102',
 'ปวดข้อมือ': 'R0682',
 'มีเสมหะ': 'R072',
 'หายใจมีเสียงวี๊ด': 'R1012',
 'เสียงแหบ': 'R10811',
 'ไอ': 'R1010',
 'ท้องเสีย': 'R030',
 'ถ่ายเป็นเลือดสด': 'R102',
 'ตาแห้ง': 'R012',
 'อาเจียน': 'R070',
 'ปวดเมื่อยกล้ามเนื้อทั่วๆ': 'R093',
 'หายใจหอบเหนื่อย': 'R1901'}

In [11]:
dict_icd10_master_word = defaultdict(list)
for idx, val in dict_master_word_icd10.items():
    dict_icd10_master_word[val].append(idx)
dict_icd10_master_word

defaultdict(list,
            {'R064': ['ปวดกระดูก', 'ปวดหลัง', 'ปวดท้อง'],
             'R1900': ['เจ็บคอ'],
             'R102': ['คันจมูกจามบ่อย', 'คันจมูก', 'ถ่ายเป็นเลือดสด'],
             'R093': ['ปวดข้อเท้า', 'ปวดเมื่อยกล้ามเนื้อทั่วๆ'],
             'R10826': ['น้ำมูกไหล', 'คันคอ'],
             'R065': ['ไข้'],
             'R10821': ['บวม'],
             'R072': ['เสมหะไหลลงคอ', 'มีเสมหะ'],
             'R1903': ['Fever'],
             'R0682': ['ปวดข้อมือ'],
             'R1012': ['หายใจมีเสียงวี๊ด'],
             'R10811': ['เสียงแหบ'],
             'R1010': ['ไอ'],
             'R030': ['ท้องเสีย'],
             'R012': ['ตาแห้ง'],
             'R070': ['อาเจียน'],
             'R1901': ['หายใจหอบเหนื่อย']})

In [12]:
symptom_transaction_df['icd10_term'] = symptom_transaction_df['search_term'].map(lambda list_x: list(dict.fromkeys([dict_master_word_icd10.get(x) for x in list_x])))
symptom_transaction_df

Unnamed: 0,search_term,patient_idx,icd10_term
0,"[มีเสมหะ, ไอ]",0,"[R072, R1010]"
1,"[ไอ, น้ำมูกไหล]",1,"[R1010, R10826]"
2,[ปวดท้อง],2,[R064]
3,[น้ำมูกไหล],3,[R10826]
4,[ตาแห้ง],4,[R012]
5,[ปวดกระดูก],5,[R064]
6,"[น้ำมูกไหล, คันจมูกจามบ่อย, ไอ]",6,"[R10826, R102, R1010]"
7,[ปวดท้อง],7,[R064]
8,"[คันคอ, ไอ]",8,"[R10826, R1010]"
9,[ไอ],9,[R1010]


In [13]:
transformation_encoder_obj = TransactionEncoder()
icd10_terms_encoded = transformation_encoder_obj.fit_transform(symptom_transaction_df['icd10_term'])
icd10_terms_encoded = pd.DataFrame(icd10_terms_encoded, columns=transformation_encoder_obj.columns_)
icd10_terms_encoded

Unnamed: 0,R012,R030,R064,R065,R0682,R070,R072,R093,R1010,R1012,R102,R10811,R10821,R10826,R1900,R1901,R1903
0,False,False,False,False,False,False,True,False,True,False,False,False,False,False,False,False,False
1,False,False,False,False,False,False,False,False,True,False,False,False,False,True,False,False,False
2,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False
3,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False
4,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
5,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False
6,False,False,False,False,False,False,False,False,True,False,True,False,False,True,False,False,False
7,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False
8,False,False,False,False,False,False,False,False,True,False,False,False,False,True,False,False,False
9,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False


In [14]:
transformation_encoder_obj.columns_

['R012',
 'R030',
 'R064',
 'R065',
 'R0682',
 'R070',
 'R072',
 'R093',
 'R1010',
 'R1012',
 'R102',
 'R10811',
 'R10821',
 'R10826',
 'R1900',
 'R1901',
 'R1903']

In [26]:
frequent_itemsets = apriori(icd10_terms_encoded, min_support=0.01, use_colnames=True, max_len=10)

In [16]:
frequent_itemsets

Unnamed: 0,support,itemsets
0,0.033333,(R012)
1,0.033333,(R030)
2,0.166667,(R064)
3,0.033333,(R065)
4,0.033333,(R0682)
5,0.066667,(R070)
6,0.1,(R072)
7,0.066667,(R093)
8,0.233333,(R1010)
9,0.033333,(R1012)


In [27]:
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0)
rules = rules[['antecedents', 'consequents', 'support', 'confidence', 'lift']]
rules

Unnamed: 0,antecedents,consequents,support,confidence,lift
0,(R070),(R064),0.033333,0.5,3.0
1,(R064),(R070),0.033333,0.2,3.0
2,(R065),(R1900),0.033333,1.0,7.5
3,(R1900),(R065),0.033333,0.25,7.5
4,(R1010),(R072),0.066667,0.285714,2.857143
5,(R072),(R1010),0.066667,0.666667,2.857143
6,(R093),(R10821),0.033333,0.5,15.0
7,(R10821),(R093),0.033333,1.0,15.0
8,(R1010),(R102),0.033333,0.142857,1.428571
9,(R102),(R1010),0.033333,0.333333,1.428571
