In [1]:
# from datasets import load_from_disk
# from collections import Counter

# test_ds = load_from_disk('Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation')
# Counter(test_ds['task'])


In [2]:
"""
{
    'bace': o,
    'chebi-20-mol2text': o,
    'chebi-20-text2mol': o,
    'forward_reaction_prediction': o,
    'qm9_homo': o,
    'qm9_lumo': o,
    'qm9_homo_lumo_gap': o,
    'reagent_prediction': o,
    'retrosynthesis': o,
    'smol-property_prediction-clintox': o,
    'smol-property_prediction-bbbp': o,
    'smol-molecule_captioning': o,
    'smol-forward_synthesis': o,
    'smol-molecule_generation': o,
    'mol-name_conversion-i2s': o,
    'mol-name_conversion-s2i': o, 
    'smol-property_prediction-esol': o,
    'smol-property_prediction-hiv': o,
    'smol-property_prediction-lipo': o,
    'smol-property_prediction-sider': o,
    'smol-retrosynthesis': o})
"""

"\n{\n    'bace': o,\n    'chebi-20-mol2text': o,\n    'chebi-20-text2mol': o,\n    'forward_reaction_prediction': o,\n    'qm9_homo': o,\n    'qm9_lumo': o,\n    'qm9_homo_lumo_gap': o,\n    'reagent_prediction': o,\n    'retrosynthesis': o,\n    'smol-property_prediction-clintox': o,\n    'smol-property_prediction-bbbp': o,\n    'smol-molecule_captioning': o,\n    'smol-forward_synthesis': o,\n    'smol-molecule_generation': o,\n    'mol-name_conversion-i2s': o,\n    'mol-name_conversion-s2i': o, \n    'smol-property_prediction-esol': o,\n    'smol-property_prediction-hiv': o,\n    'smol-property_prediction-lipo': o,\n    'smol-property_prediction-sider': o,\n    'smol-retrosynthesis': o})\n"

In [None]:
from datasets import load_from_disk, concatenate_datasets
import os

# 경로 설정
BASE_DIR = "/app/Mol-LLM_Custom/dataset/filtered_dataset"
OUTPUT_DIR = "/app/Mol-LLM_Custom/dataset/merged_dataset"

# 데이터셋 이름과 경로
DATASETS = {
    "bace": f"{BASE_DIR}/bace",
    "chebi-20-mol2text": f"{BASE_DIR}/chebi-20-mol2text",
    "chebi-20-text2mol": f"{BASE_DIR}/chebi-20-text2mol",
    "qm9_homo": f"{BASE_DIR}/qm9_homo",
}

# split 종류
SPLITS = ["train", "test", "val"]

# 파일명 패턴
PREFIX = "GSAI-ML-LLaDA-8B-Instruct_string+graph_q32"
SUFFIX = "512_Truncation"

# 출력 파일명 suffix
OUTPUT_SUFFIX = "merged_bace_chebi_mol2text_chebi_text2mol_qm9_homo"


def merge_datasets_for_split(split: str):
    """특정 split에 대해 4개 데이터셋을 병합"""
    datasets_to_merge = []

    for dataset_name, dataset_path in DATASETS.items():
        # 각 데이터셋의 split별 경로 구성
        # 예: bace -> GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation_bace
        split_dir = f"{PREFIX}_{split}_{SUFFIX}_{dataset_name}"
        full_path = os.path.join(dataset_path, split_dir)

        if os.path.exists(full_path):
            print(f"  로딩: {dataset_name} ({split})")
            ds = load_from_disk(full_path)
            print(f"    - 샘플 수: {len(ds)}")
            datasets_to_merge.append(ds)
        else:
            print(f"  [경고] 경로 없음: {full_path}")

    if not datasets_to_merge:
        print(f"  [에러] {split}에 대해 병합할 데이터셋이 없습니다.")
        return None

    # 데이터셋 병합
    merged = concatenate_datasets(datasets_to_merge)
    print(f"  병합 완료: 총 {len(merged)} 샘플")

    return merged


