In [1]:
import os
import re
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer
from tqdm import tqdm
from collections import Counter

# ==========================================
# [설정] 임계값 및 경로 설정
# ==========================================
# 병렬 프로세스 개수 (CPU 코어 수에 맞춰 조절)
NUM_PROC = 32

# 필터링 할 최대 토큰 길이 (Prompt + Target)
THRESHOLD = 512 

# 검사할 데이터셋 경로 리스트
TARGET_PATHS = [
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415",
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415",
    "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415"
]

# 토크나이저 경로
TOKENIZER_PATH = "GSAI-ML/LLaDA-8B-Instruct" 

# 멀티프로세싱 시 토크나이저 병렬 처리 충돌 방지
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def check_llada_format(example):
    """
    LLaDA Instruction Format 및 데이터 무결성을 검사합니다.
    """
    prompt = example.get('prompt_text', '')
    target = example.get('target_text', '')
    
    # 1. 텍스트 존재 여부 확인
    if not prompt or not target:
        return False, "Empty Text"

    # 2. LLaDA 프롬프트 포맷 검사
    required_tags = [
        "<|begin_of_text|>",
        "<|start_header_id|>system<|end_header_id|>",
        "<|start_header_id|>user<|end_header_id|>",
        "<|start_header_id|>assistant<|end_header_id|>"
    ]
    for tag in required_tags:
        if tag not in prompt:
            return False, f"Missing Tag: {tag}"
    
    # 3. LLaDA 타겟 포맷 검사 (<|end_of_text|>)
    if "<|end_of_text|>" not in target:
        return False, "Missing Target EOS"

    # 4. 그래프 데이터 확인 (string+graph 모드)
    # PyG Data 객체가 분해되어 저장되어 있음 (x, edge_index 등)
    # Note: batched=True로 넘어올 때는 list 형태이므로 길이 체크 시 주의 (여기서는 단일 객체 기준 로직)
    if example.get('x') is None:
        return False, "Empty Graph Node Features (x)"
    if example.get('edge_index') is None:
        return False, "Empty Graph Edge Index"

    return True, "OK"

def process_batch_verify(batch, tokenizer, threshold):
    """
    배치 단위로 무결성 및 길이를 검사하여 'keep' 여부와 'drop_reason'을 반환합니다.
    """
    batch_size = len(batch['prompt_text'])
    keeps = []
    reasons = []
    
    prompts = batch['prompt_text']
    targets = batch['target_text']
    
    # 1. 포맷 검사 (Loop)
    texts_to_tokenize = []
    indices_to_tokenize = []
    
    for i in range(batch_size):
        # 딕셔너리 형태로 재구성하여 검사 함수 호출
        # (batch 딕셔너리의 값들은 리스트임)
        example = {k: v[i] for k, v in batch.items()}
        
        is_valid, reason = check_llada_format(example)
        
        if is_valid:
            keeps.append(True) # 일단 True로 설정 (길이 검사 전)
            reasons.append("OK")
            # 길이 검사를 위해 텍스트 준비
            texts_to_tokenize.append(prompts[i] + targets[i])
            indices_to_tokenize.append(i)
        else:
            keeps.append(False)
            reasons.append(reason)
    
    # 2. 길이 필터링 (Batch Tokenization)
    if texts_to_tokenize:
        # padding=False, truncation=False로 실제 길이 계산
        tokenized = tokenizer(texts_to_tokenize, add_special_tokens=False)
        lengths = [len(ids) for ids in tokenized['input_ids']]
        
        for idx, length in zip(indices_to_tokenize, lengths):
            if length > threshold:
                keeps[idx] = False
                reasons[idx] = "Length Exceeded"
                
    return {
        "keep": keeps,
        "drop_reason": reasons
    }

def filter_dataset(dataset, tokenizer, threshold):
    """
    병렬 처리를 이용해 데이터셋을 검증하고 필터링합니다.
    """
    print(f"  > Verifying and Checking Length (Num Proc: {NUM_PROC})...")

    # 1. Map: 각 샘플에 대해 keep 여부와 이유를 판별 (병렬 처리)
    processed_dataset = dataset.map(
        process_batch_verify,
        batched=True,
        num_proc=NUM_PROC,
        fn_kwargs={"tokenizer": tokenizer, "threshold": threshold},
        desc=" analyzing"
    )
    
    # 2. 통계 집계
    # drop_reason 컬럼을 통해 통계를 냅니다.
    # (주의: 대용량 데이터셋에서 Counter는 약간의 시간이 걸릴 수 있음)
    drop_reasons = processed_dataset['drop_reason']
    stats = Counter(drop_reasons)
    del stats["OK"] # 정상 데이터 카운트는 제외
    
    # 3. Filter: keep=True인 것만 남김
    filtered_dataset = processed_dataset.filter(
        lambda example: example['keep'],
        num_proc=NUM_PROC,
        desc=" filtering"
    )
    
    # 4. 임시 컬럼 제거 (원본 스키마 복구)
    final_dataset = filtered_dataset.remove_columns(['keep', 'drop_reason'])
    
    return final_dataset, stats

