In [1]:
# finetuning llama2

In [2]:
#!pip install transformers==4.35.0
#!pip install peft==0.5.0
#!pip install bitsandbytes==0.41.1
#!pip install accelerate==0.23.0
#!pip install flash-attn==2.3.1.post1
#!pip install datasets==2.14.5

In [3]:
import os
#os.environ["CUDA_VISIBLE_DEVICES"]="1"

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer,pipeline
from datasets import Dataset
import copy
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
#ハイパラ関連
#モデル名
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
model_name = "kanhatakeyama/0405_100m_clean_ja"
model_name= "mistral-community/Mixtral-8x22B-v0.1"
#LoRA関連
r=8
lora_alpha=r
bit=16
#bit=8
bit=4
flash_atten=False

#LoRAのadapter
target_modules= [
    "lm_head",
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate",
    "w1",
    "w2",
    "w3"
]


target_modules=[
    "c_attn",
    "c_proj",
]
target_modules= [
    #"lm_head",
    "q_proj",
    "k_proj",
    "v_proj",
    #"o_proj",
    #"gate",
    #"w1",
    #"w2",
    #"w3"
]



#学習関連
#gradient_checkpointing =True  #vramの節約をしたい場合
gradient_checkpointing =False

lr=10**-5
do_train=True
#do_train=False

In [5]:

device_map="auto"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

def init_model(model_name, r, lora_alpha, target_modules, bit=4):
    if bit == 4:
        print("Using 4-bit mode")
        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                     quantization_config=bnb_config,
                                                     device_map=device_map,
                                                     use_flash_attention_2=flash_atten,
                                                     )
    elif bit == 8:
        print("Using 8-bit mode")
        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                     load_in_8bit=True,
                                                     device_map=device_map,
                                                     use_flash_attention_2=flash_atten,
                                                     )
    elif bit == 16:
        print("Using fp16 mode")
        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                     device_map=device_map,
                                                     torch_dtype=torch.float16,
                                                     use_flash_attention_2=flash_atten,
                                                     )
    else:
        raise ValueError("bit must be 4, 8 or 16")

    if len(target_modules)==0:
        return model
    peft_config = LoraConfig(
        task_type="CAUSAL_LM", inference_mode=False, r=r, lora_alpha=lora_alpha,
        lora_dropout=0.1,
        target_modules=target_modules,
    )
    model = get_peft_model(model, peft_config)
    return model


In [6]:

#モデル初期化
model=init_model(model_name, r, lora_alpha, target_modules, bit=bit)

Using 4-bit mode


Loading checkpoint shards: 100%|██████████| 59/59 [02:23<00:00,  2.43s/it]


In [7]:


tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

#pipe = pipeline("text-generation", model=model,
#                tokenizer=tokenizer, max_new_tokens=1000)

# データセットの準備

In [8]:
import random
from datasets import load_dataset
#system_prompt="You are a professional chemist. Predict the melting point of the following compound."


dataset= load_dataset('hatakeyama-llm-team/nhk-news-170k', split="train")
print(len(dataset))

max_chars = 1000  # 最大文字数

def filter_by_max_char_length(example):
    # 文字数をチェック
    return len(example['text']) <= max_chars

# フィルタリング関数をデータセットに適用
dataset= dataset.filter(filter_by_max_char_length)

print(len(dataset))
dataset = dataset.map(lambda samples: tokenizer(
        samples['text']), batched=True)



168839
132553


# モデルの訓練

In [9]:
import transformers
from datetime import datetime
per_device_train_batch_size=6
epochs=0.2

#train
train_args = transformers.TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=2,
        warmup_steps=0,
        num_train_epochs=epochs,
        learning_rate=lr,
        fp16=True,
        logging_steps=100,
        save_total_limit=1,
        output_dir='outputs/'+datetime.now().strftime('%Y%m%d%H%M%S'),
        gradient_checkpointing=gradient_checkpointing,

    )

# trainer
#callbacks = [EarlyStoppingCallback()]
callbacks = []

trainer = transformers.Trainer(
    model=model,
    train_dataset=dataset,
    args=train_args,
    callbacks=callbacks,
    data_collator=transformers.DataCollatorForLanguageModeling(
        tokenizer, mlm=False)
)

if do_train:
    training_result = trainer.train()
    training_result.training_loss

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkanhatakeyama[0m ([33mkanhatakeyamas[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
100,1.2683
200,1.2297
300,1.1922
400,1.1845
500,1.1637
600,1.1679
700,1.1503
800,1.141
900,1.1374
1000,1.1375


In [10]:


#モデルの保存: adapterのみ保存するとき｡
#from datetime import datetime
#current_datetime = datetime.now()
#model.save_pretrained(f"./outputs/{current_datetime}")
model.save_pretrained(f"./outputs/mixtral_1epoch_0415")

#モデルの読み込み: afapter経由で読み込むとき
from peft import AutoPeftModelForCausalLM
model_path="./outputs/7b_ft"
#model_path="./outputs/7b_ft_with_self_prediction_0115"

"""
model = AutoPeftModelForCausalLM.from_pretrained(model_path,
                                                 device_map=device_map,
                                                     torch_dtype=torch.float16,
                                                     use_flash_attention_2=True,
                                                 )
"""

'\nmodel = AutoPeftModelForCausalLM.from_pretrained(model_path,\n                                                 device_map=device_map,\n                                                     torch_dtype=torch.float16,\n                                                     use_flash_attention_2=True,\n                                                 )\n'

In [11]:
max_new_tokens=100
pipeline=pipeline("text-generation",
                        model=model,tokenizer=tokenizer,
                        max_new_tokens=max_new_tokens,
                        repetition_penalty=1.5,)
pipeline("""以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。
要求を適切に満たす応答を書きなさい。
質問: 元気ですか?
回答: """)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonF

[{'generated_text': '以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。\n要求を適切に満たす応答を書きなさい。\n質問: 元気ですか?\n回答: えーっ, ちょうどそろばんが壊れてしまったこともあり、少々不調だけど大丈夫ですよ!\n\\end{lstlisting}'}]