def main():
    # 출력 디렉토리 생성
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print("=" * 60)
    print("데이터셋 병합 시작")
    print("=" * 60)

    for split in SPLITS:
        print(f"\n[{split.upper()}] 병합 중...")

        merged_dataset = merge_datasets_for_split(split)

        if merged_dataset is not None:
            # 출력 경로 설정
            # 예: GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation_merged_bace_chebi_mol2text_chebi_text2mol_qm9_homo
            output_name = f"{PREFIX}_{split}_{SUFFIX}_{OUTPUT_SUFFIX}"
            output_path = os.path.join(OUTPUT_DIR, output_name)

            # 저장
            print(f"  저장 중: {output_path}")
            merged_dataset.save_to_disk(output_path)
            print(f"  저장 완료!")

    print("\n" + "=" * 60)
    print("모든 병합 완료!")
    print("=" * 60)

    # 결과 확인
    print("\n[결과 확인]")
    for split in SPLITS:
        output_name = f"{PREFIX}_{split}_{SUFFIX}_{OUTPUT_SUFFIX}"
        output_path = os.path.join(OUTPUT_DIR, output_name)
        if os.path.exists(output_path):
            ds = load_from_disk(output_path)
            print(f"  {split}: {len(ds)} 샘플")


if __name__ == "__main__":
    main()


In [1]:
#!/usr/bin/env python3
"""
SELFIES 토큰 등장 빈도 분석 스크립트

Train dataset에서 SELFIES 토큰의 등장 빈도를 분석합니다:
1. Input (Prompt) 부분의 SELFIES 토큰 등장 빈도
2. Response (Output) 부분의 SELFIES 토큰 등장 빈도
3. Total (Input + Response) 합산 등장 빈도

Usage:
    python utils/analyze_selfies_frequency.py --dataset_path <path> --selfies_dict <path> --top_k 100
"""

import argparse
import re
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple

from datasets import load_from_disk


def load_selfies_dict(selfies_dict_path: str) -> set:
    """SELFIES 토큰 사전 로드"""
    with open(selfies_dict_path, 'r') as f:
        tokens = {line.strip() for line in f.readlines() if line.strip()}
    return tokens


def extract_selfies_tokens(text: str, selfies_dict: set) -> List[str]:
    """
    텍스트에서 SELFIES 토큰 추출

    SELFIES 토큰은 [...]  형태로 되어 있음
    예: [C], [N], [=Branch1], [C@@H], [Ring1] 등
    """
    # [...] 패턴 매칭 (SELFIES 토큰 형태)
    pattern = r'\[[^\]]+\]'
    matches = re.findall(pattern, text)

    # selfies_dict에 있는 토큰만 필터링
    selfies_tokens = [m for m in matches if m in selfies_dict]
    return selfies_tokens


def analyze_dataset(
    dataset_path: str,
    selfies_dict_path: str,
    top_k: int = 100
) -> Tuple[Counter, Counter, Counter]:
    """
    데이터셋 분석

    Returns:
        input_counter: Input(Prompt)에서의 토큰 빈도
        response_counter: Response(Output)에서의 토큰 빈도
        total_counter: 합산 빈도
    """
    print(f"Loading dataset from: {dataset_path}")
    dataset = load_from_disk(dataset_path)

    print(f"Loading SELFIES dictionary from: {selfies_dict_path}")
    selfies_dict = load_selfies_dict(selfies_dict_path)
    print(f"  - SELFIES dictionary size: {len(selfies_dict)}")

    input_counter = Counter()
    response_counter = Counter()

    print(f"\nAnalyzing {len(dataset)} samples...")

    for i, sample in enumerate(dataset):
        # Input (prompt_text) 분석
        prompt_text = sample.get('prompt_text', '')
        input_tokens = extract_selfies_tokens(prompt_text, selfies_dict)
        input_counter.update(input_tokens)

        # Response (target_text) 분석
        target_text = sample.get('target_text', '')
        response_tokens = extract_selfies_tokens(target_text, selfies_dict)
        response_counter.update(response_tokens)

        if (i + 1) % 10000 == 0:
            print(f"  Processed {i + 1}/{len(dataset)} samples...")

    # Total = Input + Response
    total_counter = input_counter + response_counter

    return input_counter, response_counter, total_counter


def print_top_k_tokens(
    counter: Counter,
    title: str,
    top_k: int = 100
) -> None:
    """상위 K개 토큰 출력"""
    total_count = sum(counter.values())

    print(f"\n{'='*80}")
    print(f"{title}")
    print(f"{'='*80}")
    print(f"Total token occurrences: {total_count:,}")
    print(f"Unique tokens: {len(counter):,}")
    print(f"\nTop {top_k} tokens:")
    print(f"{'Rank':<6} {'Token':<20} {'Count':>12} {'Ratio':>10} {'Cumulative':>12}")
    print(f"{'-'*60}")

    cumulative = 0
    for rank, (token, count) in enumerate(counter.most_common(top_k), 1):
        ratio = count / total_count * 100 if total_count > 0 else 0
        cumulative += ratio
        print(f"{rank:<6} {token:<20} {count:>12,} {ratio:>9.2f}% {cumulative:>11.2f}%")


