# PREpiBind-ESMC-300M

In [None]:
#@title Prepare Dataset
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import os

if not os.path.exists("PREpiBind"):
    !git clone https://github.com/daylight-00/PREpiBind > /dev/null 2>&1
if os.path.basename(os.getcwd()) != "PREpiBind":
    %cd PREpiBind
!pip install -q esm

# --- 초기 데이터 준비 ---
mhc_map_l = 'data/mhc_mapping_light.csv'
mhc_map_f = 'data/mhc_mapping.csv'
df_map = pd.read_csv(mhc_map_l)
mhc_map = mhc_map_l
hla_emb_path = 'data/emb_hla_esmc_small_light_0601_fp16.h5'

df = pd.DataFrame(columns=['MHC', 'MHC_alpha', 'MHC_beta', 'Epitope'])

# --- 상단 체크박스 ---
use_full_hla = widgets.Checkbox(
    value=False,
    description="Use full HLA",
    style={'description_width': 'initial'}
)
output = widgets.Output()
display(use_full_hla, output)

# --- 드롭다운 초기화 ---
mhc_alpha_dropdown = widgets.Dropdown(options=[], description="MHC alpha:", style={'description_width': 'initial'})
mhc_beta_dropdown = widgets.Dropdown(options=[], description="MHC beta:", style={'description_width': 'initial'})

def update_hla_dropdowns():
    global hla_list_a, hla_list_b, mhc_alpha_dropdown, mhc_beta_dropdown
    df_map['Chain'] = df_map['HLA_Name'].str.replace('HLA-', '')
    df_map_a = df_map[df_map['HLA_Name'].str.contains('A')]
    df_map_b = df_map[df_map['HLA_Name'].str.contains('B')]
    hla_list_a = df_map_a['HLA_Name'].sort_values().unique().tolist()
    hla_list_b = df_map_b['HLA_Name'].sort_values().unique().tolist()
    alpha_value = mhc_alpha_dropdown.value
    beta_value = mhc_beta_dropdown.value
    mhc_alpha_dropdown.options = hla_list_a
    mhc_beta_dropdown.options = hla_list_b
    if alpha_value in hla_list_a:
        mhc_alpha_dropdown.value = alpha_value
    else:
        mhc_alpha_dropdown.value = hla_list_a[0] if hla_list_a else None
    if beta_value in hla_list_b:
        mhc_beta_dropdown.value = beta_value
    else:
        mhc_beta_dropdown.value = hla_list_b[0] if hla_list_b else None

update_hla_dropdowns()

def on_hla_checkbox_change(change):
    global df_map, mhc_map, hla_emb_path
    with output:
        clear_output()
        if change['new']:
            mhc_map = mhc_map_f
            df_map = pd.read_csv(mhc_map)
            hla_emb_path = 'data/emb_hla_esmc_small_0601_fp16.h5'
        else:
            mhc_map = mhc_map_l
            hla_emb_path = 'data/emb_hla_esmc_small_light_0601_fp16.h5'
            df_map = pd.read_csv(mhc_map)
        update_hla_dropdowns()
        refresh_ui()
use_full_hla.observe(on_hla_checkbox_change, names='value')

# --- 나머지 위젯/함수 ---
file_path_box = widgets.Text(
    value='data/dataset_demo.csv',
    placeholder='CSV 경로 입력 (예: ./dataset.csv)',
    description='파일 경로:',
    style={'description_width': 'initial'}
)
load_button = widgets.Button(description="CSV 불러오기")
add_button = widgets.Button(description="추가")
export_button = widgets.Button(description="CSV 내보내기")
reset_button = widgets.Button(description="초기화", button_style='danger')
epitope_textbox = widgets.Text(
    value='', placeholder='최대 15자, 대문자로 입력', description='Epitope:', style={'description_width': 'initial'}
)
epitope_msg = widgets.HTML("현재 0/15자 입력 중")
output_area = widgets.Output()

def make_preferences_box():
    global num_workers, batch_size, plot_kde, use_compile, out_path, show_top_binders
    num_workers = widgets.IntSlider(
        value=8 if os.cpu_count() > 8 else os.cpu_count(),
        description="Num workers:",
        style={'description_width': 'initial'},
        min=0,
        max=os.cpu_count()
    )
    batch_size = widgets.IntSlider(
        value=min(512, len(df)),
        description="Batch size:",
        style={'description_width': 'initial'},
        min=1,
        max=max(1, len(df))
    )
    plot_kde = widgets.Checkbox(
        value=True,
        description="Plot KDE",
        style={'description_width': 'initial'}
    )
    show_top_binders = widgets.Dropdown(
        options=[('None', None), ('Top 5', 5), ('Top 10', 10)],
        value=5,
        description="Show top binders:",
        style={'description_width': 'initial'}
    )
    use_compile = widgets.Checkbox(
        value=False,
        description="Use torch compile",
        style={'description_width': 'initial'}
    )
    out_path = widgets.Text(
        value='outputs',
        placeholder='출력 폴더 경로 입력 (예: ./outputs)',
        description='출력 폴더:',
        style={'description_width': 'initial'}
    )
    return widgets.VBox([
        out_path, show_top_binders, plot_kde, num_workers, batch_size, use_compile
    ])

