In [2]:
import json
from tqdm import tqdm
CogKG_DIR = './CogKG_Neuro_eng/'
diagnose_data_path = CogKG_DIR + 'data/diagnose/aligned/diagnose_train.json'

In [3]:
def read_data(diagnose_data_path):
    symptoms, diseases = [],[]
    with open(diagnose_data_path, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = json.loads(line)
            symptoms.append(line['symptoms'])
            diseases.append(line['disease'])
    assert len(symptoms) == len(diseases)
    return (symptoms, diseases)

def process_raw_data_fpgrowth(raw_data):
    transactions = []
    for raw_feature, raw_label in zip(raw_data[0], raw_data[1]):
        item = []
        for symp, symp_value in raw_feature.items():
            if symp_value in [True, 'True']:
                item.append(symp)
        item.append(raw_label)
        transactions.append(item)
    return transactions

In [4]:
#Itemset Mining, Frequent Pattern Mining, FP-Growth
str2id_disease_path = CogKG_DIR + 'data/diagnose/aligned/disease2id.json'
dise2id = json.load(open(str2id_disease_path, 'r', encoding='utf-8'))

raw_data = read_data(diagnose_data_path)
transactions = process_raw_data_fpgrowth(raw_data)

In [5]:
def show_rules(rules, show_num=None, file=None):
    idx = 0
    if file:
        writer = csv.writer(f)
        writer.writerow(['rule','confidence'])
    for cause, effect in rules.items():
        idx += 1
        if file:
            writer.writerow([f'IF {cause} THEN {effect[0][0]}', effect[1] if effect[1] <= 1.0 else 1.0])
        if show_num is not None and idx > show_num:
            continue
        else:
            print(f'Rule-{idx}: IF {cause} THEN--> {effect[0][0]}, confidence:{effect[1] if effect[1] <= 1.0 else 1.0}')


def deredund_rule(rules):
    filter_rule = {}
    for cause, effect in rules.items():
        cause_list = list(filter_rule.keys())

        flag = False
        for i in cause_list:
            if len(set(cause) - set(i)) == 0:
                flag = True
                # print(f'cur:{cause} << {i}')
                break
        if flag:
            continue
        else:
            filter_rule[cause] = effect
    return filter_rule

def restrict_symtom(rules, symptom_list):
    filter_rule = {}
    print(symptom_list)
    for cause, effect in rules.items():
        flag = False
        for i in cause:
            if i not in symptom_list:
                flag = True
                break
        if flag:
            continue
        else:
            filter_rule[cause] = effect
    return filter_rule


def filter_rule(rules):
    disease_rule_dict = {}
    symptom_rule_dict = {}
    for cause, effect in rules.items():
        # mining rules for diseases
        if all(ele not in list(cause) for ele in list(dise2id.keys())) and any(ele in list(effect[0]) for ele in list(dise2id.keys())) and len(list(effect[0])) == 1 and len(list(cause)) > 1: # and len(list(cause)) <= 3
            disease = effect[0][0]
        # if all(ele not in list(cause) for ele in list(dise2id.keys())) and any(ele in list(effect[0]) for ele in list(dise2id.keys())) and len(set(list(effect[0])) & set(list(dise2id.keys()))) == 1 and len(list(cause)) > 1: # and len(list(cause)) <= 3
        #     disease = list(set(list(effect[0])) & set(list(dise2id.keys())))[0]     
            # if disease == '小儿支气管炎' and any(item in cause for item in ['支气管炎']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            # if disease == '小儿感冒' and any(item in cause for item in ['缺钙']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            # if disease == '小儿发热' and any(item in cause for item in ['发热']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            # if disease == '小儿消化不良' and any(item in cause for item in ['消化不良']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            # if disease == '小儿腹泻' and any(item in cause for item in ['腹泻']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            # if disease == '上呼吸道感染' and any(item in cause for item in ['上呼吸道感染']) or len(set(cause)& set(['低热', '高热', '伴中度发热'])) > 1:
            #     continue
            if len(set(cause) & set(['低热', '高热', '伴中度发热'])) > 1:
                continue
            if disease not in disease_rule_dict:
                disease_rule_dict[disease] = {}
            disease_rule_dict[disease][cause] = effect

        # mining rules for symptoms
        if all(ele not in list(effect[0]) for ele in list(dise2id.keys())) and len(list(effect[0])) == 1:
            symptom = effect[0][0]
            if any(symp in list(dise2id.keys()) for symp in cause):
                continue
            if symptom not in symptom_rule_dict:
                symptom_rule_dict[symptom] = {}
            symptom_rule_dict[symptom][cause] = effect
    return disease_rule_dict, symptom_rule_dict

In [6]:
MINI_SUPPORT = 3
MINI_CONFIDENCE = 1.0

In [7]:
import os
for i in os.listdir(CogKG_DIR + 'data/rule/disease_rule/'):
    if '_rules.csv' in i:
        os.remove( CogKG_DIR + 'data/rule/disease_rule/' + i)

# !pip install pyfpgrowth
import pyfpgrowth
import csv
# 支持率，置信度
patterns = pyfpgrowth.find_frequent_patterns(transactions, MINI_SUPPORT)
rules = pyfpgrowth.generate_association_rules(patterns, MINI_CONFIDENCE)
# print(f'patterns:{len(patterns)}')
# print(f'rules:{len(rules)}')

disease_rule_dict, symptom_rule_dict = filter_rule(rules)

for disease in disease_rule_dict:
    rules = disease_rule_dict[disease]
    # rules = deredund_rule(rules) # filter 1
    # rules = restrict_symtom(rules, list(dise2symps_filter[disease].keys())) # filter 2

    print(f'{disease}-rules_num:{len(rules)}, avg_len:{sum([len(i) for i in rules.keys()]) / len(rules)}')
    # show_rules(rules, show_num=None)

    with open(CogKG_DIR + 'data/rule/disease_rule/' + disease + '_rules.csv', 'w', encoding='utf-8', newline='') as f:
        show_rules(rules, show_num=1, file=f)

# print(f'symptom_num:{len(symptom_rule_dict)}\nsymptom_rules:{sum([len(i) for i in symptom_rule_dict.values()])}')

# cnt = 0
# for symptom in symptom_rule_dict:
#     rules = symptom_rule_dict[symptom]
#     rules = deredund_rule(rules)
#     print(f'{symptom}-rules_num:{len(rules)}')

#     if symptom not in ['C反应蛋白', '中性粒细胞', '化验', '听诊','大便常规','白细胞','红细胞','血小板','血常规','验血']:
#         if len(rules) <= 50:
#             with open('symp_rule/' + symptom + '_rules.txt', 'w') as f:
#                 show_rules(rules, show_num=0, file=f)
#                 cnt += len(rules)

# print(cnt)

C1279369-rules_num:15, avg_len:2.7333333333333334
Rule-1: IF ('C0232292', 'C0344232') THEN--> C1279369, confidence:1.0
C0011615-rules_num:3, avg_len:2.3333333333333335
Rule-1: IF ('C0033774', 'C0332563') THEN--> C0011615, confidence:1.0
C0009763-rules_num:9, avg_len:2.4444444444444446
Rule-1: IF ('C0022281', 'C0151827') THEN--> C0009763, confidence:1.0
C0040147-rules_num:48, avg_len:2.8541666666666665
Rule-1: IF ('C0030252', 'C0270996') THEN--> C0040147, confidence:1.0
C0035455-rules_num:11, avg_len:2.5454545454545454
Rule-1: IF ('C0027424', 'C0037384') THEN--> C0035455, confidence:1.0
C0014335-rules_num:27, avg_len:2.5925925925925926
Rule-1: IF ('C0015967', 'C0239978') THEN--> C0014335, confidence:1.0
C0014868-rules_num:23, avg_len:2.391304347826087
Rule-1: IF ('C0085624', 'C0542301') THEN--> C0014868, confidence:1.0
C0024894-rules_num:16, avg_len:2.75
Rule-1: IF ('C0015672', 'C0234238') THEN--> C0024894, confidence:1.0
C0004096-rules_num:6, avg_len:3.6666666666666665
Rule-1: IF ('C00