In [1]:
!pip install wandb -qU

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import pandas as pd
import numpy as np
import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import seaborn as sns
pd.set_option('display.max_rows', None)
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from transformers import AdamW
import torch.nn as nn
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
# Log in to your W&B account
import wandb
import random
import math

# main code

In [4]:
import wandb
# 처음 실행시 WandB 웹사이트에서 발급받은 API 키를 입력해야 합니다
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
df = pd.read_csv('/content/drive/MyDrive/University/4-2/정보기술학회/data/medical_data.csv', encoding='utf-8')
df.shape

(2891197, 11)

In [14]:
import logging
import wandb
import pandas as pd
import numpy as np
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

@dataclass
class DataConfig:
    """데이터 처리 관련 설정을 관리하는 클래스"""
    min_samples_per_class: int = 1000  # 클래스당 최소 샘플 수
    valid_department_threshold: int = 1000  # 유효한 진료과로 판단할 최소 샘플 수
    valid_disease_threshold: int = 10000  # 유효한 질병 코드로 판단할 최소 샘플 수
    test_size: float = 0.2
    random_state: int = 42
    text_column: str = '증상'
    dept_column: str = '진료과목코드'
    disease_column: str = '주상병코드'

class MedicalDataProcessor:
    """의료 데이터 전처리 및 균형화를 담당하는 클래스"""

    def __init__(self, config: DataConfig):
        self.config = config
        self.logger = self._setup_logger()
        self.dept_encoder = LabelEncoder()
        self.disease_encoder = LabelEncoder()
        wandb.init(project="medical-recommendation", config=config.__dict__)

    def _setup_logger(self) -> logging.Logger:
        """로깅 설정"""
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        ))
        logger.addHandler(handler)
        return logger

    def process_data(self, file_path: Path) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """전체 데이터 처리 파이프라인"""
        try:
            # 데이터 로드
            df = self._load_data(file_path)

            # 데이터 전처리
            df = self._preprocess_data(df)

            # 유효한 클래스만 선택
            df = self._filter_valid_classes(df)

            # 데이터 균형화
            df_balanced = self._balance_data(df)

            # 학습/테스트 분할
            train_df, test_df = self._split_data(df_balanced)

            # 처리 결과 로깅
            self._log_processing_results(df, df_balanced, train_df, test_df)

            return train_df, test_df

        except Exception as e:
            self.logger.error(f"Data processing failed: {str(e)}")
            raise

    def _load_data(self, file_path: Path) -> pd.DataFrame:
        """데이터 로드 및 기본 검증"""
        self.logger.info(f"Loading data from {file_path}")
        df = pd.read_csv(file_path)

        # 필수 컬럼 존재 확인
        required_columns = [self.config.text_column,
                          self.config.dept_column,
                          self.config.disease_column]

        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns: {missing_columns}")

        return df

    def _preprocess_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """데이터 전처리"""
        # 텍스트 전처리
        df[self.config.text_column] = df[self.config.text_column].apply(
            self._preprocess_text
        )

        # 레이블 인코딩
        df['dept_encoded'] = self.dept_encoder.fit_transform(
            df[self.config.dept_column]
        )
        df['disease_encoded'] = self.disease_encoder.fit_transform(
            df[self.config.disease_column]
        )

        return df

    def _preprocess_text(self, text: str) -> str:
        """텍스트 전처리"""
        # 기본적인 텍스트 클리닝
        text = text.lower().strip()

        # 불필요한 공백 제거
        text = ' '.join(text.split())

        return text

    def _filter_valid_classes(self, df: pd.DataFrame) -> pd.DataFrame:
        """유효한 클래스만 선택"""
        # 진료과 기준 필터링
        dept_counts = df[self.config.dept_column].value_counts()
        valid_depts = dept_counts[dept_counts >= self.config.valid_department_threshold].index

        # 질병 코드 기준 필터링
        disease_counts = df[self.config.disease_column].value_counts()
        valid_diseases = disease_counts[disease_counts >= self.config.valid_disease_threshold].index

        # 유효한 클래스만 선택
        mask = (df[self.config.dept_column].isin(valid_depts)) & \
               (df[self.config.disease_column].isin(valid_diseases))

        return df[mask].reset_index(drop=True)

    def _balance_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """클래스 별 샘플 수 균형화 - 개선된 버전"""
        # 최소 기준 샘플 수 결정
        min_samples = self.config.min_samples_per_class

        # 각 클래스별 현재 샘플 수 확인
        disease_counts = df[self.config.disease_column].value_counts()

        # 충분한 샘플을 가진 클래스만 선택
        valid_diseases = disease_counts[disease_counts >= min_samples].index

        balanced_dfs = []
        for disease in valid_diseases:
            disease_df = df[df[self.config.disease_column] == disease]

            # 정확히 min_samples만큼 샘플링
            sampled_df = disease_df.sample(
                n=min_samples,
                random_state=self.config.random_state
            )
            balanced_dfs.append(sampled_df)

        balanced_df = pd.concat(balanced_dfs, axis=0).reset_index(drop=True)

        # 결과 로깅
        self.logger.info(f"Original class distribution:\n{disease_counts}")
        self.logger.info(f"Balanced class distribution:\n{balanced_df[self.config.disease_column].value_counts()}")

        return balanced_df

    def _split_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """학습/테스트 데이터 분할"""
        train_df, test_df = train_test_split(
            df,
            test_size=self.config.test_size,
            stratify=df['dept_encoded'],
            random_state=self.config.random_state
        )

        return train_df, test_df

    def _log_processing_results(self, original_df: pd.DataFrame,
                              balanced_df: pd.DataFrame,
                              train_df: pd.DataFrame,
                              test_df: pd.DataFrame):
        """처리 결과를 WandB에 기록"""
        wandb.log({
            "data_processing": {
                "original_samples": len(original_df),
                "balanced_samples": len(balanced_df),
                "training_samples": len(train_df),
                "test_samples": len(test_df),
                "dept_distribution": train_df[self.config.dept_column].value_counts().to_dict(),
                "disease_distribution": train_df[self.config.disease_column].value_counts().to_dict(),
                "text_length_stats": train_df[self.config.text_column].str.len().describe().to_dict()
            }
        })