def refresh_ui():
    with output_area:
        clear_output()
        if df.empty:
            df_widget = widgets.HTML('<b>입력된 데이터가 없습니다.</b>')
            ui = widgets.VBox([
                widgets.HBox([mhc_alpha_dropdown, mhc_beta_dropdown, epitope_textbox, add_button]),
                epitope_msg,
                widgets.HBox([file_path_box, load_button, export_button, reset_button]),
                df_widget
            ])
        else:
            df_widget = widgets.HTML(df.tail(5).to_html(index=False))
            preferences_box = make_preferences_box()
            ui = widgets.VBox([
                widgets.HBox([mhc_alpha_dropdown, mhc_beta_dropdown, epitope_textbox, add_button]),
                epitope_msg,
                widgets.HBox([file_path_box, load_button, export_button, reset_button]),
                df_widget,
                widgets.HTML('<hr>'),
                preferences_box
            ])
        display(ui)

def add_row(_):
    global df
    alpha, beta, epitope = mhc_alpha_dropdown.value, mhc_beta_dropdown.value, epitope_textbox.value
    if not (alpha and beta and epitope.strip()):
        with output_area:
            clear_output()
            display(widgets.HTML('<b style="color:red;">MHC alpha, MHC beta, Epitope 모두 입력해야 합니다.</b>'))
            refresh_ui()
    else:
        mhc = alpha + '_' + beta
        df.loc[len(df)] = [mhc, alpha, beta, epitope]
        epitope_textbox.value = ""
        refresh_ui()

def export_csv(_):
    path = file_path_box.value
    with output_area:
        clear_output()
        df.to_csv(path, index=False)
        display(widgets.HTML('<b>저장되었습니다: ./tmp/inputs.csv</b>'))
        refresh_ui()

def load_csv(_):
    global df
    path = file_path_box.value
    with output_area:
        clear_output()
        if not os.path.isfile(path):
            display(widgets.HTML(f'<b style="color:red;">파일을 찾을 수 없습니다: {path}</b>'))
        else:
            try:
                df = pd.read_csv(path)
                display(widgets.HTML(f'<b>불러온 파일: {path}<br>상위 3개 행:</b>'))
                display(widgets.HTML(df.head(3).to_html(index=False)))
            except Exception as e:
                display(widgets.HTML(f'<b style="color:red;">파일 읽기 실패: {str(e)}</b>'))
        refresh_ui()

def epitope_textbox_change(change):
    new_val = change['new'].upper()[:15]
    if new_val != epitope_textbox.value:
        epitope_textbox.value = new_val
        return
    epitope_msg.value = f"현재 {len(epitope_textbox.value)}/15자 입력 중"

def reset_df(_):
    global df
    df = pd.DataFrame(columns=['MHC', 'MHC_alpha', 'MHC_beta', 'Epitope'])
    with output_area:
        clear_output()
        display(widgets.HTML('<b style="color:green;">DataFrame이 초기화되었습니다.</b>'))
        refresh_ui()

add_button.on_click(add_row)
export_button.on_click(export_csv)
load_button.on_click(load_csv)
reset_button.on_click(reset_df)
epitope_textbox.observe(epitope_textbox_change, names='value')

display(output_area)
refresh_ui()


In [None]:
#@title Run Prediciton

if len(df) == 0:
    raise ValueError("Please add data before running the prediction.")

import sys, os
utils_path = os.path.abspath('code')
sys.path.insert(0, utils_path)
from inference import main, load_config

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='daylight00/esmc-300m-2024-12', filename="esmc_300m_2024_12_v0_fp16.pth", local_dir="models")
hf_hub_download(repo_id='daylight00/prepibind-esmc-300m', filename="prepi_esmc_small_e5_s128_f4_fp16.pth", local_dir="models")
hf_hub_download(repo_id='daylight00/prepibind-esmc-300m', filename="emb_hla_esmc_small_0601_fp16.h5", local_dir="data") if use_full_hla.value else None

os.makedirs(out_path.value, exist_ok=True)
df.to_csv(f'{out_path.value}/dataset.csv', index=False)
test_path = f'{out_path.value}/dataset.csv'
config_path = 'config_demo.py'

config = load_config(
    config_path,
    num_workers=num_workers.value,
    batch_size=batch_size.value,
    use_compile=use_compile.value,
    test_path=test_path,
    hla_path=mhc_map,
    plot = plot_kde.value,
    out_path=out_path.value,
    hla_emb_path=hla_emb_path,
    )

df_out = main(config)
cols = ['Score', 'Logits']
df_out[cols] = df_out[cols].apply(pd.to_numeric, errors='coerce')
df_out = df_out.nlargest(show_top_binders.value, 'Score')
for col in cols:
    df_out[col] = df_out[col].map(lambda x: f"{x:.5f}")
display(df_out)

In [None]:
#@title Export Results
from google.colab import files

# 예시: prediction.csv 파일 다운로드
files.download(f"{out_path.value}/prediction.csv")
files.download(f"{out_path.value}/plot.png") if plot_kde.value else None