※以下の記載コマンドは本番環境向けではないので注意してください

* megablocksによって作られたMoEモデルに対してtrl.DPOTrainerで事後学習ができるか確認する
* 最低限のモデルができるところまで確認中（現在lossが下がっていない状態）

```shell
srun --job-name=kuma-eval --partition g2 --nodes=1 --gpus-per-node=8 --time=06:00:00 --mem=128GB --pty bash -i
```

で実行する。時間は使う時に調整する

作業ノードでは以下を実行

```shell
conda activate venv39

jupyter-lab --no-browser --port 8888 --ip $(hostname -i)
```

※VSCodeからnotebook作業する場合は[この対応](https://blog.masuyoshi.com/%E3%80%90vscode%E4%BD%BF%E7%94%A8%E8%80%85%E6%B3%A8%E6%84%8F%E3%80%91%E3%82%B5%E3%83%BC%E3%83%90%E3%83%BC%E3%83%AA%E3%82%BD%E3%83%BC%E3%82%B9%E9%A3%9F%E3%81%84%E6%95%A3%E3%82%89%E3%81%8B%E3%81%99/)をすること

In [None]:
%pwd

In [None]:
!nvidia-smi

In [None]:
# dependencies は初回のみ実行
%pip install ipywidgets bitsandbytes peft pyzmq transformers trl datasets sentencepiece accelerate wandb huggingface_hub argilla python-dotenv 

In [None]:
%load_ext dotenv
%dotenv

In [None]:
# cache系は必ずteam storageへ
# TEAM_DATASETS_CACHE_DIR="/persistentshare/storage/team_kumagai/datasets"
TEAM_DATASETS_CACHE_DIR="./.cache"

In [None]:
import json
import os
import sys
from datetime import datetime
import logging
import random

import numpy as np
import pandas as pd

import wandb
from huggingface_hub import login, whoami

import argilla as rg

from datasets import load_dataset, Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModel,
    set_seed,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    LlamaTokenizer,
    TrainerCallback
)

from transformers import TrainingArguments
from trl import DPOTrainer

import torch.distributed as dist
import multiprocessing as mp

import torch
import transformers

from typing import Any

logger = logging.getLogger()
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

logger.info("start logging...!")

In [None]:
login(token=os.getenv('HF_TOKEN'))

In [None]:
run = wandb.init(
    project=os.getenv('WANDB_PROJECT'),
    entity=os.getenv('WANDB_ENTITY'),
)

In [None]:
rg.init(
    api_url=os.getenv("RG_API_URL"),
    api_key=os.getenv("RG_API_KEY"),
    workspace=os.getenv("RG_WORKSPACE"),
)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


def find_all_linear_names(peft_model, int4=False, int8=False):
    """Find all linear layer names in the model. reference from qlora paper."""
    cls = torch.nn.Linear
    if int4 or int8:
        import bitsandbytes as bnb
        if int4:
            cls = bnb.nn.Linear4bit
        elif int8:
            cls = bnb.nn.Linear8bitLt
    lora_module_names = set()
    for name, module in peft_model.named_modules():
        if isinstance(module, cls):
            # last layer is not add to lora_module_names
            if 'lm_head' in name:
                continue
            if 'output_layer' in name:
                continue
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    return sorted(lora_module_names)


def return_prompt_and_responses(examples) -> dict[str, str]:
    """Load the paired dataset and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': list[str],
        'chosen': list[str],
        'rejected': list[str],
    }

    Prompts are structured as follows:
      "Question: " + <prompt> + "\n\nAnswer: "
    """
    return {
         # see: https://github.com/ZHZisZZ/emulated-disalignment/blob/2f8e441fdf9117490c36d9f54adf536c23b6eb69/utils/utils.py#L80
        "prompt": ["Question: " + question.split("\n\nAssistant")[0].split("\n\nHuman: ")[1] + "\n\nAnswer: " for question in examples["chosen"]],
        "chosen": examples["chosen"],
        "rejected": examples["rejected"],
    }


In [None]:
model_name = "geniacllm/dMoEHf2"

