In [1]:
!pip install datasets



In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset
import pandas as pd
from datetime import datetime
import chardet

def calculate_age(birth_date_str):
    try:
        birth_date = datetime.strptime(birth_date_str, '%Y%m%d')
        today = datetime.now()
        age = today.year - birth_date.year - ((today.month, today.day) < (birth_date.month, birth_date.day))
        return age
    except ValueError:
        return None

def get_age_appropriate_dosage(age, medication, symptoms_df):
    dosage_info = symptoms_df[symptoms_df['추천 의약품'].str.contains(medication, na=False)]['나이대'].iloc[0] if any(symptoms_df['추천 의약품'].str.contains(medication, na=False)) else None
    return dosage_info if dosage_info else None

def get_required_doctor_visit(disease, symptoms_df):
    doctor_visit_info = symptoms_df[symptoms_df['질환명'] == disease]['의사의 진료가 필요한 경우'].iloc[0] if any(symptoms_df['질환명'] == disease) else None
    return doctor_visit_info if doctor_visit_info else None

def prepare_data(symptoms_df):
    texts = []
    labels = []

    for _, row in symptoms_df.iterrows():
        if isinstance(row['증상'], str) and len(row['증상'].strip()) > 0:
            symptom_text = row['증상'].replace('"', '').strip()
            if symptom_text:
                texts.append(symptom_text)
                labels.append(row['질환명'])

    if not texts or not labels:
        raise ValueError("No valid data found")

    return texts, labels

def get_user_selection(symptoms_df):
    unique_categories = symptoms_df["대분류"].dropna().unique()

    print("\n=== 증상 분류 선택 ===")
    for i, category in enumerate(unique_categories, 1):
        print(f"{i}. {category}")

    while True:
        try:
            category_idx = int(input("증상 분류 번호를 선택하세요 (1-3): ")) - 1
            if 0 <= category_idx < len(unique_categories):
                return unique_categories[category_idx]
            print("올바른 번호를 입력하세요.")
        except ValueError:
            print("숫자를 입력하세요.")

def train_model(symptoms_df):
    try:
        texts, labels = prepare_data(symptoms_df)
        print(f"훈련 데이터 수: {len(texts)}")

        unique_labels = list(set(labels))
        label_map = {label: i for i, label in enumerate(unique_labels)}
        encoded_labels = [label_map[label] for label in labels]

        model_name = "klue/bert-base"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=len(unique_labels)
        )

        tokenized = tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        labels_tensor = torch.tensor(encoded_labels, dtype=torch.long)

        dataset = Dataset.from_dict({
            "input_ids": tokenized["input_ids"],
            "attention_mask": tokenized["attention_mask"],
            "labels": labels_tensor
        })

        training_args = TrainingArguments(
            output_dir="./results",
            num_train_epochs=5,
            per_device_train_batch_size=8,
            learning_rate=2e-5,
            weight_decay=0.01
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
        )

        trainer.train()
        return model, tokenizer, label_map

    except Exception as e:
        print(f"Error during model training: {str(e)}")
        raise

def predict_top_3_diseases_in_category(input_text, model, tokenizer, label_map, symptoms_df, selected_category):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    inputs = tokenizer(input_text, padding=True, truncation=True, max_length=128, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0]
        probabilities = torch.softmax(logits, dim=0)

        filtered_probs = []
        reverse_label_map = {v: k for k, v in label_map.items()}

        for idx, prob in enumerate(probabilities):
            disease = reverse_label_map[idx]
            disease_data = symptoms_df[symptoms_df['질환명'] == disease]

            if not disease_data.empty:
                try:
                    disease_category = disease_data['대분류'].iloc[0]
                    if disease_category == selected_category:
                        filtered_probs.append((disease, prob.item()))
                except IndexError:
                    continue

        if not filtered_probs:
            return []

        filtered_probs.sort(key=lambda x: x[1], reverse=True)
        return filtered_probs[:3]

def main():
    # 파일 인코딩 확인
    file_path = "./증상.csv"
    with open(file_path, "rb") as f:
        encoding = chardet.detect(f.read())['encoding']

    # 감지된 인코딩으로 데이터 읽기
    symptoms_df = pd.read_csv(file_path, encoding=encoding)

    # 1. 생년월일 입력
    while True:
        birth_date = input("생년월일을 입력하세요 (YYYYMMDD 형식): ")
        age = calculate_age(birth_date)
        if age is not None:
            break
        print("올바른 형식으로 다시 입력해주세요.")

    # 2. 대분류 선택
    selected_category = get_user_selection(symptoms_df)
    print(f"\n선택된 카테고리: {selected_category}")

    # 3. 모델 학습
    print("\n모델 학습 시작...")
    model, tokenizer, label_map = train_model(symptoms_df)
    print("모델 학습 완료!")

    # 4. 증상 입력 및 예측
    input_text = input("\n증상을 입력하세요: ")
    diseases_in_category = predict_top_3_diseases_in_category(
        input_text, model, tokenizer, label_map, symptoms_df, selected_category
    )

    if not diseases_in_category:
        print("\n선택한 카테고리에서 증상과 일치하는 질병을 찾을 수 없습니다.")
        return

    print(f"\n=== 분석 결과 ===")
    print(f"연령: {age}세")
    print(f"선택된 증상 분류: {selected_category}")

    print("\n예측된 질병:")
    for i, (disease, probability) in enumerate(diseases_in_category, 1):
        print(f"{i}. {disease} ({probability*100:.2f}%)")

if __name__ == "__main__":
    main()



=== 증상 분류 선택 ===
1. 복통
2. 감기
3. 비염
4. 천식

선택된 카테고리: 복통

모델 학습 시작...
훈련 데이터 수: 21


tokenizer_config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/425 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/495k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/445M [00:00<?, ?B/s]

KeyboardInterrupt: 