In [1]:
import os
import pandas as pd
from datasets import load_from_disk
from transformers import AutoTokenizer

# =============================================================================
# [설정] 경로 및 변수
# =============================================================================
MODEL_ID = "GSAI-ML/LLaDA-8B-Instruct"
NUM_PROC = 64
MAX_LENGTH = 512

# 입력 경로 (Step 1에서 생성된 최종 클린 데이터셋)
INPUT_PATHS = {
    "train": "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_FINAL_CLEANED",
    "val": "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_val_FINAL_CLEANED",
    "test": "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_FINAL_CLEANED"
}

# 출력 디렉토리
SAVE_DIR = "/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/"

# SELFIES 사전 경로
SELFIES_DICT_PATH = "/home/jovyan/CHJ/Mol-LLM_Custom/model/selfies_dict.txt"

# =============================================================================
# [1] 토크나이저 준비 (Special Tokens 포함)
# =============================================================================
def get_custom_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    
    # 기본 스페셜 토큰 정의 (이전 코드와 동일)
    CUSTOM_SPECIAL_TOKENS = [
        "<BOOLEAN>", "</BOOLEAN>", "<FLOAT>", "</FLOAT>", "<DESCRIPTION>", "</DESCRIPTION>",
        "<SELFIES>", "</SELFIES>", "<GRAPH>", "</GRAPH>", "<3D_CONFORMER>", "</3D_CONFORMER>",
        "<mol>", "<|0|>", "<|1|>", "<|2|>", "<|3|>", "<|4|>", "<|5|>", "<|6|>", "<|7|>", 
        "<|8|>", "<|9|>", "<|+|>", "<|-|>", "<|.|>", "<INSTRUCTION>", "</INSTRUCTION>", 
        "|>>|", "<IUPAC>", "</IUPAC>", "<MOLFORMULA>", "</MOLFORMULA>"
    ]
    
    if os.path.exists(SELFIES_DICT_PATH):
        with open(SELFIES_DICT_PATH, 'r') as f:
            selfies_tokens = [line.strip() for line in f if line.strip()]
        CUSTOM_SPECIAL_TOKENS.extend(selfies_tokens)
    
    tokenizer.add_tokens(list(set(CUSTOM_SPECIAL_TOKENS)))
    return tokenizer

tokenizer = get_custom_tokenizer()

# =============================================================================
# [2] 필터링 함수 정의
# =============================================================================
def filter_by_length(batch):
    # Prompt와 Target을 토큰화 (전체 길이를 측정하기 위해 truncation=False)
    p_enc = tokenizer(batch['prompt_text'], add_special_tokens=False)
    t_enc = tokenizer(batch['target_text'], add_special_tokens=False)
    
    # 결과 리스트 (각 샘플별로 합산 길이가 512 이하인 것만 True)
    keep_indices = []
    for p_ids, t_ids in zip(p_enc['input_ids'], t_enc['input_ids']):
        # Prompt + Target 길이가 512 이내인지 확인
        if len(p_ids) + len(t_ids) <= MAX_LENGTH:
            keep_indices.append(True)
        else:
            keep_indices.append(False)
    return keep_indices

# =============================================================================
# [3] 메인 실행 루프
# =============================================================================
def run_truncation_pipeline():
    for split, path in INPUT_PATHS.items():
        if not os.path.exists(path):
            print(f"[Skip] {split} dataset not found at {path}")
            continue
            
        print(f"\n>>> Processing {split.upper()} Split...")
        ds = load_from_disk(path)
        original_count = len(ds)
        
        # 필터링 수행
        # .filter()의 batched=True를 사용하여 토큰화 속도를 높입니다.
        ds_filtered = ds.filter(
            filter_by_length,
            batched=True,
            batch_size=1000,
            num_proc=NUM_PROC,
            desc=f"Filtering {split} (> {MAX_LENGTH} tokens)"
        )
        
        filtered_count = len(ds_filtered)
        dropped_count = original_count - filtered_count
        
        print(f" - Original: {original_count:,}")
        print(f" - Filtered: {filtered_count:,}")
        print(f" - Dropped:  {dropped_count:,} ({ (dropped_count/original_count)*100:.2f}%)")
        
        # 저장 경로 생성
        # 예: GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation
        save_name = f"GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_{split}_512_Truncation"
        save_full_path = os.path.join(SAVE_DIR, save_name)
        
        ds_filtered.save_to_disk(save_full_path)
        print(f"[Success] Saved to: {save_full_path}")

if __name__ == "__main__":
    run_truncation_pipeline()

  import pynvml  # type: ignore[import]



>>> Processing TRAIN Split...


Loading dataset from disk:   0%|          | 0/44 [00:00<?, ?it/s]

Filtering train (> 512 tokens) (num_proc=64):   0%|          | 0/3313489 [00:00<?, ? examples/s]

 - Original: 3,313,489
 - Filtered: 3,303,537
 - Dropped:  9,952 (0.30%)


Saving the dataset (0/43 shards):   0%|          | 0/3303537 [00:00<?, ? examples/s]

[Success] Saved to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation

>>> Processing VAL Split...


Filtering val (> 512 tokens) (num_proc=64):   0%|          | 0/35199 [00:00<?, ? examples/s]

 - Original: 35,199
 - Filtered: 35,042
 - Dropped:  157 (0.45%)


Saving the dataset (0/1 shards):   0%|          | 0/35042 [00:00<?, ? examples/s]

[Success] Saved to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_val_512_Truncation

>>> Processing TEST Split...


Filtering test (> 512 tokens) (num_proc=64):   0%|          | 0/32822 [00:00<?, ? examples/s]

 - Original: 32,822
 - Filtered: 32,595
 - Dropped:  227 (0.69%)


Saving the dataset (0/1 shards):   0%|          | 0/32595 [00:00<?, ? examples/s]

[Success] Saved to: /home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_512_Truncation


In [4]:
from datasets import load_from_disk
from collections import Counter
train_ds = load_from_disk('/home/jovyan/CHJ/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation')
Counter(train_ds['task'])


Loading dataset from disk:   0%|          | 0/43 [00:00<?, ?it/s]

Counter({'smol-forward_synthesis': 957064,
         'smol-retrosynthesis': 858460,
         'smol-name_conversion-i2s': 296011,
         'smol-name_conversion-s2i': 295989,
         'reagent_prediction': 121846,
         'forward_reaction_prediction': 121795,
         'retrosynthesis': 120557,
         'qm9_lumo': 117708,
         'qm9_homo': 117660,
         'qm9_homo_lumo_gap': 117539,
         'smol-molecule_captioning': 39195,
         'smol-property_prediction-hiv': 32864,
         'chebi-20-mol2text': 26113,
         'chebi-20-text2mol': 25887,
         'smol-molecule_generation': 24843,
         'smol-property_prediction-sider': 21986,
         'smol-property_prediction-lipo': 3341,
         'smol-property_prediction-bbbp': 1450,
         'bace': 1210,
         'smol-property_prediction-clintox': 1131,
         'smol-property_prediction-esol': 888})