def save_results_to_file(
    input_counter: Counter,
    response_counter: Counter,
    total_counter: Counter,
    output_path: str,
    top_k: int = 100
) -> None:
    """결과를 파일로 저장"""
    with open(output_path, 'w') as f:
        for counter, title in [
            (total_counter, "TOTAL (Input + Response)"),
            (input_counter, "INPUT (Prompt) Only"),
            (response_counter, "RESPONSE (Output) Only"),
        ]:
            total_count = sum(counter.values())

            f.write(f"\n{'='*80}\n")
            f.write(f"{title}\n")
            f.write(f"{'='*80}\n")
            f.write(f"Total token occurrences: {total_count:,}\n")
            f.write(f"Unique tokens: {len(counter):,}\n")
            f.write(f"\nTop {top_k} tokens:\n")
            f.write(f"{'Rank':<6} {'Token':<20} {'Count':>12} {'Ratio':>10} {'Cumulative':>12}\n")
            f.write(f"{'-'*60}\n")

            cumulative = 0
            for rank, (token, count) in enumerate(counter.most_common(top_k), 1):
                ratio = count / total_count * 100 if total_count > 0 else 0
                cumulative += ratio
                f.write(f"{rank:<6} {token:<20} {count:>12,} {ratio:>9.2f}% {cumulative:>11.2f}%\n")

    print(f"\nResults saved to: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Analyze SELFIES token frequency in dataset")
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation_merged_bace_chebi_mol2text_chebi_text2mol_qm9_homo",
        help="Path to the dataset"
    )
    parser.add_argument(
        "--selfies_dict",
        type=str,
        default="Mol-LLM_Custom/model/selfies_dict.txt",
        help="Path to SELFIES dictionary file"
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=100,
        help="Number of top tokens to display"
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output file path (optional, prints to stdout if not specified)"
    )

    args = parser.parse_args()

    # 분석 실행
    input_counter, response_counter, total_counter = analyze_dataset(
        args.dataset_path,
        args.selfies_dict,
        args.top_k
    )

    # 결과 출력
    print_top_k_tokens(total_counter, "TOTAL (Input + Response)", args.top_k)
    print_top_k_tokens(input_counter, "INPUT (Prompt) Only", args.top_k)
    print_top_k_tokens(response_counter, "RESPONSE (Output) Only", args.top_k)

    # 파일로 저장 (옵션)
    if args.output:
        save_results_to_file(
            input_counter, response_counter, total_counter,
            args.output, args.top_k
        )

    # 추가 통계
    print(f"\n{'='*80}")
    print("SUMMARY STATISTICS")
    print(f"{'='*80}")

    input_total = sum(input_counter.values())
    response_total = sum(response_counter.values())
    total_total = sum(total_counter.values())

    print(f"Input tokens:    {input_total:>12,} ({input_total/total_total*100:.1f}%)")
    print(f"Response tokens: {response_total:>12,} ({response_total/total_total*100:.1f}%)")
    print(f"Total tokens:    {total_total:>12,}")

    # Input에만 있는 토큰 vs Response에만 있는 토큰
    input_only = set(input_counter.keys()) - set(response_counter.keys())
    response_only = set(response_counter.keys()) - set(input_counter.keys())
    both = set(input_counter.keys()) & set(response_counter.keys())

    print(f"\nUnique token distribution:")
    print(f"  - Input only:    {len(input_only):>6} tokens")
    print(f"  - Response only: {len(response_only):>6} tokens")
    print(f"  - Both:          {len(both):>6} tokens")


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm
usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]
                             [--selfies_dict SELFIES_DICT] [--top_k TOP_K]
                             [--output OUTPUT]
ipykernel_launcher.py: error: unrecognized arguments: --f=/root/.local/share/jupyter/runtime/kernel-v39bdc3123549b22d3e3491388678e219ef98e1ef4.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [1]:
from datasets import load_from_disk
from collections import Counter

train_ds = load_from_disk('/app/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_train_512_Truncation_chebi-20-mol2text') 
test_ds = load_from_disk('/app/Mol-LLM_Custom/dataset/train_official/GSAI-ML-LLaDA-8B-Instruct_string+graph_q32_test_512_Truncation_chebi-20-mol2text') 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
Counter(train_ds['task'])

Counter({'chebi-20-mol2text': 26113})

In [3]:
Counter(test_ds['task'])

Counter({'chebi-20-mol2text': 327})