In [63]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import xgboost as xgb
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import wandb

@dataclass
class ModelConfig:
    """모델 구성을 위한 설정 클래스"""
    # 기본 설정
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # BERT 트랙 설정
    bert_model_name: str = "madatnlp/km-bert"
    tokenizer_name: str = "snunlp/KR-BERT-char16424"
    max_length: int = 512
    bert_batch_size: int = 32
    bert_learning_rate: float = 2e-5
    bert_epochs: int = 10
    warmup_ratio: float = 0.1

    # 학습 설정
    patience: int = 3  # Early stopping patience
    epochs: int = 10  # 전체 학습 에포크 수
    learning_rate: float = 1e-4  # 기본 학습률
    weight_decay: float = 0.01  # 가중치 감쇠
    gradient_clip_val: float = 1.0  # 그래디언트 클리핑 값

    # XGBoost 트랙 설정
    xgb_params: Dict = field(default_factory=lambda: {
        'objective': 'multi:softprob',
        'eval_metric': ['mlogloss', 'merror'],
        'eta': 0.1,
        'max_depth': 6,
        'min_child_weight': 1,
        'subsample': 0.8,
        'colsample_bytree': 0.8,
        'tree_method': 'gpu_hist'  # GPU 활용
    })

    # 스태킹 설정
    stacking_folds: int = 5
    use_probabilities: bool = True  # 확률값 사용 여부

class TextBertTrack(nn.Module):
    """텍스트 데이터를 처리하는 BERT 트랙"""

    def __init__(self, config: ModelConfig, num_classes: int):
        super().__init__()
        self.config = config
        self.num_classes = num_classes

        # 토크나이저 초기화
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.tokenizer_name,
            max_length=config.max_length
        )

        # BERT 모델 초기화
        self.bert = AutoModel.from_pretrained(config.bert_model_name)

        # 드롭아웃 적용
        self.dropout = nn.Dropout(0.1)

        # 계층적 분류기 구조
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.LayerNorm(512),  # 정규화 추가
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, return_features=False):
        # BERT 출력 획득
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True  # 모든 레이어의 hidden states 획득
        )

        # 마지막 4개 레이어의 [CLS] 토큰 표현을 결합
        last_four_layers = outputs.hidden_states[-4:]
        cls_embeddings = torch.stack([layer[:, 0] for layer in last_four_layers], dim=1)
        pooled_output = torch.mean(cls_embeddings, dim=1)

        # 드롭아웃 적용
        pooled_output = self.dropout(pooled_output)

        # 분류
        logits = self.classifier(pooled_output)

        if return_features:
            return logits, pooled_output
        return logits

