# Mol-LLM 데이터셋 검증 스크립트

In [1]:
import os
import re
import sys
from collections import defaultdict
import datasets
from datasets import load_from_disk, enable_progress_bar
import selfies as sf
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

# HF Progress Bar 활성화
enable_progress_bar()

# =============================================================================
# [Configuration]
# =============================================================================

# 스캐폴드 누수까지 엄격하게 잡아야 하는 태스크 목록
SCAFFOLD_SPLIT_TASKS = {
    "bace", "bbbp", "clintox", "tox21", "toxcast", "sider", 
    "hiv", "muv", "esol", "freesolv", "lipo", "hopv"
}

# =============================================================================
# [Helper Functions]
# =============================================================================

def decode_and_get_info(batch):
    """
    SMILES와 Scaffold 정보를 추출하여 컬럼에 추가하는 함수
    (이전 단계와 동일하지만, None 대신 빈 문자열 반환으로 안정성 확보)
    """
    input_mols = batch["input_mol_string"]
    
    canon_smiles_list = []
    scaffold_list = []
    valid_list = []
    
    for input_mol in input_mols:
        res_smiles = ""
        res_scaffold = ""
        is_valid = False
        
        try:
            if input_mol:
                clean_str = re.sub(r"<[^>]+>", "", str(input_mol)).strip()
                smiles = sf.decoder(clean_str)
                
                if smiles:
                    mol = Chem.MolFromSmiles(smiles)
                    if mol:
                        res_smiles = Chem.MolToSmiles(mol, canonical=True)
                        res_scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol)
                        if res_scaffold is None: res_scaffold = ""
                        is_valid = True
        except:
            pass
            
        canon_smiles_list.append(res_smiles)
        scaffold_list.append(res_scaffold)
        valid_list.append(is_valid)
            
    return {
        "canon_smiles": canon_smiles_list,
        "scaffold": scaffold_list,
        "valid": valid_list
    }

# =============================================================================
# [Main Cleaning Logic]
# =============================================================================

