In [ ]:
# Cell 1: 라이브러리 임포트 및 설정
import mlflow
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from est_utils import FONT_PROP, BACKBONE_CLASSES

# MLflow 트래킹 서버 URI 설정
mlflow.set_tracking_uri("http://0.0.0.0:5000")

# 분석할 MLflow Experiment 이름 또는 ID
EXPERIMENT_NAME = "FER Fine-tuning with Soft Labels"

# 분석할 Run ID 목록 (예시: 여러 훈련 실행을 비교할 경우)
# run_ids = ["run_id_1", "run_id_2"]
# 단일 Run을 분석할 경우:
TARGET_RUN_ID = "c88cb1374eed41dc9d164a91a5f4e956" # 여기에 분석하고자 하는 run_id를 입력하세요.

# 지표 및 아티팩트 이름
METRICS_TO_ANALYZE = ["train_loss", "val_loss", "general_val_loss", "general_val_accuracy", "test_accuracy", "test_f1_macro"]
ARTIFACT_NAME = "08_test_confusion_matrix.png" # 08_finetuning_refactored.py에서 저장된 혼동 행렬 이미지 이름

In [ ]:
# Cell 2: MLflow Run 정보 및 지표 가져오기
def get_run_metrics(run_id, metrics_list):
    client = mlflow.tracking.MlflowClient()
    run = client.get_run(run_id)
    metrics = {metric: run.data.metrics.get(metric) for metric in metrics_list}
    params = run.data.params
    return metrics, params, run.info.run_name

# 단일 Run 분석 예시
if TARGET_RUN_ID:
    metrics, params, run_name = get_run_metrics(TARGET_RUN_ID, METRICS_TO_ANALYZE)
    print(f"
--- Analysis for Run: {run_name} ({TARGET_RUN_ID}) ---"
    print("Parameters:")
    for k, v in params.items():
        print(f"  {k}: {v}")
    print("Metrics:")
    for k, v in metrics.items():
        print(f"  {k}: {v:.4f}")

    # 시계열 지표 가져오기 (예: 에포크별 손실)
    client = mlflow.tracking.MlflowClient()
    history_metrics = client.get_metric_history(TARGET_RUN_ID, "val_loss") # 예시: val_loss 시계열
    if history_metrics:
        epochs = [m.step for m in history_metrics]
        val_losses = [m.value for m in history_metrics]
        plt.figure(figsize=(10, 6))
        plt.plot(epochs, val_losses, marker='o')
        plt.title(f'Validation Loss History for {run_name}', fontproperties=FONT_PROP)
        plt.xlabel('Epoch', fontproperties=FONT_PROP)
        plt.ylabel('Validation Loss', fontproperties=FONT_PROP)
        plt.grid(True)
        plt.show()

    # 아티팩트 다운로드 및 표시 (예: 혼동 행렬 이미지)
    try:
        artifact_path = client.download_artifacts(TARGET_RUN_ID, ARTIFACT_NAME)
        print(f"Artifact downloaded to: {artifact_path}")
        img = plt.imread(artifact_path)
        plt.figure(figsize=(8, 8))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Confusion Matrix for {run_name}', fontproperties=FONT_PROP)
        plt.show()
    except Exception as e:
        print(f"Error downloading or displaying artifact {ARTIFACT_NAME}: {e}")
else:
    print("분석할 TARGET_RUN_ID를 설정해 주세요.")