class TabularXGBoostTrack:
    """테이블 데이터를 처리하는 XGBoost 트랙"""

    def __init__(self, config: ModelConfig):
        self.config = config
        self.model = None
        self.feature_importance = None

    def train(self, X, y, eval_set=None):
        dtrain = xgb.DMatrix(X, label=y)

        # 검증 세트가 있는 경우
        evals = [(dtrain, 'train')]
        if eval_set:
            deval = xgb.DMatrix(eval_set[0], label=eval_set[1])
            evals.append((deval, 'eval'))

        # 학습 진행
        self.model = xgb.train(
            self.config.xgb_params,
            dtrain,
            num_boost_round=1000,
            evals=evals,
            early_stopping_rounds=50,
            verbose_eval=100
        )

        # 특성 중요도 저장
        self.feature_importance = self.model.get_score(importance_type='gain')

    def predict(self, X, return_probabilities=True):
        dtest = xgb.DMatrix(X)
        predictions = self.model.predict(dtest)

        if not return_probabilities:
            predictions = predictions.argmax(axis=1)
        return predictions

class StackEnsemble(nn.Module):
    """스태킹 앙상블 모델"""

    def __init__(self, config: ModelConfig, num_classes: int):
        super().__init__()  # nn.Module 초기화
        self.config = config
        self.num_classes = num_classes
        self.bert_track = TextBertTrack(config, num_classes).to(config.device)
        self.xgb_track = TabularXGBoostTrack(config)
        self.meta_model = None

        # 성능 측정을 위한 메트릭 초기화
        self.best_score = 0
        self.best_epoch = 0
        self.training = False  # 학습 모드 플래그 추가

    def forward(self, input_ids, attention_mask):
        """Forward pass 구현"""
        # BERT 트랙의 출력 획득
        bert_output = self.bert_track(input_ids, attention_mask)
        return bert_output

    def train(self, mode=True):
        """학습/평가 모드 설정 (PyTorch 기본 메서드 오버라이드)"""
        self.training = mode
        self.bert_track.train(mode)
        return self

    def eval(self):
        """평가 모드 설정"""
        return self.train(False)

    def fit(self, text_data, tabular_data, labels, eval_data=None):
        """전체 학습 프로세스"""
        # BERT 트랙 학습 및 특성 추출
        bert_features = self._train_bert_track(text_data, labels, eval_data)

        # XGBoost 트랙 학습 및 특성 추출
        xgb_features = self._train_xgb_track(tabular_data, labels, eval_data)

        # 메타 특성 생성
        meta_features = self._create_meta_features(bert_features, xgb_features)

        # 최종 메타 모델 학습
        self._train_meta_model(meta_features, labels)

        # WandB에 학습 결과 기록
        self._log_training_results()

    def _train_bert_track(self, text_data, labels, eval_data=None):
        """BERT 트랙 학습"""
        # BERT 모델 학습 로직 구현
        pass

    def _train_xgb_track(self, tabular_data, labels, eval_data=None):
        """XGBoost 트랙 학습"""
        # XGBoost 모델 학습 로직 구현
        pass

    def _train_meta_model(self, meta_features, labels):
        """메타 모델 학습"""
        # 메타 모델 학습 로직 구현
        pass

    def _create_meta_features(self, bert_features, xgb_features):
        """메타 특성 생성"""
        if self.config.use_probabilities:
            return np.concatenate([bert_features, xgb_features], axis=1)
        return np.concatenate([
            bert_features.argmax(axis=1).reshape(-1, 1),
            xgb_features.argmax(axis=1).reshape(-1, 1)
        ], axis=1)

    def predict(self, text_data, tabular_data):
        """예측 수행"""
        # 각 트랙의 예측 획득
        bert_preds = self._predict_bert_track(text_data)
        xgb_preds = self.xgb_track.predict(tabular_data)

        # 메타 특성 생성
        meta_features = self._create_meta_features(bert_preds, xgb_preds)

        # 최종 예측
        final_predictions = self.meta_model.predict(xgb.DMatrix(meta_features))
        return final_predictions

    def _predict_bert_track(self, text_data):
        """BERT 트랙 예측"""
        # BERT 모델 예측 로직 구현
        pass

    def _log_training_results(self):
        """학습 결과를 WandB에 기록"""
        wandb.log({
            "best_score": self.best_score,
            "best_epoch": self.best_epoch,
            "bert_feature_dim": self.bert_track.bert.config.hidden_size,
            "xgb_feature_importance": self.xgb_track.feature_importance
        })

