In [None]:
import os
import json
import gzip
from pathlib import Path
from functools import partial
from typing import List, Dict, Any

import fiddle as fdl
import lightning.pytorch as p
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import AutoTokenizer

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


In [10]:
class JapaneseWikiDataset(Dataset):
    """日本語WikipediaデータセットのためのDatasetクラス"""

    def __init__(self, data_files: List[str], tokenizer, seq_length: int = 512,
                 instruction_template: str = None):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.examples = []
        self.instruction_template = instruction_template or self._default_instruction_template()

        # データファイルの読み込み
        for file_path in data_files:
            self._load_file(file_path)

        print(f"Loaded {len(self.examples)} examples from {len(data_files)} files")

    def _default_instruction_template(self):
        """デフォルトのインストラクションテンプレート"""
        return """以下の文章を要約してください。

文章: {text}

要約:"""

    def _load_file(self, file_path: str):
        """JSONLファイルの読み込み（gzip対応）"""
        if file_path.endswith('.gz'):
            with gzip.open(file_path, 'rt', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line.strip())
                    if 'text' in data and data['text'].strip():
                        self.examples.append(data['text'])
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line.strip())
                    if 'text' in data and data['text'].strip():
                        self.examples.append(data['text'])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        text = self.examples[idx]

        # テキストを短く切って、要約タスクとして使用
        # 実際のSFTでは、より適切な質問-回答ペアを使用することを推奨
        if len(text) > 200:
            input_text = text[:200]
            target_text = text[:50]  # 簡単な要約として最初の50文字を使用

            # インストラクション形式に変換
            full_text = self.instruction_template.format(text=input_text) + " " + target_text
        else:
            full_text = text

        # トークナイズ
        encoded = self.tokenizer(
            full_text,
            max_length=self.seq_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoded['input_ids'].squeeze(),
            'attention_mask': encoded['attention_mask'].squeeze(),
            'labels': encoded['input_ids'].squeeze()  # SFTでは入力と同じ
        }

    def collate_fn(self, batch):
        """バッチのコレート関数"""
        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


In [11]:
class JapaneseWikiDataModule(pl.LightningDataModule):
    """日本語WikipediaデータのためのLightningDataModule"""

    def __init__(
        self,
        tokenizer,
        data_dir: str = "./data/ja_wiki",
        seq_length: int = 512,
        micro_batch_size: int = 1,
        global_batch_size: int = 2,
        num_workers: int = 0,
        pin_memory: bool = True,
        persistent_workers: bool = False,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.data_dir = Path(data_dir)
        self.seq_length = seq_length
        self.micro_batch_size = micro_batch_size
        self.global_batch_size = global_batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers

    def setup(self, stage: str = None):
        """データセットのセットアップ"""
        # トレーニングデータファイルの取得
        train_files = sorted(list(self.data_dir.glob("train_*.jsonl*")))
        val_files = sorted(list(self.data_dir.glob("validation_*.jsonl*")))

        if not train_files:
            raise ValueError(f"No training files found in {self.data_dir}")

        # データセットの作成
        self.train_dataset = JapaneseWikiDataset(
            [str(f) for f in train_files],
            self.tokenizer,
            self.seq_length
        )

        if val_files:
            self.val_dataset = JapaneseWikiDataset(
                [str(f) for f in val_files],
                self.tokenizer,
                self.seq_length
            )
        else:
            # 検証データがない場合は、トレーニングデータの一部を使用
            print("No validation files found, using 10% of training data for validation")
            train_size = int(0.9 * len(self.train_dataset))
            val_size = len(self.train_dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                self.train_dataset, [train_size, val_size]
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.micro_batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
            collate_fn=self.train_dataset.collate_fn if hasattr(self.train_dataset, 'collate_fn') else None,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.micro_batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
            collate_fn=self.val_dataset.collate_fn if hasattr(self.val_dataset, 'collate_fn') else None,
            shuffle=False,
        )


In [12]:
def japanese_wiki_data(tokenizer, data_dir="./data/ja_wiki", mbs=1, gbs=2) -> pl.LightningDataModule:
    """日本語Wikipediaデータモジュールをインスタンス化して返す

    Args:
        tokenizer (AutoTokenizer): 使用するトークナイザー
        data_dir (str): ja_wikiデータのディレクトリパス
        mbs (int): マイクロバッチサイズ
        gbs (int): グローバルバッチサイズ

    Returns:
        pl.LightningDataModule: トレーニング用のデータセット
    """
    return JapaneseWikiDataModule(
        tokenizer=tokenizer,
        data_dir=data_dir,
        seq_length=512,
        micro_batch_size=mbs,
        global_batch_size=gbs,
        num_workers=0,
    )


In [13]:
model_name = "Qwen/Qwen2.5-0.5B"  # 多言語対応モデル

strategy = ""  # 分散トレーニング戦略
max_steps = 100  # トレーニングステップ数
accelerator = "gpu"
num_devices = 1  # GPU数
wandb_name = None  # wandb実験名
use_torch_jit = False  # torch jit有効化
ckpt_folder = "/opt/checkpoints/japanese_wiki_experiments/"  # チェックポイント保存パス
data_dir = "/workspace/data/ja_wiki"  # ja_wikiデータのパス


In [14]:
def make_strategy(strategy, model, devices, num_nodes, adapter_only=False):
    """
    分散トレーニング戦略を作成
    """
    if strategy == 'ddp':
        return pl.strategies.DDPStrategy(
            checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),
        )
    elif strategy == 'fsdp2':
        return nl.FSDP2Strategy(
            data_parallel_size=devices * num_nodes,
            tensor_parallel_size=1,
            checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),
        )
    else:
        return pl.strategies.SingleDeviceStrategy(
            device='cuda:0',
            checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),
        )