def clean_and_verify(train_path, test_path, save_path, num_cores=32):
    print(f"=== Mol-LLM Dataset Decontamination & Verification ===")
    
    # 1. Load Datasets
    print(f"\n[Step 1/4] Loading Datasets...")
    try:
        train_ds = load_from_disk(train_path)
        test_ds = load_from_disk(test_path)
        print(f" -> Original TRAIN size: {len(train_ds):,}")
        print(f" -> TEST size:           {len(test_ds):,}")
    except Exception as e:
        print(f"[Error] Load Failed: {e}")
        return

    # 2. Parsing (Map)
    cols = ["task", "input_mol_string"]
    print(f"\n[Step 2/4] Parsing Molecules to Identify Leaks...")
    
    # 캐시 문제 방지를 위해 load_from_cache_file=False 사용
    train_parsed = train_ds.select_columns(cols).map(
        decode_and_get_info, batched=True, batch_size=1000, num_proc=num_cores, 
        desc="Parsing Train", load_from_cache_file=False
    )
    test_parsed = test_ds.select_columns(cols).map(
        decode_and_get_info, batched=True, batch_size=1000, num_proc=num_cores, 
        desc="Parsing Test", load_from_cache_file=False
    )

    # 3. Build Blocklist from Test Set
    print(f"\n[Step 3/4] Building Blocklist from Test Set...")
    
    # Task별로 Test Set의 분자(Exact)와 스캐폴드를 수집
    test_smiles_map = defaultdict(set)
    test_scaffolds_map = defaultdict(set)
    
    for row in test_parsed:
        if row['valid']:
            t = row['task']
            if row['canon_smiles']: test_smiles_map[t].add(row['canon_smiles'])
            if row['scaffold']: test_scaffolds_map[t].add(row['scaffold'])

    print(f" -> Blocklist created for {len(test_smiles_map)} tasks.")

    # 4. Filter Train Set (Decontamination)
    print(f"\n[Step 4/4] Filtering Train Set (Removing Leaks)...")

    # 필터링 함수 정의
    def filter_leaks(batch):
        # 배치 단위 필터링 (True면 유지, False면 삭제)
        tasks = batch['task']
        smiles = batch['canon_smiles']
        scaffolds = batch['scaffold']
        valids = batch['valid']
        
        keep_indices = []
        
        for i in range(len(tasks)):
            t = tasks[i]
            s = smiles[i]
            scaf = scaffolds[i]
            v = valids[i]
            
            # Invalid 데이터는 일단 유지 (나중에 별도로 처리하거나 지금 제거 가능)
            # 여기서는 "Leak 제거"가 목적이므로, Invalid는 Leak가 아니라고 보고 유지합니다.
            if not v:
                keep_indices.append(True)
                continue

            # Rule 1: Exact Match Check (All Tasks)
            # 해당 태스크의 Test셋에 있는 SMILES와 같으면 삭제
            if s in test_smiles_map[t]:
                keep_indices.append(False)
                continue
            
            # Rule 2: Scaffold Match Check (Only Scaffold Tasks)
            is_scaffold_task = any(st in t for st in SCAFFOLD_SPLIT_TASKS)
            if is_scaffold_task:
                if scaf in test_scaffolds_map[t]:
                    keep_indices.append(False) # 스캐폴드 겹치면 삭제
                    continue
            
            # Rule 3: Random Task (Scaffold Overlap Allowed)
            # 위에서 안 걸렸으면 통과
            keep_indices.append(True)
            
        return keep_indices

    # 원본 데이터셋에 파싱된 정보(smiles, scaffold)를 임시로 붙여서 필터링
    # (효율성을 위해 파싱된 dataset을 기준으로 필터링 마스크를 생성하고, 원본을 자르는 방식 사용)
    
    # 하지만 더 간단하게는 파싱된 데이터셋에 원본 데이터의 모든 컬럼이 없으므로,
    # 원본 데이터셋을 map으로 다시 돌리는 것보다, 
    # **train_parsed**에서 필터링 마스크를 구해서 **train_ds**를 select하는 것이 빠름.
    
    # 여기서는 코드를 간결하게 하기 위해 train_ds에 파싱 정보를 합쳐서 filter를 적용하겠습니다.
    # (데이터가 아주 크면 메모리 이슈가 있을 수 있으나 3M 정도는 괜찮음)
    
    # 4-1. 원본 Train DS에 파싱 컬럼 추가
    train_ds_with_info = train_ds.add_column("canon_smiles", train_parsed["canon_smiles"])
    train_ds_with_info = train_ds_with_info.add_column("scaffold", train_parsed["scaffold"])
    train_ds_with_info = train_ds_with_info.add_column("valid", train_parsed["valid"])

    # 4-2. 필터링 수행
    filtered_train_ds = train_ds_with_info.filter(
        filter_leaks, 
        batched=True, 
        batch_size=1000,
        num_proc=num_cores,
        desc="Decontaminating"
    )

    # 4-3. 불필요한 임시 컬럼 제거 (저장 전)
    final_clean_ds = filtered_train_ds.remove_columns(["canon_smiles", "scaffold", "valid"])
    
    removed_count = len(train_ds) - len(final_clean_ds)
    print(f"\n>>> Decontamination Complete!")
    print(f" -> Removed: {removed_count:,} examples")
    print(f" -> Final Train Size: {len(final_clean_ds):,}")

    # =========================================================================
    # [Verification Report Generation]
    # =========================================================================
    print(f"\n[Verification] Generating Final Report based on Cleaned Data...")
    
    # 통계 집계 (정제된 Train vs 원본 Test)
    task_stats = defaultdict(lambda: {
        "train_smiles": set(), "test_smiles": set(),
        "train_scaffolds": set(), "test_scaffolds": set()
    })

    # Test Stats 채우기 (이미 파싱된 것 활용)
    for row in test_parsed:
        if row['valid']:
            t = row['task']
            task_stats[t]["test_smiles"].add(row['canon_smiles'])
            if row['scaffold']: task_stats[t]["test_scaffolds"].add(row['scaffold'])

    # Cleaned Train Stats 채우기 (필터링된 데이터셋 다시 순회 필요)
    # 이미 메모리에 있는 정보 활용을 위해 select 사용
    final_check_cols = filtered_train_ds.select_columns(["task", "canon_smiles", "scaffold", "valid"])
    
    for row in final_check_cols: # tqdm(final_check_cols, desc="Final Check"):
        if row['valid']:
            t = row['task']
            task_stats[t]["train_smiles"].add(row['canon_smiles'])
            if row['scaffold']: task_stats[t]["train_scaffolds"].add(row['scaffold'])

    # 표 출력
    print("\n" + "="*90)
    print(f"{'TASK NAME':<40} | {'TYPE':<10} | {'LEAK(Exact)':<12} | {'LEAK(Scaf)':<12} | {'STATUS'}")
    print("="*90)

    for task, stats in sorted(task_stats.items()):
        if any(x in task for x in SCAFFOLD_SPLIT_TASKS):
            split_type = "SCAFFOLD"
        else:
            split_type = "RANDOM"
            
        leak_exact = len(stats["train_smiles"].intersection(stats["test_smiles"]))
        leak_scaf = len(stats["train_scaffolds"].intersection(stats["test_scaffolds"]))
        
        status = "[PASS]"
        if leak_exact > 0: status = "[FAIL:Leak]" # 이제 Leak가 있으면 안됨
        
        if split_type == "SCAFFOLD":
            if leak_scaf > 0: status = "[FAIL:Scaf]"
        
        scaf_display = f"{leak_scaf}"
        if split_type == "RANDOM" and leak_scaf > 0:
            scaf_display += " (OK)"
            
        print(f"{task:<40} | {split_type:<10} | {leak_exact:<12} | {scaf_display:<12} | {status}")

    print("="*90)
    
    print(f"Saving Cleaned Dataset to: {save_path}")
    final_clean_ds.save_to_disk(save_path)
    print("Done.")