In [64]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup
import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple
from pathlib import Path
import logging
from tqdm.auto import tqdm
import wandb
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import gc

class MedicalDataset(Dataset):
    """의료 데이터를 효율적으로 처리하는 데이터셋 클래스"""

    def __init__(self,
                 data_path: Path,
                 tokenizer,
                 config: ModelConfig,
                 chunk_size: int = 1000):
        """
        Parameters:
            data_path: 데이터 파일 경로
            tokenizer: BERT 토크나이저
            config: 모델 설정
            chunk_size: 한 번에 처리할 데이터 크기
        """
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.config = config
        self.chunk_size = chunk_size

        # 데이터 인덱스 초기화
        self._initialize_data_index()

    def _initialize_data_index(self):
        """데이터 인덱스 구성 및 기본 검증"""
        try:
            # 데이터 기본 정보 읽기
            self.total_rows = sum(1 for _ in open(self.data_path)) - 1

            # 청크 단위로 데이터 읽어오기 위한 인덱스 생성
            self.chunk_starts = list(range(0, self.total_rows, self.chunk_size))

            logging.info(f"Total rows: {self.total_rows}")
            logging.info(f"Number of chunks: {len(self.chunk_starts)}")

        except Exception as e:
            logging.error(f"Data initialization failed: {str(e)}")
            raise

    def _load_chunk(self, chunk_idx: int) -> pd.DataFrame:
        """청크 단위로 데이터 로드"""
        start_idx = self.chunk_starts[chunk_idx]

        try:
            chunk = pd.read_csv(
                self.data_path,
                skiprows=range(1, start_idx + 1),
                nrows=self.chunk_size
            )
            return self._preprocess_chunk(chunk)
        except Exception as e:
            logging.error(f"Chunk loading failed at index {chunk_idx}: {str(e)}")
            raise

    def _preprocess_chunk(self, chunk: pd.DataFrame) -> pd.DataFrame:
        """데이터 청크 전처리"""
        # 누락된 값 처리
        chunk['증상'] = chunk['증상'].fillna('')

        # 텍스트 데이터 정제
        chunk['증상'] = chunk['증상'].apply(self._clean_text)

        # 수치형 특성 정규화
        numeric_columns = chunk.select_dtypes(include=[np.number]).columns
        chunk[numeric_columns] = chunk[numeric_columns].fillna(0)

        return chunk

    def _clean_text(self, text: str) -> str:
        """텍스트 데이터 정제"""
        text = str(text).lower().strip()
        # 추가적인 텍스트 정제 규칙 적용
        return text

    def __len__(self):
        return self.total_rows

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        try:
            chunk = self._load_chunk(idx // self.chunk_size)
            row = chunk.iloc[idx % self.chunk_size]

            # 텍스트 토크나이징
            encoding = self._tokenize_text(row['증상'])

            return {
                'input_ids': encoding['input_ids'],
                'attention_mask': encoding['attention_mask'],
                'label': torch.tensor(row['주상병코드'], dtype=torch.long)  # dtype 명시
            }
        except Exception as e:
            logging.error(f"Error in __getitem__ at index {idx}: {str(e)}")
            raise

    def _tokenize_text(self, text: str) -> Dict[str, torch.Tensor]:
        """텍스트 토크나이징"""
        try:
            encoding = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.config.max_length,
                return_tensors='pt'
            )

            return {
                'input_ids': encoding['input_ids'].squeeze(0),
                'attention_mask': encoding['attention_mask'].squeeze(0)
            }

        except Exception as e:
            logging.error(f"Tokenization failed: {str(e)}")
            raise

