In [2]:
import os
import re
from datasets import load_from_disk
from transformers import AutoTokenizer
from collections import Counter
import multiprocessing as mp

# ==========================================
# [설정] 임계값 및 경로 설정
# ==========================================
NUM_PROC = 64           # 병렬 프로세스 수
THRESHOLD = 512         # 최대 토큰 길이 (Prompt + Target)
MODEL_ID = "GSAI-ML/LLaDA-8B-Instruct"

# 검사할 데이터셋 경로 리스트
TARGET_PATHS = [
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415",
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415",
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415"
]

# SELFIES 등 특수 토큰 사전 경로 (길이 계산 정확도 향상용)
SELFIES_DICT_PATH = "/home/jovyan/CHJ/Mol-LLM_Custom/model/selfies_dict.txt"

# 멀티프로세싱 설정
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ==========================================
# [1] Special Tokens 정의 (길이 계산용)
# ==========================================
# Mol-LLM Custom Tokens
CUSTOM_SPECIAL_TOKENS = [
    "<BOOLEAN>", "</BOOLEAN>", "<FLOAT>", "</FLOAT>", "<DESCRIPTION>", "</DESCRIPTION>",
    "<SELFIES>", "</SELFIES>", "<GRAPH>", "</GRAPH>", "<3D_CONFORMER>", "</3D_CONFORMER>",
    "<mol>", "<INSTRUCTION>", "</INSTRUCTION>", "|>>|", "<IUPAC>", "</IUPAC>", "<MOLFORMULA>", "</MOLFORMULA>"
]
# 숫자 토큰 추가
CUSTOM_SPECIAL_TOKENS.extend([f"<|{i}|>" for i in range(10)] + ["<|+|>", "<|-|>", "<|.|>"])

# 전역 토크나이저 변수 (Worker 프로세스용)
global_tokenizer = None

def load_selfies_tokens(path):
    if not os.path.exists(path):
        return []
    with open(path, 'r') as f:
        tokens = f.read().splitlines()
    return [t.strip() for t in tokens if t.strip()]

def init_worker():
    """
    각 워커 프로세스에서 토크나이저를 로드하고 특수 토큰을 등록합니다.
    (이렇게 해야 정확한 토큰 개수를 셀 수 있습니다.)
    """
    global global_tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        # 1. Custom Tags 추가
        tokens_to_add = list(set(CUSTOM_SPECIAL_TOKENS))
        
        # 2. SELFIES Dictionary 추가
        selfies_tokens = load_selfies_tokens(SELFIES_DICT_PATH)
        tokens_to_add.extend(selfies_tokens)
        
        # 3. 토크나이저에 추가 (중복 제거)
        existing_vocab = set(tokenizer.get_vocab().keys())
        final_tokens = [t for t in set(tokens_to_add) if t not in existing_vocab]
        
        if final_tokens:
            tokenizer.add_tokens(final_tokens)
            
        global_tokenizer = tokenizer
    except Exception as e:
        print(f"[Worker Error] Failed to load tokenizer: {e}")

def check_llada_format(prompt, target):
    """
    LLaDA 포맷 검증 (생성 코드 기준)
    """
    # 1. 텍스트 존재 여부 확인
    if not prompt or not target:
        return False, "Empty Text"

    # 2. LLaDA 프롬프트 포맷 검사 (수정됨)
    # 생성 코드: "<|startoftext|>", "<|start_header_id|>system...", "<|eot_id|>"
    required_tags = [
        "<|startoftext|>",  # [수정] <|begin_of_text|> -> <|startoftext|>
        "<|start_header_id|>system<|end_header_id|>",
        "<|start_header_id|>user<|end_header_id|>",
        "<|start_header_id|>assistant<|end_header_id|>"
    ]
    
    for tag in required_tags:
        if tag not in prompt:
            return False, f"Missing Tag: {tag}"
    
    # 3. Target EOS 검사 (수정됨)
    # 생성 코드: formatted_target_text = ... + "<|eot_id|>"
    if "<|eot_id|>" not in target: # [수정] <|end_of_text|> -> <|eot_id|>
        return False, "Missing Target EOS (<|eot_id|>)"

    return True, "OK"

def process_batch(batch):
    """
    배치 단위로 포맷 검사 및 길이 계산 수행
    """
    # 워커 초기화 확인 (global_tokenizer)
    if global_tokenizer is None:
        init_worker()
        
    prompts = batch['prompt_text']
    targets = batch['target_text']
    # 그래프 데이터(x)가 없으면 None으로 채움 (Optional 처리)
    xs = batch.get('x', [None] * len(prompts)) 
    
    keeps = []
    reasons = []
    
    # 1. 포맷 및 그래프 데이터 검사
    valid_indices = []
    texts_to_tokenize = []
    
    for i in range(len(prompts)):
        p = prompts[i]
        t = targets[i]
        
        # A. 포맷 검사
        is_valid_fmt, reason = check_llada_format(p, t)
        
        # B. 그래프 데이터 검사 (x가 None이면 비정상)
        # 데이터셋에 x 컬럼이 아예 없으면 위에서 None 리스트로 처리되어 여기서 걸림
        # x 컬럼이 있지만 값이 None인 경우도 처리
        if xs[i] is None and 'x' in batch: 
             is_valid_fmt = False
             reason = "Empty Graph Node Features (x)"

        if is_valid_fmt:
            valid_indices.append(i)
            texts_to_tokenize.append(p + t) # 길이 계산용 텍스트
            keeps.append(True) # 임시 True (길이 검사 전)
            reasons.append("OK")
        else:
            keeps.append(False)
            reasons.append(reason)
            
    # 2. 길이 검사 (유효한 포맷인 것만 토큰화하여 성능 최적화)
    if texts_to_tokenize:
        # padding=False로 실제 길이 측정
        tokenized = global_tokenizer(texts_to_tokenize, add_special_tokens=False)
        lengths = [len(ids) for ids in tokenized['input_ids']]
        
        for idx, length in zip(valid_indices, lengths):
            if length > THRESHOLD:
                keeps[idx] = False
                reasons[idx] = f"Length Exceeded ({length} > {THRESHOLD})"
                
    return {
        "keep": keeps,
        "drop_reason": reasons
    }