if __name__ == "__main__":
    # 1. 입력 경로
    train_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415_deduplicate"
    test_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415_deduplicate"
    
    # 2. 저장 경로 (Cleaned 버전)
    save_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_CLEANED"
    
    num_cores = 32
    
    clean_and_verify(train_path, test_path, save_path, num_cores)
# 5분 넘게 걸림.

=== Mol-LLM Dataset Decontamination & Verification ===

[Step 1/4] Loading Datasets...


Loading dataset from disk:   0%|          | 0/45 [00:00<?, ?it/s]

 -> Original TRAIN size: 3,465,790
 -> TEST size:           32,822

[Step 2/4] Parsing Molecules to Identify Leaks...


Parsing Train (num_proc=32):   0%|          | 0/3465790 [00:00<?, ? examples/s]

[14:06:03] Unusual charge on atom 29 number of radical electrons set to zero


Parsing Test (num_proc=32):   0%|          | 0/32822 [00:00<?, ? examples/s]




[Step 3/4] Building Blocklist from Test Set...
 -> Blocklist created for 17 tasks.

[Step 4/4] Filtering Train Set (Removing Leaks)...

>>> Decontamination Complete!
 -> Removed: 1,007 examples
 -> Final Train Size: 3,464,783

[Verification] Generating Final Report based on Cleaned Data...

TASK NAME                                | TYPE       | LEAK(Exact)  | LEAK(Scaf)   | STATUS
bace                                     | SCAFFOLD   | 0            | 0            | [PASS]
chebi-20-mol2text                        | RANDOM     | 0            | 856 (OK)     | [PASS]
forward_reaction_prediction              | RANDOM     | 0            | 494 (OK)     | [PASS]
qm9_homo                                 | RANDOM     | 0            | 301 (OK)     | [PASS]
qm9_homo_lumo_gap                        | RANDOM     | 0            | 311 (OK)     | [PASS]
qm9_lumo                                 | RANDOM     | 0            | 277 (OK)     | [PASS]
reagent_prediction                       | RANDOM     | 

Saving the dataset (0/45 shards):   0%|          | 0/3464783 [00:00<?, ? examples/s]

Done.


In [3]:
from datasets import load_from_disk
train_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415_deduplicate_CLEANED"
test_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415_deduplicate_CLEANED"
val_path = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415_deduplicate_CLEANED"

train_ds = load_from_disk(train_path)
test_ds = load_from_disk(test_path)
val_ds = load_from_disk(val_path)

len(train_ds), len(test_ds), len(val_ds)

Loading dataset from disk:   0%|          | 0/45 [00:00<?, ?it/s]

(3464783, 32822, 36016)