class ModelTrainer:
    """모델 학습 및 평가를 관리하는 클래스"""

    def __init__(self,
                 config: ModelConfig,
                 model: StackEnsemble,
                 experiment_name: str):
        self.config = config
        self.model = model
        self.experiment_name = experiment_name

        # 로깅 설정
        self._setup_logging()

        # WandB 초기화
        self._initialize_wandb()

    def _setup_logging(self):
        """로깅 설정"""
        log_dir = Path('logs')
        log_dir.mkdir(parents=True, exist_ok=True)

        log_path = log_dir / f"{self.experiment_name}.log"

        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_path),
                logging.StreamHandler()
            ]
        )

    def _initialize_wandb(self):
        """WandB 설정"""
        wandb.init(
            project="medical-classification",
            name=self.experiment_name,
            config=vars(self.config),
            resume=True
        )

    def _setup_optimizer(self):
        """옵티마이저 설정"""
        # 학습 가능한 파라미터만 선택
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        # 가중치 감쇠 적용 여부에 따라 파라미터 그룹화
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]

        # AdamW 옵티마이저 생성
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.config.bert_learning_rate
        )

        return optimizer

    def _setup_scheduler(self, optimizer, num_training_steps):
        """학습률 스케줄러 설정"""
        # 웜업 스텝 계산
        warmup_steps = int(num_training_steps * self.config.warmup_ratio)

        # 선형 스케줄러 생성
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_training_steps
        )

        return scheduler

    def _process_batch(self, batch, optimizer=None):
        """배치 처리"""
        try:
            # 배치 데이터가 이미 텐서인지 확인하고 변환
            input_ids = batch['input_ids'].to(self.config.device) if isinstance(batch['input_ids'], torch.Tensor) else torch.tensor(batch['input_ids']).to(self.config.device)
            attention_mask = batch['attention_mask'].to(self.config.device) if isinstance(batch['attention_mask'], torch.Tensor) else torch.tensor(batch['attention_mask']).to(self.config.device)
            labels = batch['label'].to(self.config.device) if isinstance(batch['label'], torch.Tensor) else torch.tensor(batch['label']).to(self.config.device)

            # Forward pass
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            # 손실 계산
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(outputs, labels)

            # 학습 모드인 경우
            if optimizer is not None:
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # 예측값 반환
            predictions = torch.argmax(outputs, dim=1)

            return loss.item(), predictions

        except Exception as e:
            logging.error(f"Batch processing failed: {str(e)}")
            raise

    def train(self,
              train_loader: DataLoader,
              val_loader: Optional[DataLoader] = None) -> Dict[str, float]:  # 반환 타입 명시
        """모델 학습 수행"""
        try:
            # 옵티마이저 설정
            optimizer = self._setup_optimizer()
            scheduler = self._setup_scheduler(optimizer, len(train_loader))

            best_val_score = float('-inf')
            early_stopping_counter = 0
            best_metrics = None  # 최고 성능의 메트릭 저장

            for epoch in range(self.config.bert_epochs):
                # 학습
                train_metrics = self._train_epoch(train_loader, optimizer, scheduler)

                # 검증
                if val_loader is not None:
                    val_metrics = self._validate(val_loader)
                    current_val_score = val_metrics['f1_score']

                    # 모델 저장 및 조기 종료 확인
                    if current_val_score > best_val_score:
                        best_val_score = current_val_score
                        best_metrics = {  # 최고 성능 메트릭 저장
                            'val_accuracy': val_metrics['accuracy'],
                            'val_f1': val_metrics['f1_score'],
                            'val_precision': val_metrics['precision'],
                            'val_recall': val_metrics['recall'],
                            'val_loss': val_metrics['loss']
                        }
                        self._save_model(epoch, val_metrics)
                        early_stopping_counter = 0
                    else:
                        early_stopping_counter += 1

                    if early_stopping_counter >= self.config.patience:
                        logging.info("Early stopping triggered")
                        break

                # 메모리 정리
                gc.collect()
                torch.cuda.empty_cache()

            return best_metrics  # 최고 성능의 메트릭 반환

        except Exception as e:
            logging.error(f"Training failed: {str(e)}")
            raise

    def _train_epoch(self,
                    train_loader: DataLoader,
                    optimizer: torch.optim.Optimizer,
                    scheduler) -> Dict[str, float]:
        """한 에포크 학습"""
        self.model.train()  # 이제 PyTorch의 train() 메서드 사용
        total_loss = 0
        predictions = []
        labels = []

        progress_bar = tqdm(train_loader, desc='Training')

        for batch in progress_bar:
            try:
                loss, batch_preds = self._process_batch(batch, optimizer)
                total_loss += loss

                predictions.extend(batch_preds.cpu().numpy())
                labels.extend(batch['label'].cpu().numpy())

                scheduler.step()

                # 진행상황 업데이트
                progress_bar.set_postfix({'loss': loss})

            except Exception as e:
                logging.error(f"Batch processing failed: {str(e)}")
                continue

        # 메트릭 계산
        metrics = self._calculate_metrics(predictions, labels)
        metrics['loss'] = total_loss / len(train_loader)

        # WandB 로깅
        wandb.log({f'train_{k}': v for k, v in metrics.items()})

        return metrics

    def _validate(self, val_loader: DataLoader) -> Dict[str, float]:
        """검증 수행"""
        self.model.eval()
        total_loss = 0
        predictions = []
        labels = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                try:
                    loss, batch_preds = self._process_batch(batch)
                    total_loss += loss

                    predictions.extend(batch_preds.cpu().numpy())
                    labels.extend(batch['label'].cpu().numpy())

                except Exception as e:
                    logging.error(f"Validation batch processing failed: {str(e)}")
                    continue

        # 메트릭 계산
        metrics = self._calculate_metrics(predictions, labels)
        metrics['loss'] = total_loss / len(val_loader)

        # WandB 로깅
        wandb.log({f'val_{k}': v for k, v in metrics.items()})

        return metrics

    def _calculate_metrics(self,
                         predictions: np.ndarray,
                         labels: np.ndarray) -> Dict[str, float]:
        """성능 메트릭 계산"""
        return {
            'accuracy': accuracy_score(labels, predictions),
            'f1_score': f1_score(labels, predictions, average='weighted'),
            'precision': precision_score(labels, predictions, average='weighted'),
            'recall': recall_score(labels, predictions, average='weighted')
        }

    def _save_model(self,
                   epoch: int,
                   metrics: Dict[str, float]) -> None:
        """모델 저장"""
        save_path = Path(f'models/{self.experiment_name}')
        save_path.mkdir(parents=True, exist_ok=True)

        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'metrics': metrics
        }, save_path / f'model_epoch_{epoch}.pt')

