# Deduplication Dataset 

In [1]:
import os
import re
import pandas as pd
from collections import defaultdict, Counter
from tqdm import tqdm
import selfies as sf
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from datasets import load_from_disk, Dataset, enable_progress_bar
from tabulate import tabulate
import multiprocessing

# HF Progress Bar 활성화
enable_progress_bar()

# =============================================================================
# [설정: Configuration]
# =============================================================================

# 1. 1:1 매핑 정의 (중복 제거를 수행할 핵심 그룹)
STRICT_TASK_GROUP = {
    # [Deduplication 대상 그룹]
    "chebi-20-mol2text": "molecule_captioning_merged",
    "smol-molecule_captioning": "molecule_captioning_merged",

    "chebi-20-text2mol": "molecule_generation_merged",
    "smol-molecule_generation": "molecule_generation_merged",

    "retrosynthesis": "retrosynthesis_merged",
    "smol-retrosynthesis": "retrosynthesis_merged",

    "forward_reaction_prediction": "forward_reaction_merged",
    "smol-forward_synthesis": "forward_reaction_merged",
}

# 2. 중복 제거(Deduplication)를 수행할 "허용된 그룹 ID" 목록
# 이 리스트에 포함된 그룹만 같은 입력/출력일 경우 제거합니다.
# Property Prediction 등은 여기에 포함되지 않으므로 중복 제거를 건너뜁니다.
DEDUPLICATION_ENABLED_GROUPS = {
    "molecule_captioning_merged",
    "molecule_generation_merged",
    "retrosynthesis_merged",
    "forward_reaction_merged"
}

# 3. 스캐폴드 분할(Scaffold Split) 검사 대상 태스크
# Property Prediction 태스크들은 중복 제거는 안 하더라도, Test셋과의 오염(Leakage)은 막아야 하므로 여기에 포함됩니다.
SCAFFOLD_SPLIT_GROUPS = {
    # Original Tasks
    "bace", "bbbp", "clintox", "tox21", "toxcast", "sider", 
    "hiv", "muv", "esol", "freesolv", "lipo", "hopv",
    
    # Smol Tasks (개별 이름)
    "smol-property_prediction-bace", "smol-property_prediction-bbbp",
    "smol-property_prediction-clintox", "smol-property_prediction-tox21",
    "smol-property_prediction-toxcast", "smol-property_prediction-sider",
    "smol-property_prediction-hiv", "smol-property_prediction-muv",
    "smol-property_prediction-esol", "smol-property_prediction-freesolv",
    "smol-property_prediction-lipo"
}

CATEGORY_ORDER = [
    "Property Prediction (Regression)", "Property Prediction (Classification)", 
    "Forward Reaction Prediction", "Retrosynthesis", "Reagent Prediction", 
    "Molecule Captioning", "Description-Guided Molecule Generation", "Name Conversion"
]

DISPLAY_MAPPING = {
    "Property Prediction (Regression)": ["qm9_homo", "qm9_lumo", "qm9_homo_lumo_gap", "qm9_additional_label", "smol-property_prediction-esol", "smol-property_prediction-lipo", "smol-property_prediction-freesolv", "esol", "lipo", "freesolv"],
    "Property Prediction (Classification)": ["bace", "tox21", "toxcast", "clintox", "bbbp", "hiv", "sider", "muv", "hopv", "smol-property_prediction-bbbp", "smol-property_prediction-clintox", "smol-property_prediction-hiv", "smol-property_prediction-sider", "smol-property_prediction-tox21", "smol-property_prediction-toxcast", "smol-property_prediction-muv"],
    "Forward Reaction Prediction": ["forward_reaction_prediction", "smol-forward_synthesis"],
    "Retrosynthesis": ["retrosynthesis", "smol-retrosynthesis"],
    "Reagent Prediction": ["reagent_prediction"],
    "Molecule Captioning": ["chebi-20-mol2text", "smol-molecule_captioning"],
    "Description-Guided Molecule Generation": ["chebi-20-text2mol", "smol-molecule_generation"],
    "Name Conversion": ["smol-name_conversion-i2s", "smol-name_conversion-s2i", "smol-name_conversion-i2f", "smol-name_conversion-s2f"]
}

# =============================================================================
# [Helper Functions]
# =============================================================================
def get_strict_task_group(task_name):
    return STRICT_TASK_GROUP.get(task_name, task_name)