In [None]:
config = AutoConfig.from_pretrained(model_name,
                                    cache_dir=TEAM_DATASETS_CACHE_DIR)

In [None]:
print(config)

In [None]:
# install MoE model created by megablocks
# prepare model
model_moe = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    trust_remote_code=True,
    cache_dir=TEAM_DATASETS_CACHE_DIR,
    device_map="auto",
)

In [None]:
print_trainable_parameters(model_moe)

In [None]:
# install tokenizer
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    # use_fast=False,
    # add_eos_token=True,
    # trust_remote_code=True,
    cache_dir=TEAM_DATASETS_CACHE_DIR,
    device_map="auto"
)
# なんかおかしい？ので足してみる
# 出たエラー: ValueError: Padding is enabled, but the tokenizer is not configured with a padding token. Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`) before calling the trainer.
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
eos_token_text = tokenizer.eos_token
print(f'eos token: {eos_token_text}')

In [None]:
# install datasets for DPO
dataset_rlhf = load_dataset("Anthropic/hh-rlhf", cache_dir=TEAM_DATASETS_CACHE_DIR)
dataset_rlhf_ja = load_dataset("llm-jp/hh-rlhf-12k-ja", cache_dir=TEAM_DATASETS_CACHE_DIR)

In [None]:
print(dataset_rlhf)
print(dataset_rlhf_ja)

megablocksで作られたモデルでHFに上がったものを使って、DPOを試す

see: https://llama2-accessory.readthedocs.io/en/latest/projects/mixtral-8x7b.html

ref: https://github.com/shibing624/MedicalGPT/blob/726bd2a62686bd7ed62262be44f8e0233edc2443/dpo_training.py#L30

In [None]:
# dataset for train
max_source_length = 256
max_target_length = 256
full_max_length = max_source_length + max_target_length

raw_datasets = dataset_rlhf

if "train" not in raw_datasets:
    raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets['train']
max_train_samples = len(train_dataset)
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")

tokenized_dataset = train_dataset.shuffle().map(
    return_prompt_and_responses,
    batched=True,
    num_proc=1,
    remove_columns=train_dataset.column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)
train_dataset = tokenized_dataset.filter(
    lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
                and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("First train example:")
logger.debug(train_dataset[0]['prompt'] + train_dataset[0]['chosen'])

if "test" not in raw_datasets:
    raise ValueError("--do_eval requires a test dataset")
eval_dataset = raw_datasets["test"]
max_eval_samples = len(eval_dataset)
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")

eval_dataset = eval_dataset.map(
    return_prompt_and_responses,
    batched=True,
    num_proc=1,
    remove_columns=eval_dataset.column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(
    lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
                and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("First eval example:")
logger.debug(eval_dataset[0]['prompt'] + eval_dataset[0]['chosen'])

In [None]:
train_dataset[0]

In [None]:
eval_dataset[0]

In [None]:
# see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
# github: https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/training_args.py#L176
training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        max_steps=50,
        logging_steps=1,
        save_steps=5,
        gradient_accumulation_steps=1,
        gradient_checkpointing=False,
        learning_rate=5e-4,
        evaluation_strategy="steps",
        eval_steps=5,
        output_dir="./output-dpo",
        report_to=["wandb"],
        lr_scheduler_type="cosine",
        warmup_steps=2,
        optim="paged_adamw_32bit", # see: https://github.com/eyess-glitch/phi-2-fine-tuning/blob/79cd01554482973f5e709ca9da9a5746d305b46e/dpo_train.py#L34
        bf16=True,  # T4はbf16が使えないけどL4は使える
        fp16=False,
        remove_unused_columns=False,
        run_name=f"dpo_{config.model_type}",
        # device_map="auto", の指定はない
    )

In [None]:
# trainer
trainer_dpo = DPOTrainer(
    model_moe,
    ref_model=None,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=None,
    max_prompt_length=1024,
    max_length=full_max_length,
)

In [None]:
# start!
trainer_dpo.train()
trainer_dpo.model.save_pretrained("./model-dpo")

In [None]:
# push HF see: https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/trainer.py#L3559

# trainer_dpo.push_to_hub("geniacllm/dMoEHf2-dpo-test")