In [None]:
def run_japanese_sft():
    """日本語WikipediaデータでSFTを実行"""

    # WandBロガーの設定
    wandb = WandbLogger(
        project="nemo_japanese_wiki",
        name=wandb_name,
    ) if wandb_name is not None else None

    # コールバックの設定
    callbacks = []
    if use_torch_jit:
        jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)
        callbacks = [JitTransform(jit_config)]

    callbacks.append(
        nl.ModelCheckpoint(
            every_n_train_steps=max_steps // 2,
            dirpath=ckpt_folder,
        )
    )

    # モデルの作成
    model = llm.HFAutoModelForCausalLM(model_name=model_name)

    # 戦略の作成
    training_strategy = make_strategy(strategy, model, num_devices, 1)

    # トレーナーの設定
    trainer = nl.Trainer(
        devices=num_devices,
        max_steps=max_steps,
        accelerator=accelerator,
        strategy=training_strategy,
        log_every_n_steps=1,
        limit_val_batches=0.1,  # 検証データの一部のみ使用
        num_sanity_val_steps=2,
        accumulate_grad_batches=1,
        gradient_clip_val=1.0,
        use_distributed_sampler=False,
        logger=wandb,
        callbacks=callbacks,
        precision="bf16",
    )

    # トークナイザーの作成
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # SFTの実行
    llm.api.finetune(
        model=model,
        data=japanese_wiki_data(
            tokenizer,
            data_dir=data_dir,
            gbs=1
        ),
        trainer=trainer,
        optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
        peft=None,  # PEFTを使用する場合はここでLoRAなどを設定
        log=None,
    )

    print(f"Training completed! Checkpoints saved to: {ckpt_folder}")
    return ckpt_folder



In [18]:
if __name__ == "__main__":
    # データディレクトリの確認
    if not os.path.exists(data_dir):
        print(f"Data directory {data_dir} not found!")
        print("Please download ja_wiki data using the commands from the blog post.")
    else:
        # SFTの実行
        checkpoint_dir = run_japanese_sft()

        # # 推論テスト（オプション）
        # print("\n=== Inference Test ===")
        # from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

        # new_checkpoint = Path(checkpoint_dir) / f"default--val_loss=0.0000-epoch=1-step={max_steps}-last/hf_weights"

        # if new_checkpoint.exists():
        #     pipe = pipeline(
        #         "text-generation",
        #         model=AutoModelForCausalLM.from_pretrained(new_checkpoint),
        #         tokenizer=AutoTokenizer.from_pretrained(new_checkpoint),
        #         torch_dtype=torch.bfloat16,
        #         device_map="auto",
        #     )

        #     # 日本語でのテスト
        #     result = pipe("日本の文化について教えてください。")
        #     print(result)


[NeMo I 2025-07-14 07:49:56 nemo_logging:393] use_linear_ce_loss: True


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


[NeMo I 2025-07-14 07:49:57 nemo_logging:393] Experiments will be logged at /workspace/src/nemo_experiments/default/2025-07-14_07-47-16


[NeMo W 2025-07-14 07:49:57 nemo_logging:405] "update_logger_directory" is True. Overwriting tensorboard logger "save_dir" to /workspace/src/nemo_experiments


Loaded 1363395 examples from 14 files
Loaded 1134 examples from 1 files
[NeMo I 2025-07-14 07:50:24 nemo_logging:393] Bad message (TypeError('not all arguments converted during string formatting')): {'name': 'nemo_logger', 'msg': 'Configuring model with attn_implementation:', 'args': ('sdpa',), 'levelname': 'INFO', 'levelno': 20, 'pathname': '/opt/NeMo/nemo/utils/nemo_logging.py', 'filename': 'nemo_logging.py', 'module': 'nemo_logging', 'exc_info': None, 'exc_text': None, 'stack_info': None, 'lineno': 393, 'funcName': 'info', 'created': 1752479424.0764952, 'msecs': 76.0, 'relativeCreated': 310648.4651565552, 'thread': 140005550290240, 'threadName': 'MainThread', 'processName': 'MainProcess', 'process': 248363, 'taskName': 'Task-2'}


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | Qwen2ForCausalLM | 494 M  | train
---------------------------------------------------
494 M     Trainable params
0         Non-trainable params
494 M     Total params
1,976.131 Total estimated model params size (MB)
319       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TypeError: 'AutoTokenizer' object is not callable