In [1]:
import json
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.vocab import Vocab
import os
import re
import numpy as np


In [None]:
# text preprocessing
def preprocess_note(raw_note: str) -> str:
    lines = raw_note.splitlines()
    clean_lines = []
    for line in lines:
        line = line.strip()
        if not line or re.fullmatch(r'\.*', line): continue
        if re.match(r"^\(\d+\)|^\d+\)|^\[\d+\]", line): continue
        if re.match(r"^\d+\. ", line): continue
        if re.search(r"(mg|mL|tab|inj|bag|cap|\[IV\]|\[P\.O\.\])", line): continue
        if re.search(r"(tid|bid|q \d+hr|daily|NIBP|SpO2|EKG|Foley|catheter)", line, re.IGNORECASE): continue
        if len(line) < 4: continue
        clean_lines.append(line)
    merged_note = '. '.join(clean_lines)
    merged_note = re.sub(r'\s*\.\s*', '. ', merged_note)
    merged_note = re.sub(r'\s+', ' ', merged_note)
    return merged_note.strip()



# Load Keyword Extractor 

cdb = CDB.load("./cdb.dat")
vocab = Vocab.load("./vocab.dat")
cat = CAT(cdb=cdb, vocab=vocab)

# extract keywords
def extract_keywords(note):
    doc = cat.get_entities(note)
    return list(set(ent["pretty_name"] for ent in doc["entities"].values()))

In [None]:
# additional functions for improving performance
ALPHA = 0.5
def calculate_normalized_prior(keywords, modality):
    desc_keywords = modality_dict[modality].get("related_conditions", [])
    if not desc_keywords: return 0.0
    intersect = set(k.lower() for k in keywords) & set(k.lower() for k in desc_keywords)
    return len(intersect) / len(desc_keywords)

def calculate_normalized_prior_note(note_keywords, modality):
    desc_keywords = modality_dict[modality].get("related_conditions", [])
    if not desc_keywords:
        return 0.0
    intersect = set(k.lower() for k in note_keywords) & set(k.lower() for k in desc_keywords)
    return len(intersect) / len(desc_keywords)

def apply_length_penalty(score, text, scale=0.03):
    token_len = len(tokenizer.encode(text, truncation=False))
    penalty = 1 / (1 + scale * np.log1p(token_len))  # log-based penalty
    return score * penalty

def scaled_softmax(scores, temp=0.07):
    scores = np.array(scores)
    scaled = scores / temp
    exp_scores = np.exp(scaled - np.max(scaled))  # for numerical stability
    return exp_scores / np.sum(exp_scores)


# --------------------------------------------------------------
def calculate_prior_note(keywords, modality):
    modality_keywords = {
        "Brain MRI": ["Glioma", "Brain", "Neoplasms", "Tumor", "Craniotomy", "MRI", "Hemiparesis"],
        "Brain CT": ["Stroke", "Hemorrhage", "CT", "Consciousness", "Trauma"],
        "Chest CT": ["Pulmonary Embolism", "PE", "CT", "Lung", "Nodule", "Mass", "Embolism"],
        "Chest X-ray": ["Pneumonia", "Chest", "Infiltrate", "X-ray", "Pulmonary", "Respiratory", "Effusion"]
    }
    return sum(1 for kw in keywords if kw in modality_keywords.get(modality, []))

# 2. 적용 코드
MODALITY_CONSTRAINTS = {
    "glioma": ["Brain MRI", "Brain CT"],
    "pneumonia": ["Chest X-ray", "Chest CT"]
}
CRITICAL_KEYWORDS = ['pneumonia', 'glioma']

def apply_disease_modality_constraint(note, modality, penalty=-1):
    """
    특정 질병 단어가 들어간 note가 modality와 맞지 않으면 강한 penalty 반환.
    """
    note_lower = note.lower()
    for disease, allowed_modalities in MODALITY_CONSTRAINTS.items():
        if disease in note_lower:
            if modality not in allowed_modalities:
                return penalty  # 강제 페널티
    return 0.0  # 제약 없음

def compute_keyword_boost(note, keywords=CRITICAL_KEYWORDS, boost_val=0.9):
    """
    note에 keywords 중 하나라도 포함되어 있으면 boost_val을 return,
    없으면 0.0을 반환.
    """
    note_lower = note.lower()
    for keyword in keywords:
        if keyword.lower() in note_lower:
            return boost_val
    return 0.0

In [None]:
# -------------------------------
# 1. Load Bio_ClinicalBERT
# -------------------------------
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to("cuda")

def get_embedding(text, chunk_size=128):
    tokens = tokenizer.encode(text, truncation=False)
    chunks = [tokens[i:i+chunk_size] for i in range(0, len(tokens), chunk_size)]
    embeddings = []
    for chunk in chunks:
        input_ids = torch.tensor([chunk]).to("cuda")
        with torch.no_grad():
            output = model(input_ids)[0][:, 0, :]
        embeddings.append(output.cpu())
    return torch.mean(torch.cat(embeddings, dim=0), dim=0).unsqueeze(0).numpy()

# -------------------------------
# 2. Load modality descriptions
# -------------------------------
with open('./modal_descrip2.json', 'r') as f:
    modality_dict = json.load(f)
modality_texts = {k: v["formal"] for k, v in modality_dict.items()}