In [65]:
import argparse
import yaml
from pathlib import Path
import logging
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
import wandb
from datetime import datetime

class ExperimentRunner:
    def __init__(self, config_path: str, train_data: pd.DataFrame, test_data: pd.DataFrame):
        """
        Args:
            config_path: YAML 설정 파일 경로
            train_data: 학습 데이터
            test_data: 테스트 데이터
        """
        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                self.config = yaml.safe_load(f)
        except Exception as e:
            raise RuntimeError(f"Failed to load config file: {str(e)}")

        self.train_data = train_data
        self.test_data = test_data
        self.experiment_name = f"medical_classification_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

        # 경로 설정
        self._setup_paths()
        # 로깅 설정
        self._setup_logging()

        # wandb 초기화 - 기본값 설정 추가
        wandb_config = self.config.get('wandb', {})
        wandb.init(
            project=wandb_config.get('project', 'medical-recommendation'),  # 기본값 설정
            name=self.experiment_name,
            config=self.config
        )

    def _setup_paths(self):
        """필요한 디렉터리 생성"""
        paths = ['logs', 'models', 'results', 'submissions']
        for path in paths:
            Path(path).mkdir(parents=True, exist_ok=True)

    def _setup_logging(self):
        """로깅 설정"""
        log_dir = Path('logs')
        log_dir.mkdir(parents=True, exist_ok=True)

        log_path = log_dir / f"{self.experiment_name}.log"

        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_path),
                logging.StreamHandler()
            ]
        )

    def _load_config(self, config_path: str) -> dict:
        """YAML 설정 파일 로드"""
        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                config = yaml.safe_load(f)
            return config
        except Exception as e:
            raise RuntimeError(f"Failed to load config: {str(e)}")

    def prepare_data(self, data_path: str):
        """데이터 준비 및 전처리

        Args:
            data_path: 데이터 파일 경로
        Returns:
            처리된 데이터셋
        """
        logging.info("Loading and preparing data...")

        try:
            # 데이터 로드
            df = pd.read_csv(data_path)

            # 기본적인 데이터 검증
            self._validate_data(df)

            # 레이블 인코딩
            df['dept_encoded'] = self._encode_labels(df['진료과목코드'])
            df['disease_encoded'] = self._encode_labels(df['주상병코드'])

            # 결과 저장
            processed_path = Path('data') / 'processed' / f'{Path(data_path).stem}_processed.csv'
            df.to_csv(processed_path, index=False)

            logging.info(f"Data preparation completed. Shape: {df.shape}")
            return df

        except Exception as e:
            logging.error(f"Data preparation failed: {str(e)}")
            raise

    def _initialize_model(self):
        """모델 초기화"""
        try:
            # 클래스 수 계산
            num_classes = len(self.train_data['dept_encoded'].unique())

            # 모델 설정 가져오기
            model_config = ModelConfig(**self.config['model'])

            # 스태킹 앙상블 모델 초기화
            model = StackEnsemble(
                config=model_config,
                num_classes=num_classes
            )

            logging.info(f"Model initialized with {num_classes} classes")
            return model

        except Exception as e:
            logging.error(f"Model initialization failed: {str(e)}")
            raise

    def run_experiment(self):
        """전체 실험 실행"""
        try:
            # K-폴드 교차 검증 설정
            skf = StratifiedKFold(
                n_splits=self.config['training']['n_folds'],
                shuffle=True,
                random_state=self.config['data']['random_state']
            )

            # 각 폴드에 대해 학습 진행
            fold_results = []
            for fold, (train_idx, val_idx) in enumerate(skf.split(self.train_data, self.train_data['dept_encoded'])):
                logging.info(f"\nStarting Fold {fold + 1}")

                # 데이터 분할
                train_split = self.train_data.iloc[train_idx]
                val_split = self.train_data.iloc[val_idx]

                # 모델 초기화 및 학습
                model = self._initialize_model()
                trainer = ModelTrainer(
                    config=ModelConfig(**self.config['model']),
                    model=model,
                    experiment_name=f"{self.experiment_name}_fold_{fold + 1}"
                )

                # 학습 실행 및 결과 저장
                results = trainer.train(train_split, val_split)
                if results is not None:  # 결과가 있는 경우에만 처리
                    fold_results.append(results)
                    self._save_fold_results(results, fold)

            # 최종 결과 분석
            if fold_results:  # 결과가 있는 경우에만 분석
                self._analyze_results(fold_results)

        except Exception as e:
            logging.error(f"Experiment failed: {str(e)}")
            raise
        finally:
            wandb.finish()

    def _create_data_loader(self, data: dict, is_training: bool):
        """데이터 로더 생성"""
        dataset = MedicalDataset(
            data['texts'],
            data['tabular_features'],
            data['labels'],
            self.tokenizer,
            self.config['model']
        )

        return DataLoader(
            dataset,
            batch_size=self.config['training']['batch_size'],
            shuffle=is_training,
            num_workers=self.config['training']['num_workers']
        )

    def _save_fold_results(self, results: dict, fold: int):
        """폴드별 결과 저장"""
        results_path = Path('results') / self.experiment_name / f'fold_{fold + 1}'
        results_path.mkdir(parents=True, exist_ok=True)

        # 결과 저장
        pd.DataFrame(results).to_csv(results_path / 'metrics.csv', index=False)

        # WandB에 로깅
        wandb.log({
            f'fold_{fold + 1}_val_accuracy': results['val_accuracy'],
            f'fold_{fold + 1}_val_f1': results['val_f1']
        })

    def _analyze_results(self, fold_results: list):
        """전체 실험 결과 분석"""
        # 평균 및 표준편차 계산
        metrics = pd.DataFrame(fold_results)
        summary = {
            'mean': metrics.mean(),
            'std': metrics.std()
        }

        # 결과 저장
        summary_path = Path('results') / self.experiment_name / 'summary.csv'
        pd.DataFrame(summary).to_csv(summary_path)

        # 결과 로깅
        logging.info("\nExperiment Summary:")
        logging.info(f"Mean Validation Accuracy: {summary['mean']['val_accuracy']:.4f} ± {summary['std']['val_accuracy']:.4f}")
        logging.info(f"Mean Validation F1 Score: {summary['mean']['val_f1']:.4f} ± {summary['std']['val_f1']:.4f}")