def main():
    print(f"=== LLaDA Dataset Filtering (Threshold: {THRESHOLD}) ===")
    
    # 메인 프로세스에서 토크나이저 테스트
    init_worker()
    if global_tokenizer:
        print(f"Tokenizer loaded. Vocab size: {len(global_tokenizer)}")
    else:
        print("Failed to load tokenizer in main process.")
        return

    for path in TARGET_PATHS:
        if not os.path.exists(path):
            print(f"\n[Skip] Path not found: {path}")
            continue
            
        print(f"\nProcessing: {path}")
        
        try:
            dataset = load_from_disk(path)
            original_size = len(dataset)
            
            # 1. Map: 검증 및 필터링 플래그 생성
            # (load_from_cache_file=False로 설정하여 이전의 잘못된 캐시 사용 방지)
            processed = dataset.map(
                process_batch,
                batched=True,
                num_proc=NUM_PROC,
                desc="Verifying & Calculating",
                load_from_cache_file=False
            )
            
            # 2. 통계 집계
            drop_reasons = processed['drop_reason']
            stats = Counter(drop_reasons)
            if "OK" in stats: del stats["OK"]
            
            # 3. Filter: keep=True인 것만 남김
            filtered_dataset = processed.filter(
                lambda x: x['keep'],
                num_proc=NUM_PROC,
                desc="Filtering"
            )
            
            # 4. 임시 컬럼 제거
            final_dataset = filtered_dataset.remove_columns(['keep', 'drop_reason'])
            
            filtered_size = len(final_dataset)
            dropped_count = original_size - filtered_size
            
            print(f"  ------------------------------------------------")
            print(f"  Original Size   : {original_size}")
            print(f"  Filtered Size   : {filtered_size}")
            print(f"  Dropped Samples : {dropped_count}")
            if dropped_count > 0:
                # 상위 5개 드랍 사유 출력
                print(f"  Top Drop Reasons: {stats.most_common(5)}")
            print(f"  ------------------------------------------------")
            
            # 5. 저장
            if dropped_count > 0 or original_size > 0:
                save_path = path.rstrip('/') + f"_verified_filtered_{THRESHOLD}"
                print(f"  Saving to: {save_path}")
                final_dataset.save_to_disk(save_path)
            else:
                print("  No data to save.")
                
        except Exception as e:
            print(f"  [Error] Failed to process {path}: {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    # Multiprocessing 시작 방식 설정
    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        pass
    
    main()

=== LLaDA Dataset Filtering (Threshold: 512) ===
Tokenizer loaded. Vocab size: 129325

Processing: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415


Verifying & Calculating (num_proc=64): 100%|██████████| 70906/70906 [00:01<00:00, 48506.30 examples/s]
Filtering (num_proc=64): 100%|██████████| 70906/70906 [00:01<00:00, 41417.21 examples/s]


  ------------------------------------------------
  Original Size   : 70906
  Filtered Size   : 70565
  Dropped Samples : 341
  Top Drop Reasons: [('Length Exceeded (559 > 512)', 7), ('Length Exceeded (549 > 512)', 6), ('Length Exceeded (602 > 512)', 6), ('Length Exceeded (674 > 512)', 6), ('Length Exceeded (878 > 512)', 6)]
  ------------------------------------------------
  Saving to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415_verified_filtered_512


Saving the dataset (1/1 shards): 100%|██████████| 70565/70565 [00:01<00:00, 38700.35 examples/s]



Processing: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415


Verifying & Calculating (num_proc=64): 100%|██████████| 70906/70906 [00:01<00:00, 48420.20 examples/s]
Filtering (num_proc=64): 100%|██████████| 70906/70906 [00:01<00:00, 48154.94 examples/s]


  ------------------------------------------------
  Original Size   : 70906
  Filtered Size   : 70565
  Dropped Samples : 341
  Top Drop Reasons: [('Length Exceeded (559 > 512)', 7), ('Length Exceeded (549 > 512)', 6), ('Length Exceeded (602 > 512)', 6), ('Length Exceeded (674 > 512)', 6), ('Length Exceeded (878 > 512)', 6)]
  ------------------------------------------------
  Saving to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415_verified_filtered_512


Saving the dataset (1/1 shards): 100%|██████████| 70565/70565 [00:01<00:00, 37885.16 examples/s]



Processing: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415


Verifying & Calculating (num_proc=64): 100%|██████████| 3450540/3450540 [00:46<00:00, 74325.36 examples/s] 
Filtering (num_proc=64): 100%|██████████| 3450540/3450540 [00:55<00:00, 61881.54 examples/s] 


  ------------------------------------------------
  Original Size   : 3450540
  Filtered Size   : 3442043
  Dropped Samples : 8497
  Top Drop Reasons: [('Length Exceeded (513 > 512)', 59), ('Length Exceeded (520 > 512)', 50), ('Length Exceeded (531 > 512)', 49), ('Length Exceeded (522 > 512)', 49), ('Length Exceeded (524 > 512)', 48)]
  ------------------------------------------------
  Saving to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415_verified_filtered_512


Saving the dataset (45/45 shards): 100%|██████████| 3442043/3442043 [01:32<00:00, 37056.00 examples/s]