In [None]:
# 적용할 Note 불러와야 함
# 5. Sample notes
notes = [
    # Chest X-ray
    "Patient presents with shortness of breath and pleuritic chest pain. Suspected pulmonary embolism.",

    # Spinal MRI
    "Lower back pain with leg numbness. Suspected disc herniation.",

    # Brain - Glioma
    "High-Grade Glioma, IDH-wt, NEC, WHO Gr 3, Rt. T-O~CC Splenium (TERTpC228T)- s/p Navi-Guided Bx ('25.05.29)- s/p Craniotomy and tumor removal ('25.06.27)",
    "2025-06-27 Craniotomy and tumor removal  Brain tumor Brain tumor",
    "E4M6V5 Orientaiton T/P/P +-/+/+ Lt homonymous hemianopsia on confrontation test Motor grossly V/V, V/V",
    "Glioma MR (c/w '25.05.26) : Slightly Decreased Extent? w/ Improved Perilesional Edema",
    "Chief Complaint 1. tumor, brain ( brain tumor )- for Op Present Illness #. High-Grade Glioma, IDH-wt, NEC, WHO Gr 3, Rt. T-O~CC Splenium (TERTpC228T) - s/p Navi-Guided Bx ('25.05.29) Patient had weakness in left arm and legs from early May but symptoms were not severe. Patient had two episodes where she got lost while finding the way home. '25.05.25 Sudden vomiting and severe headache. Visited Emergency ward due to Lt. Hemiplegia. r/o PCNSL or HGG, Lt. T-P-O. '25.05.29 Navi-Guided Bx. Pathology reported as high grade glioma. Hospitalized to undergo operation. Stopped steroids after moving to rehabilitaion hospital. Restated steroids due to Confused Mentality and worsened Lt. Hemiparesis. No meaningful neurological symptoms compared to discharged day. Cannot walk without wheelchair. Rt. Handed",
    
    # Chest - Pneumonia
    "Patient was followed up regularly due to M. abcessus PD, BE, NTM was negative for a few years. pseudomonas colonizer.Patient under went total esophagogastrectomy due to corrosive esophagitis. Had multiple episodes of aspiration pneumonia during enteral feeding by esojejunostomy. Patient was hospitalized for pneumonia treatment in January and February. Patient was hospitalized in the MICU due to pneumonia withtype 1 respiratory failure in March.After discharge, patient visitied hospital 25.06.10. Patient complained aggravated cough and purulent sputum from 5 days ago. Patient wasclinically diagnosed with pneumonia and is hospitalized for treatment.",
    "CRPA(+) in sputum. Patient has used pip/tazo for initial intibiotics instead of zerbaxa in severe presentation and ICU settings. Patient showed relieved symptoms at initial administration of medication. However had fever and desaturation event after some time. Patient was dosed with zerbaxa from 6/24 due to RLL haziness development, and symptoms are relieving. Consider using colistin in case of symptom aggravation.",
    "# r/o Aspiration pneumonia - Pip/taz 6/10~6/23 - CRPA: Ceftolozane/tazo 6/24 -"
    ]

In [None]:
# ----------------------------------
# Note -> Modality별 점수 
# ----------------------------------
note_embeddings = []
note_to_modality_results = []
for note in notes:
    note_emb = get_embedding(note)
    note_embeddings.append(note_emb)
    keywords = extract_keywords(note)
    raw_scores = {
        modality: cosine_similarity(note_emb, get_embedding(desc))[0][0]
        for modality, desc in modality_texts.items()
    }
    scored_with_prior = {
        modality: ALPHA * score + (1 - ALPHA) * calculate_normalized_prior(keywords, modality)
        for modality, score in raw_scores.items()
    }
    top2 = sorted(scored_with_prior.items(), key=lambda x: x[1], reverse=True)[:2]
    note_to_modality_results.append({"note": note, "keywords": keywords, "top2_modalities": top2})


# -------------------------------
# Modality -> Note별 점수
# -------------------------------
modality_to_note_results = []
for modality, desc in modality_texts.items():
    mod_emb = get_embedding(desc)
    scores = []
    for i, note in enumerate(notes):
        note_emb = note_embeddings[i]
        cosine_score = cosine_similarity(mod_emb, note_emb)[0][0]

        # Prior (keyword 기반)
        note_keywords = extract_keywords(note)[:5]
        prior = calculate_normalized_prior_note(note_keywords, modality)

        # Keyword boost (glioma, pneumonia 등장 여부)
        keyword_boost = compute_keyword_boost(note, boost_val=0.9)

        # Modality 제한 조건 페널티
        constraint_penalty = apply_disease_modality_constraint(note, modality)

        # 최종 점수
        final_score = (
            ALPHA * cosine_score +
            (1 - ALPHA) * prior +
            keyword_boost +
            constraint_penalty
        )

        scores.append({
            "note": note,
            "score": final_score
        })

    # top-2 note 선택
    top_notes = sorted(scores, key=lambda x: x["score"], reverse=True)[:2]
    modality_to_note_results.append({
        "modality": modality,
        "top_notes": top_notes
    })




In [None]:
# Print 
print("### 🔍 NOTE → Top-2 MODALITIES\n")
for r in note_to_modality_results:
    print(f"Note: {r['note']}")
    print(f"Keywords: {r['keywords']}")
    for modality, score in r['top2_modalities']:
        print(f"  - {modality}: {round(score, 4)}")
    print()


print("\n### 🔍 MODALITY → Top-2 NOTES\n")
for r in modality_to_note_results:
    print(f"Modality: {r['modality']}")
    for i, n in enumerate(r['top_notes']):
        print(f"  Top-{i+1} Note (score={n['score']}): {n['note'][:200]}...")
    print()