In [None]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("duongttr/chebi-20-new")

In [None]:
import os
import pandas as pd
from datasets import load_dataset
from rdkit import Chem
import selfies as sf

# ===== 사용자 설정 =====
DATASET_NAME = "duongttr/chebi-20-new"
OUT_ROOT = "/app/Mol-LLM/dataset/train"   # download_dataset.py의 raw_data_root 아래 raw/ 경로
SPLIT_NAME = "train"                  # 보통 단일 split; 만약 이미 분할돼 있으면 적절히 바꾸세요.
SEED = 42
# ======================

def pick(cols, cands):
    for c in cands:
        if c in cols: return c
    return None

def canon_smiles(s):
    m = Chem.MolFromSmiles(s)
    return Chem.MolToSmiles(m) if m else None

ds = load_dataset(DATASET_NAME, split=SPLIT_NAME)

cols = set(ds.column_names)
cap_col    = pick(cols, ["description","caption","text","molecular_caption","molecular_captions"])
selfies_col= pick(cols, ["SELFIES","selfies"])
smiles_col = pick(cols, ["SMILES","smiles","smi"])

assert cap_col, "캡션/설명 열을 찾지 못했습니다. (description/caption/text 등 후보 확인)"
# SELFIES 보장
if selfies_col is None:
    assert smiles_col, "SELFIES가 없으므로 SMILES 열이 필요합니다."
    ds = ds.map(lambda x: {"_smi_canon": canon_smiles(x[smiles_col])})
    ds = ds.filter(lambda x: x["_smi_canon"] is not None)
    ds = ds.map(lambda x: {"SELFIES": sf.encoder(x["_smi_canon"])})
else:
    if selfies_col != "SELFIES":
        ds = ds.rename_column(selfies_col, "SELFIES")

# description으로 표준화
if cap_col != "description":
    ds = ds.rename_column(cap_col, "description")

# 필요한 열만 유지
keep = ["SELFIES","description"]
ds = ds.remove_columns([c for c in ds.column_names if c not in keep])

# 80/10/10 분할
splits = ds.train_test_split(test_size=0.2, seed=SEED)
tmp = splits["train"].train_test_split(test_size=0.111111, seed=SEED)  # 0.111... of 0.9 ~= 0.1
train, valid, test = tmp["train"], tmp["test"], splits["test"]

os.makedirs(OUT_ROOT, exist_ok=True)
train.to_pandas()[keep].to_csv(os.path.join(OUT_ROOT,"BioT5_chebi20_train.csv"), index=False)
valid.to_pandas()[keep].to_csv(os.path.join(OUT_ROOT,"BioT5_chebi20_valid.csv"), index=False)
test.to_pandas()[keep].to_csv(os.path.join(OUT_ROOT,"BioT5_chebi20_test.csv"), index=False)
print("CSV saved to:", OUT_ROOT)