In [None]:
%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
!pip install unsloth

# Tải checkpoint của đợt train trước từ HuggingFace

In [None]:
from huggingface_hub import login, snapshot_download
login("HUGGINGFACE_TOKEN")  # REPLACE with your HF token

checkpoint_path = snapshot_download(
    repo_id="TrinhHoangKhang/mistral-7B-2500",
    allow_patterns=["*"],  # only download the folder
    repo_type="model",
    local_dir="/kaggle/working/mistral_checkpoint",
    local_dir_use_symlinks=False,
)


# Tải Base model

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# Chuẩn bị dataset

In [None]:
from datasets import load_dataset, DatasetDict

metamathqa = load_dataset("json", data_files='/kaggle/input/metamathqa-395k/MetaMathQA-395K.json')
metamathqa = DatasetDict({
    "train": metamathqa["train"].select(range(200000, 300000))
})

In [None]:
def split_query(example):
    query = example["query"]
    response = example["response"]

    if "\n" in query:
        instruction = query.split("\n")[0]
        input_text = "\n".join(query.split("\n")[1:])
    else:
        instruction = query
        input_text = ""

    return {
        "instruction": instruction,
        "input": input_text,
        "output": response
    }

# Apply transformation
metamathqa = metamathqa.map(split_query)

In [None]:
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

def format_meta_math(example):
    if example["input"]:
        prompt = PROMPT_DICT["prompt_input"].format(
            instruction=example["instruction"],
            input=example["input"]
        )
    else:
        prompt = PROMPT_DICT["prompt_no_input"].format(
            instruction=example["instruction"]
        )
    return {
        "text": prompt + " " + example["output"] + tokenizer.eos_token
    }

# Apply formatting
metamathqa = metamathqa.map(format_meta_math)

In [None]:
print(metamathqa['train']['text'][0])

In [None]:
print(metamathqa['train']['text'][4])

# Finetune mô hình

In [None]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=metamathqa["train"],
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=SFTConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        warmup_steps=0,
        num_train_epochs=1,  # adjust as needed
        learning_rate=1e-4,
        save_strategy="steps",
        save_steps=100,
        logging_steps=10,
        weight_decay=0.0,
        lr_scheduler_type="cosine",
        output_dir="mistral_output_new",
        report_to="none",  # avoid wandb
    ),
)

In [None]:
# Xóa file rng_state.pth ra khỏi folder checkpoint để tránh lỗi
import os
os.remove('/kaggle/working/mistral_checkpoint/rng_state.pth')

In [None]:
trainer.train(resume_from_checkpoint="/kaggle/working/mistral_checkpoint")

# Lưu mô hình lên HuggingFace

In [None]:
from huggingface_hub import login, create_repo, upload_folder

# Push to Hugging Face Hub
model.push_to_hub("TrinhHoangKhang/mistral-7B-200k-300k", token="HUGGINGFACE_TOKEN")
tokenizer.push_to_hub("TrinhHoangKhang/mistral-7B-200k-300k", token="HUGGINGFACE_TOKEN")

# -> Model lưu tại: TrinhHoangKhang/mistral-7B-200k-300k