def decode_and_get_info(batch):
    input_mols = batch["input_mol_string"]
    canon_smiles_list, scaffold_list, valid_list = [], [], []
    for input_mol in input_mols:
        res_smiles, res_scaffold, is_valid = "", "", False
        try:
            if input_mol:
                clean_str = re.sub(r"<[^>]+>", "", str(input_mol)).strip()
                smiles = sf.decoder(clean_str)
                if smiles:
                    mol = Chem.MolFromSmiles(smiles)
                    if mol:
                        res_smiles = Chem.MolToSmiles(mol, canonical=True)
                        res_scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol) or ""
                        is_valid = True
        except: pass
        canon_smiles_list.append(res_smiles)
        scaffold_list.append(res_scaffold)
        valid_list.append(is_valid)
    return {"canon_smiles": canon_smiles_list, "scaffold": scaffold_list, "valid": valid_list}

# =============================================================================
# [Main Pipeline]
# =============================================================================
def final_integrated_cleanup(train_path, val_path, test_path, base_save_dir, num_cores=24):
    print(f"=== [Step 2] HF Multiprocessing을 이용한 분자 파싱 ===")
    splits = {"train": train_path, "val": val_path, "test": test_path}
    dfs = {}
    
    drop_stats = defaultdict(Counter)

    for name, path in splits.items():
        print(f" -> {name.upper()} 스플릿 처리 중...")
        ds = load_from_disk(path)
        
        # 1. 분자 정보 파싱
        parsed_info = ds.select_columns(['task', 'input_mol_string']).map(
            decode_and_get_info, batched=True, batch_size=2000, num_proc=num_cores, desc=f"Parsing {name}"
        )
        
        df_raw = ds.to_pandas()
        df_info = parsed_info.to_pandas()[['canon_smiles', 'scaffold', 'valid']]
        dfs[name] = pd.concat([df_raw.reset_index(drop=True), df_info.reset_index(drop=True)], axis=1)
        dfs[name]['task_group'] = dfs[name]['task'].apply(get_strict_task_group)

    print(f"\n=== [Step 3] 그룹 기반 오염 제거 및 조건부 중복 제거 ===")
    
    # Test셋 블랙리스트 구축
    test_df = dfs["test"]
    test_black_smiles = test_df[test_df['valid']].groupby('task_group')['canon_smiles'].apply(set).to_dict()
    test_black_scaf = test_df[test_df['valid']].groupby('task_group')['scaffold'].apply(set).to_dict()

    final_dfs = {"test": dfs["test"]} 
    
    for name in ["train", "val"]:
        df = dfs[name]
        out_col = next((c for c in ['label', 'output_string', 'output', 'target'] if c in df.columns), "output")
        
        def check_status(row):
            t_group = row['task_group'] 
            s, scaf, v = row['canon_smiles'], row['scaffold'], row['valid']
            
            if not v: return "Keep"
            
            # 1. Exact Match Check (모든 태스크 공통)
            if s in test_black_smiles.get(t_group, set()):
                return "Drop: Exact Match"
            
            # 2. Scaffold Match Check (설정된 SCAFFOLD_SPLIT_GROUPS만 적용)
            # Property Prediction은 여기가 적용되어야 하므로 리스트에 포함되어 있음.
            if t_group in SCAFFOLD_SPLIT_GROUPS:
                if scaf in test_black_scaf.get(t_group, set()):
                    return "Drop: Scaffold Match"
            
            return "Keep"

        tqdm.pandas(desc=f"Scanning {name}")
        df['status'] = df.progress_apply(check_status, axis=1)
        
        # 오염(Exact/Scaffold Match) 제거는 모든 태스크에 적용
        df_kept = df[df['status'] == "Keep"].copy()
        
        # --- [조건부 중복 제거 로직 수정] ---
        # "Deduplication 대상 그룹"에 속하는 행과 "보존할 그룹(Property Prediction 등)"을 분리
        
        mask_dedup = df_kept['task_group'].isin(DEDUPLICATION_ENABLED_GROUPS)
        
        # 1. 중복 제거 대상 (Retrosynthesis 등 4개 그룹)
        df_to_dedup = df_kept[mask_dedup]
        df_deduped = df_to_dedup.drop_duplicates(subset=['task_group', 'input_mol_string', out_col], keep='first')
        
        # 2. 중복 제거 제외 대상 (Property Prediction 등 - 데이터 보존)
        df_preserved = df_kept[~mask_dedup]
        
        # 3. 다시 합치기
        df_clean = pd.concat([df_deduped, df_preserved]).sort_index()
        
        # --- 통계 기록 ---
        # 1. 오염(Leakage) 탈락 기록
        for task, status in zip(df['task'], df['status']):
            if status != "Keep":
                drop_stats[task][f"{name.upper()} {status}"] += 1
        
        # 2. 중복(Dedup) 탈락 기록 (Dedup 대상 그룹만 해당)
        removed_indices = set(df_to_dedup.index) - set(df_deduped.index)
        for idx in removed_indices:
            task = df_to_dedup.loc[idx, 'task']
            drop_stats[task][f"{name.upper()} Drop: Task Group Dup"] += 1
            
        final_dfs[name] = df_clean
        print(f" -> {name.upper()}: 최종 {len(df_clean):,} (제거됨: {len(df) - len(df_clean):,})")

    print(f"\n=== [Step 4] 최종 상세 리포트 및 저장 ===")
    
    report_data = []
    all_tasks = sorted(dfs["train"]['task'].unique())
    for task in all_tasks:
        row = {"Task Name": task}
        row.update(drop_stats[task])
        report_data.append(row)
    
    if report_data:
        report_df = pd.DataFrame(report_data).fillna(0)
        cols = ["Task Name"] + sorted([c for c in report_df.columns if c != "Task Name"])
        print("\n[Step-by-Step Drop Statistics]")
        print(tabulate(report_df[cols], headers="keys", tablefmt="grid", floatfmt=".0f"))
    
    summary_data = []
    for cat in CATEGORY_ORDER:
        tasks = DISPLAY_MAPPING.get(cat, [])
        counts = {s: final_dfs[s][final_dfs[s]['task'].isin(tasks)].shape[0] for s in ["train", "val", "test"]}
        summary_data.append({
            "Category": cat, 
            "Train": f"{counts['train']:,}", "Val": f"{counts['val']:,}", "Test": f"{counts['test']:,}", 
            "Total": f"{counts['train']+counts['test']:,}"
        })
    print("\n[Category Summary]")
    print(tabulate(summary_data, headers="keys", tablefmt="github", stralign="right"))

    for name, df in final_dfs.items():
        save_path = os.path.join(base_save_dir, f"GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_{name}_FINAL_CLEANED")
        out_df = df.drop(columns=['canon_smiles', 'scaffold', 'valid', 'task_group', 'status'], errors='ignore')
        Dataset.from_pandas(out_df, preserve_index=False).save_to_disk(save_path)
        print(f"[Saved] {save_path}")