In [66]:
# 설정 객체 생성
config = DataConfig(
    min_samples_per_class=2000,  # 더 작은 값으로 조정
    valid_department_threshold=2000,  # 더 작은 값으로 조정
    valid_disease_threshold=2000  # 더 작은 값으로 조정
)

# 데이터 처리기 초기화
processor = MedicalDataProcessor(config)

# 데이터 처리 실행
train_df, test_df = processor.process_data(Path("/content/drive/MyDrive/University/4-2/정보기술학회/data/medical_data.csv"))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112426188891631, max=1.0…



CommError: Run initialization has timed out after 90.0 sec. Please try increasing the timeout with the `init_timeout` setting: `wandb.init(settings=wandb.Settings(init_timeout=120))`.

In [61]:
# Experiment 설정 및 실행
experiment_config_path = "/content/drive/MyDrive/University/4-2/정보기술학회/data/experiment.yaml"

# experiment.yaml /Users/ham-yanghun/Desktop/University/University/24-2/ML/code/experiment.yaml파일이 존재하는지 확인
if not Path(experiment_config_path).exists():
    raise FileNotFoundError(f"Config file not found: {experiment_config_path}")

# Experiment 실행
runner = ExperimentRunner(experiment_config_path, train_df, test_df)
runner.run_experiment()

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

