<a href="https://colab.research.google.com/github/mega-317/medical-specialty-prediction/blob/main/disease_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# 라이브러리 임포트
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity

# 1. KM-BERT 불러오기
model_name = "madatnlp/km-bert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)
bert_model.eval()  # 평가 모드 (학습 X)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model.to(device)

# 2. 질병 목록 준비
disease_list = ["감기", "독감", "당뇨병", "고혈압", "지주막하출혈", "추간판탈출증", "알츠하이머"]

# 3. 질병명을 벡터화
def get_sentence_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = bert_model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # [CLS] 토큰
    return cls_embedding.cpu()

disease_embeddings = []
for disease in disease_list:
    emb = get_sentence_embedding(disease)
    disease_embeddings.append(emb)

# 병 이름 임베딩들을 하나의 텐서로 합치기
disease_embeddings = torch.cat(disease_embeddings, dim=0)  # (5, hidden_size)

# 4. 증상 문장을 입력받아 예측하는 함수
def predict_disease(symptom_sentence):
    symptom_emb = get_sentence_embedding(symptom_sentence)  # (1, hidden_size)

    # 코사인 유사도 계산
    similarities = cosine_similarity(symptom_emb.numpy(), disease_embeddings.numpy())  # (1, 5)

    # 유사도 하나씩 출력
    print("\n[각 질병에 대한 유사도]")
    for idx, disease in enumerate(disease_list):
        print(f"{disease}: {similarities[0][idx]:.4f}")

    best_idx = similarities.argmax()
    predicted_disease = disease_list[best_idx]
    confidence = similarities[0][best_idx]

    return predicted_disease, confidence

# 5. 테스트
user_symptom = "허리 쪽에 갑자기 통증이 심해져서 허리를 못 굽히겠어요"
predicted_disease, confidence = predict_disease(user_symptom)

print(f"예측된 질병: {predicted_disease}")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



[각 질병에 대한 유사도]
감기: 0.6790
독감: 0.6178
당뇨병: 0.6838
고혈압: 0.6825
지주막하출혈: 0.5874
추간판탈출증: 0.7020
알츠하이머: 0.6023
예측된 질병: 추간판탈출증