def main():
    print(f"=== LLaDA Dataset Verification & Filtering (Threshold: {THRESHOLD}, Process: {NUM_PROC}) ===")
    
    print(f"Loading Tokenizer from: {TOKENIZER_PATH}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
    except Exception as e:
        print(f"[Error] Failed to load tokenizer: {e}")
        return

    for path in TARGET_PATHS:
        if not os.path.exists(path):
            print(f"\n[Skipping] Path not found: {path}")
            continue
            
        print(f"\nProcessing Dataset: {path}")
        
        try:
            dataset = load_from_disk(path)
            original_size = len(dataset)
            
            # 필터링 수행
            filtered_dataset, stats = filter_dataset(dataset, tokenizer, THRESHOLD)
            
            filtered_size = len(filtered_dataset)
            dropped_count = original_size - filtered_size
            
            print(f"  ------------------------------------------------")
            print(f"  Original Size   : {original_size}")
            print(f"  Filtered Size   : {filtered_size}")
            print(f"  Dropped Samples : {dropped_count}")
            if dropped_count > 0:
                print(f"  Drop Reasons    : {dict(stats)}")
            print(f"  ------------------------------------------------")
            
            # 저장
            if dropped_count > 0:
                save_path = path.rstrip('/') + f"_verified_filtered_{THRESHOLD}"
                print(f"  Saving filtered dataset to: {save_path}")
                filtered_dataset.save_to_disk(save_path)
            else:
                print("  No samples dropped. Skipping save.")
            
        except Exception as e:
            print(f"  [Fatal Error] Failed to process {path}: {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    main()

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


=== LLaDA Dataset Verification & Filtering (Threshold: 512, Process: 32) ===
Loading Tokenizer from: GSAI-ML/LLaDA-8B-Instruct

Processing Dataset: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415
  > Verifying and Checking Length (Num Proc: 32)...


 analyzing (num_proc=32): 100%|██████████| 79884/79884 [00:06<00:00, 13107.89 examples/s]
 filtering (num_proc=32): 100%|██████████| 79884/79884 [00:02<00:00, 33910.09 examples/s]


  ------------------------------------------------
  Original Size   : 79884
  Filtered Size   : 76164
  Dropped Samples : 3720
  Drop Reasons    : {'Length Exceeded': 3720}
  ------------------------------------------------
  Saving filtered dataset to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_validation_3.3M_0415_verified_filtered_512


Saving the dataset (1/1 shards): 100%|██████████| 76164/76164 [00:01<00:00, 38670.05 examples/s]



Processing Dataset: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415
  > Verifying and Checking Length (Num Proc: 32)...


 analyzing (num_proc=32): 100%|██████████| 79884/79884 [00:06<00:00, 12711.72 examples/s]
 filtering (num_proc=32): 100%|██████████| 79884/79884 [00:02<00:00, 32842.58 examples/s]


  ------------------------------------------------
  Original Size   : 79884
  Filtered Size   : 76164
  Dropped Samples : 3720
  Drop Reasons    : {'Length Exceeded': 3720}
  ------------------------------------------------
  Saving filtered dataset to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_3.3M_0415_verified_filtered_512


Saving the dataset (1/1 shards): 100%|██████████| 76164/76164 [00:01<00:00, 39963.77 examples/s]



Processing Dataset: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415
  > Verifying and Checking Length (Num Proc: 32)...


 analyzing (num_proc=32): 100%|██████████| 3450540/3450540 [02:37<00:00, 21886.27 examples/s]
 filtering (num_proc=32): 100%|██████████| 3450540/3450540 [01:34<00:00, 36338.53 examples/s]


  ------------------------------------------------
  Original Size   : 3450540
  Filtered Size   : 3275725
  Dropped Samples : 174815
  Drop Reasons    : {'Length Exceeded': 174815}
  ------------------------------------------------
  Saving filtered dataset to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_3.3M_0415_verified_filtered_512


Saving the dataset (43/43 shards): 100%|██████████| 3275725/3275725 [01:36<00:00, 34087.00 examples/s]