0,1
train_loss,▁▁▁
val_loss,▁▁▁

0,1
train_accuracy,
train_f1_score,
train_loss,0.0
train_precision,
train_recall,
val_accuracy,
val_f1_score,
val_loss,0.0
val_precision,
val_recall,


Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

0,1
train_loss,▁▁▁
val_loss,▁▁▁

0,1
train_accuracy,
train_f1_score,
train_loss,0.0
train_precision,
train_recall,
val_accuracy,
val_f1_score,
val_loss,0.0
val_precision,
val_recall,


Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

Training:   0%|          | 0/49920 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indic

Validation:   0%|          | 0/12480 [00:00<?, ?it/s]

ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: string indices must be integers
ERROR:root:Batch processing failed: string indices must be integers
ERROR:root:Validation batch processing failed: str

0,1
train_loss,▁▁▁
val_loss,▁▁▁

0,1
train_accuracy,
train_f1_score,
train_loss,0.0
train_precision,
train_recall,
val_accuracy,
val_f1_score,
val_loss,0.0
val_precision,
val_recall,


KeyboardInterrupt: 

In [None]:
import sys
from argparse import ArgumentParser

def main():
    """메인 실행 함수"""
    parser = ArgumentParser(description='Medical Classification Experiment Runner')
    parser.add_argument('--config', type=str, required=True, help='Path to config file')
    parser.add_argument('--data', type=str, required=True, help='Path to data file')
    args = parser.parse_args()

    # 실험 실행
    runner = ExperimentRunner(args.config)
    runner.run_experiment(args.data)

if __name__ == '__main__':
    # 코랩 환경에서 인자 전달
    sys.argv = ['main.py', '--config', 'configs/experiment.yaml', '--data', 'data/medical_data.csv']
    main()

usage: colab_kernel_launcher.py [-h] --config CONFIG --data DATA
colab_kernel_launcher.py: error: the following arguments are required: --config, --data


SystemExit: 2