if __name__ == "__main__":
    try: multiprocessing.set_start_method('spawn', force=True)
    except RuntimeError: pass

    train_in = "Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415_raw"
    val_in = "Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_val_3.3M_0415_raw"
    test_in = "Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415_raw"
    save_dir = "Mol-LLM_Custom/dataset/train_official/"
    
    final_integrated_cleanup(train_in, val_in, test_in, save_dir, num_cores=32)

=== [Step 2] HF Multiprocessing을 이용한 분자 파싱 ===
 -> TRAIN 스플릿 처리 중...


Loading dataset from disk:   0%|          | 0/45 [00:00<?, ?it/s]

 -> VAL 스플릿 처리 중...
 -> TEST 스플릿 처리 중...

=== [Step 3] 그룹 기반 오염 제거 및 조건부 중복 제거 ===


Scanning train: 100%|██████████| 3465790/3465790 [00:43<00:00, 80364.04it/s]


 -> TRAIN: 최종 3,313,489 (제거됨: 152,301)


Scanning val: 100%|██████████| 36016/36016 [00:00<00:00, 81488.40it/s]


 -> VAL: 최종 35,199 (제거됨: 817)

=== [Step 4] 최종 상세 리포트 및 저장 ===

[Step-by-Step Drop Statistics]
+----+----------------------------------+---------------------------+------------------------------+------------------------------+-------------------------+----------------------------+----------------------------+
|    | Task Name                        |   TRAIN Drop: Exact Match |   TRAIN Drop: Scaffold Match |   TRAIN Drop: Task Group Dup |   VAL Drop: Exact Match |   VAL Drop: Scaffold Match |   VAL Drop: Task Group Dup |
|  0 | bace                             |                         0 |                            0 |                            0 |                       0 |                          0 |                          0 |
+----+----------------------------------+---------------------------+------------------------------+------------------------------+-------------------------+----------------------------+----------------------------+
|  1 | chebi-20-mol2text                |

Saving the dataset (0/1 shards):   0%|          | 0/32822 [00:00<?, ? examples/s]

[Saved] Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_FINAL_CLEANED


Saving the dataset (0/44 shards):   0%|          | 0/3313489 [00:00<?, ? examples/s]

[Saved] Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_FINAL_CLEANED


Saving the dataset (0/1 shards):   0%|          | 0/35199 [00:00<?, ? examples/s]

[Saved] Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_val_FINAL_CLEANED
