[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/docto-rin/Med-LLM-Jp/blob/main/colab_notebooks/JMLE-SFT.ipynb)

**Note on Hardware Requirements:**
*   **GPU:** NVIDIA A100 (40GB VRAM) was used for running this notebook.
*   **Training VRAM:** Peak usage observed during training was 22.8 GB.

## Installation (requiring a session restart)

In [1]:
# ==============================================================================
# Installation (requiring a session restart after the first run)
# ==============================================================================
!pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install unsloth

Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-ufpraro9/unsloth_b33c1d04081348cea92d3f27fe067308
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-ufpraro9/unsloth_b33c1d04081348cea92d3f27fe067308
  Resolved https://github.com/unslothai/unsloth.git to commit 05f4875aff111bf3801f6e740cc03cb7a8594c9b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
!pip install --upgrade torch
!pip install --upgrade xformers

import torch
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"

Collecting trl>=0.8.3
  Downloading trl-0.16.1-py3-none-any.whl.metadata (12 kB)
Downloading trl-0.16.1-py3-none-any.whl (336 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.4/336.4 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
  Attempting uninstall: trl
    Found existing installation: trl 0.15.2
    Uninstalling trl-0.15.2:
      Successfully uninstalled trl-0.15.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
unsloth-zoo 2025.3.17 requires trl!=0.15.0,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.15.2,>=0.7.9, but you have trl 0.16.1 which is incompatible.[0m[31m
[0mSuccessfully installed trl-0.16.1
Collecting xformers
  Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl (43.4 MB)
[2K

## Manual Auth

In [4]:
# ==============================================================================
# Authentication
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')

use_secret = False

Mounted at /content/drive


In [5]:
if use_secret:
    from huggingface_hub import login as hf_login
    from google.colab import userdata
    hf_login(userdata.get('HF_TOKEN'))
else:
    from huggingface_hub import notebook_login
    notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
import wandb

if use_secret:
    from google.colab import userdata
    wandb.login(key=userdata.get('WANDB_API_KEY'))
else:
    wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdocto-rin[0m ([33mnagoya-u[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Workflow (push "Ctrl + F10")

### Setup

In [3]:
# ==============================================================================
# Imports (after potential restart)
# ==============================================================================
import os
import gc
import math
import torch
from unsloth import (
    FastLanguageModel,
    is_bfloat16_supported,
    # UnslothTrainer,             # We'll use SFTTrainer from TRL
    # UnslothTrainingArguments    # We'll use TrainingArguments from TRL/Transformers
)
from transformers import (
    AutoModelForCausalLM, # Good practice, though FastLanguageModel wraps it
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,    # Use standard TrainingArguments for SFTTrainer
    pipeline,
    TextStreamer,
)
from datasets import load_dataset, load_from_disk
from datasets.dataset_dict import DatasetDict
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM # Import SFTTrainer for Supervised Fine-Tuning

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [7]:
# ==============================================================================
# Configuration
# ==============================================================================
print("\nSetting up configuration for SFT...")

# --- Model IDs and Names ---
model_id = "cyberagent/DeepSeek-R1-Distill-Qwen-32B-Japanese" # Base model used for CPT

# Define a name for your SFT model and the repo on HF Hub
sft_model_suffix = "sft-0.6" # Suffix for the SFT model version
new_sft_model_id = f"DeepSeek-R1-Distill-Qwen-32B-Japanese-{sft_model_suffix}" # Full name for the SFT adapter/model
hf_username = "doctorin" # <<<<<<<<<================= REPLACE THIS (Your Hugging Face Username)

# --- Paths ---
# Base directory on Google Drive
base_save_dir = "/content/drive/MyDrive/student_iwase/finetuned-models" # CHANGE this path if desired
# Specific directory structure for this SFT model
model_base_path_sft = os.path.join(base_save_dir, new_sft_model_id)
output_dir_sft = os.path.join(model_base_path_sft, "training_checkpoints") # For SFT checkpoints
final_sft_adapter_save_path = os.path.join(model_base_path_sft, "final_sft_adapter") # For final SFT LoRA adapter

# Path to the *original* dataset used for CPT (needed for SFT formatting)
dataset_id = "doctorin/JMLE-CoT-gemini-2.5-pro-dataset-combined" # <<<<<<<<<================= REPLACE THIS if different
# Path where the *processed* dataset was saved during CPT (used to get split info if needed, but not directly for SFTTrainer)
# processed_dataset_path_cpt = os.path.join(base_save_dir, cpt_model_base_name, "processed_dataset")

os.makedirs(output_dir_sft, exist_ok=True)
os.makedirs(final_sft_adapter_save_path, exist_ok=True)

# --- SFT Training Parameters ---
max_seq_length = 2048       # Keep consistent with CPT or adjust if needed
load_in_4bit = True         # Use QLoRA (recommended)
lora_r = 16                 # LoRA rank (can be same as CPT or adjusted for SFT)
lora_alpha = 16             # LoRA alpha (can be same as CPT or adjusted for SFT)
lora_dropout = 0.05         # LoRA dropout
target_modules = [          # Modules to target with LoRA during SFT (should generally include those from CPT)
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj"
    # "embed_tokens", "lm_head" # Keep these if trained during CPT
]
per_device_train_batch_size = 1 # Adjust based on VRAM
gradient_accumulation_steps = 8 # Adjust based on VRAM (effective batch size = batch * grad_accum)
learning_rate = 2e-5        # Often slightly lower for SFT than CPT (e.g., 1e-5 to 5e-5)
warmup_ratio = 0.05         # Use ratio instead of steps for flexibility with epochs
max_grad_norm = 0.3
num_train_epochs = 1        # <<<<<<<<<================= ADJUST SFT Epochs (typically 1-3)
optim = "adamw_8bit"        # Use 8-bit AdamW optimizer
save_strategy = "epoch"     # Or "steps"
save_steps = 100            # Only used if save_strategy="steps", calculated later if needed
logging_steps = 10
save_total_limit = 2        # Keep last N checkpoints + final adapter
validation_split_percentage = 10 # Use the same split ratio
seed = 42                   # Seed for reproducibility in SFT

# --- Derived Parameters ---
effective_batch_size = per_device_train_batch_size * gradient_accumulation_steps
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
print(f"Number of GPUs detected: {num_gpus}")

print(f"\nConfiguration set for SFT:")
print(f"  Base Model: {model_id}")
print(f"  SFT Model Name: {new_sft_model_id}")
print(f"  Dataset ID: {dataset_id}")
print(f"  Max Sequence Length: {max_seq_length}")
print(f"  Output (Checkpoints): {output_dir_sft}")
print(f"  Final SFT Adapter Save Path: {final_sft_adapter_save_path}")
print(f"  Effective Batch Size: {effective_batch_size * num_gpus}")


Setting up configuration for SFT...
Number of GPUs detected: 1

Configuration set for SFT:
  Base Model: cyberagent/DeepSeek-R1-Distill-Qwen-32B-Japanese
  SFT Model Name: DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6
  Dataset ID: doctorin/JMLE-CoT-gemini-2.5-pro-dataset-combined
  Max Sequence Length: 2048
  Output (Checkpoints): /content/drive/MyDrive/student_iwase/finetuned-models/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6/training_checkpoints
  Final SFT Adapter Save Path: /content/drive/MyDrive/student_iwase/finetuned-models/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6/final_sft_adapter
  Effective Batch Size: 8


In [8]:
# ==============================================================================
# Load Base Model and Tokenizer
# ==============================================================================
print("\nLoading base model and tokenizer...")
dtype = None # Auto detection by Unsloth
compute_dtype = torch.float16
if is_bfloat16_supported():
    compute_dtype = torch.bfloat16
    print("bfloat16 is supported. Using bfloat16 for computation.")
else:
    print("bfloat16 not supported. Using float16 for computation.")

# Load the base model WITH quantization config for QLoRA
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    trust_remote_code=True,
    device_map = "auto", # Automatically distribute across GPUs
    quantization_config = BitsAndBytesConfig(
        load_in_4bit = load_in_4bit,
        bnb_4bit_quant_type = "nf4",
        bnb_4bit_compute_dtype = compute_dtype,
        bnb_4bit_use_double_quant = True,
    ) if load_in_4bit else None,
    # token = "hf_...", # Add Hugging Face token if the base model is private
)
print("Base model and tokenizer loaded.")

# --- Set Padding Token ---
if tokenizer.pad_token is None:
    print("Setting pad_token to eos_token")
    tokenizer.pad_token = tokenizer.eos_token
    # Ensure pad_token_id is also set, needed by some components
    tokenizer.pad_token_id = tokenizer.eos_token_id


Loading base model and tokenizer...
bfloat16 is supported. Using bfloat16 for computation.
Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.50.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json:   0%|          | 0.00/63.2k [00:00<?, ?B/s]

model-00001-of-00014.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

model-00002-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00003-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00004-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00005-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00006-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00007-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00008-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00009-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00010-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00011-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00012-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00013-of-00014.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00014-of-00014.safetensors:   0%|          | 0.00/2.12G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/6.75k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/485 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

cyberagent/DeepSeek-R1-Distill-Qwen-32B-Japanese does not have a padding token! Will use pad_token = <|vision_pad|>.
Base model and tokenizer loaded.


In [9]:
# ==============================================================================
# Configure PEFT and Load CPT Adapter
# ==============================================================================
print("\nConfiguring PEFT (LoRA) for SFT and loading CPT adapter...")

# Apply PEFT configuration *before* loading the adapter weights
# This prepares the model structure to receive the adapter weights
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_r,
    lora_alpha = lora_alpha,
    lora_dropout = lora_dropout,
    target_modules = target_modules,
    bias = "none", # Required for QLoRA
    use_gradient_checkpointing = "unsloth", # Let Unsloth handle gradient checkpointing
    random_state = seed,
    max_seq_length = max_seq_length,
    # adapter_name = "default" # Keep it simple or name if managing multiple adapters later
)
print("PEFT configured for SFT.")
model.print_trainable_parameters() # Show number of trainable LoRA parameters

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.



Configuring PEFT (LoRA) for SFT and loading CPT adapter...


Unsloth 2025.3.19 patched 64 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


PEFT configured for SFT.
trainable params: 134,217,728 || all params: 32,898,094,080 || trainable%: 0.4080


In [10]:
# ==============================================================================
# Prepare Dataset for SFT (Load Original, Convert to Chat Format, Filter)
# ==============================================================================
import warnings
from unsloth.chat_templates import get_chat_template # Import Unsloth's helper

# --- Define the target chat template ---
# Use a template known to work well with the base model, or a generic one like "chatml"
# For DeepSeek-R1-Distill-Qwen, "qwen" or "chatml" might be appropriate. Let's try "chatml" as a robust default.
TARGET_CHAT_TEMPLATE = "chatml" # Options: "llama-3.1", "chatml", "zephyr", "qwen", etc.
print(f"Using chat template: {TARGET_CHAT_TEMPLATE}")

# --- Apply the chat template to the tokenizer ---
# This adds special tokens and the apply_chat_template method
try:
    tokenizer = get_chat_template(
        tokenizer,
        chat_template=TARGET_CHAT_TEMPLATE,
        # mapping={"role" : "from", "content" : "value"}, # Adjust if your roles differ from standard 'user'/'assistant'
        map_eos_token=True, # Automatically adds EOS token
    )
except Exception as e:
    print(f"Warning: Could not apply chat template '{TARGET_CHAT_TEMPLATE}' directly. Error: {e}")
    print("Proceeding without get_chat_template, manual formatting might be less optimal.")
    # If get_chat_template fails, apply_chat_template might still exist but without added special tokens
    if not hasattr(tokenizer, "apply_chat_template"):
         raise RuntimeError("Tokenizer does not support apply_chat_template and get_chat_template failed.") from e

# --- Function to convert your data structure to chat format ---
def convert_to_chat_format(example):
    """Converts a single example to the chat format [{role: ..., content: ...}, ...]"""
    question = example.get("question", "")
    choices_list = example.get("choices", [])
    cot = example.get("cot", "")
    explanation = example.get("explanation", "")
    raw_answer_list = example.get("answer", [])

    # --- Input Validation (ensure necessary fields are present) ---
    if not question or not isinstance(question, str) or not question.strip() or \
       not cot or not isinstance(cot, str) or not cot.strip() or \
       not explanation or not isinstance(explanation, str) or not explanation.strip() or \
       not isinstance(raw_answer_list, list) or not raw_answer_list:
        return {"conversations": None} # Return None if essential data is missing/invalid

    # --- Format user message ---
    choices_text = "\n".join([f"- {str(choice)}" for choice in choices_list if choice is not None]) if choices_list else "選択肢なし"
    user_content = f"""以下の医師国家試験問題について、思考過程と解答、簡潔な解説を生成してください。
- 思考過程は<think></think>タグで囲んでください。
- 解答はanswer: の後に続けて、後述の形式に厳密に従うように書いてください。
- 簡潔な解説はexplanation: の後に続けて書いてください。
【answerについて注意】
- 問題は単数選択、複数選択、数値入力のいずれかであり、問題文からその形式を判断する。
- 「どれか。」で終わる選択問題で数が明記されていない場合は、五者択一を意味するので選択肢を必ず1つだけ選び小文字のアルファベットで回答する。（単数選択）
- 「2つ選べ」「3つ選べ」などと書いてある場合に限り、指定された数だけの複数選択肢を選び、小文字のアルファベット順（abcde順）に並び替えて列挙する。（複数選択）
- 選択肢が存在しない場合は、小数や四捨五入など、問題文で特に指示があればそれに従い、選択肢記号ではなく数値を回答する。（数値入力）

問題：
{question}

選択肢：
{choices_text}"""

    # --- Format assistant message ---
    string_answer_list = [str(item) for item in raw_answer_list if item is not None]
    answer = "".join(sorted(string_answer_list))

    assistant_content = f"""<think>
{cot}
</think>
answer: {answer}
explanation: {explanation}"""

    # --- Create the conversation list ---
    conversations = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content},
    ]
    return {"conversations": conversations}


# --- Function to apply the chat template and create the 'text' field ---
# This is similar to the function in the article
def formatting_chat_func(examples):
    """Applies the chat template to the 'conversations' field."""
    convos = examples["conversations"]
    # Ensure we only process valid conversation lists
    valid_convos = [c for c in convos if c is not None and isinstance(c, list) and len(c) > 0]
    if not valid_convos:
        # Handle cases where a batch might contain only invalid conversations
        # Return an empty dict or a dict with empty lists for expected keys
        return {"text": []} # Important: return empty list for the 'text' key

    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in valid_convos]
    # Ensure the output dictionary has the same number of rows as the input batch
    # If some convos were invalid, we need a placeholder or strategy.
    # Simplest is to filter invalid ones *before* this function.
    # Assuming filtering happened, len(texts) should match the number of valid inputs.
    return {"text": texts}


# --- Load, Convert, Filter, and Format Dataset ---
print(f"\nLoading original dataset '{dataset_id}'...")
train_dataset_sft = None
eval_dataset_sft = None
try:
    # Load the raw dataset
    raw_dataset = load_dataset(dataset_id, split="train")
    print(f"Original dataset loaded. Number of examples: {len(raw_dataset)}")

    # --- 1. Convert to chat format ---
    print("Converting dataset to chat format...")
    num_proc_convert = os.cpu_count() // 2 or 1
    chat_dataset = raw_dataset.map(
        convert_to_chat_format,
        num_proc=num_proc_convert,
        desc="Converting to chat format"
    )
    # Remove original columns if desired, keep 'conversations'
    # chat_dataset = chat_dataset.remove_columns(raw_dataset.column_names)

    # --- 2. Filter out examples where conversion failed (returned None) ---
    print("Filtering out invalid/incomplete examples...")
    original_chat_count = len(chat_dataset)
    # The filter function now simply checks if 'conversations' is not None
    filtered_chat_dataset = chat_dataset.filter(
        lambda example: example["conversations"] is not None,
        num_proc=num_proc_convert,
        desc="Filtering invalid chats"
    )
    filtered_chat_count = len(filtered_chat_dataset)
    print(f"Filtered out {original_chat_count - filtered_chat_count} examples that failed chat conversion.")

    if filtered_chat_count == 0:
        raise ValueError("Dataset is empty after conversion and filtering. Cannot proceed.")

    # --- 3. Apply chat template formatting ---
    print("Applying chat template to conversations...")
    # This map function now takes the 'conversations' and creates 'text'
    formatted_dataset = filtered_chat_dataset.map(
        formatting_chat_func,
        batched=True, # Crucial for efficiency
        num_proc=num_proc_convert, # Use multiple processes
        remove_columns=["conversations"], # Remove the intermediate conversations column
        desc="Applying chat template"
    )
    # After this step, the dataset should contain a 'text' column with fully formatted strings

    # --- 4. Split the formatted dataset ---
    print(f"Splitting formatted dataset into train/test ({100-validation_split_percentage}/{validation_split_percentage}%) using seed {seed}...")
    # Shuffle before splitting
    formatted_dataset = formatted_dataset.shuffle(seed=seed)
    split_dataset = formatted_dataset.train_test_split(
        test_size=validation_split_percentage / 100.0,
        seed=seed
    )
    train_dataset_sft = split_dataset["train"]
    eval_dataset_sft = split_dataset["test"]
    print(f"Dataset split: Train={len(train_dataset_sft)}, Eval={len(eval_dataset_sft)}")

    # --- Display Sample Formatted Text ---
    if len(train_dataset_sft) > 0:
        print("\n--- Sample Formatted Text for SFT (Using Chat Template) ---")
        # The dataset now directly contains the 'text' field
        print(train_dataset_sft[0]['text'])
        print("-----------------------------------------\n")
    else:
        print("\nWarning: Training dataset is empty after split, cannot show sample prompt.")


except Exception as e:
    print(f"\n❌ FATAL ERROR during dataset preparation using chat format. Error: {e}")
    print("Please check the dataset structure, conversion logic, and template application.")
    raise RuntimeError("Dataset preparation failed.") from e

# --- Final Verification ---
if train_dataset_sft is None or eval_dataset_sft is None or len(train_dataset_sft) == 0 or "text" not in train_dataset_sft.column_names:
     raise RuntimeError("Training or evaluation dataset is missing, empty, or lacks the 'text' column after preparation.")
else:
    print("Dataset preparation using chat format finished. Proceeding to Trainer initialization.")

Unsloth: Will map <|im_end|> to EOS = <｜end▁of▁sentence｜>.


Using chat template: chatml


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.



Loading original dataset 'doctorin/JMLE-CoT-gemini-2.5-pro-dataset-combined'...


README.md:   0%|          | 0.00/539 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/8.69M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3390 [00:00<?, ? examples/s]

Original dataset loaded. Number of examples: 3390
Converting dataset to chat format...


Converting to chat format (num_proc=6):   0%|          | 0/3390 [00:00<?, ? examples/s]

Filtering out invalid/incomplete examples...


Filtering invalid chats (num_proc=6):   0%|          | 0/3390 [00:00<?, ? examples/s]

Filtered out 28 examples that failed chat conversion.
Applying chat template to conversations...


Applying chat template (num_proc=6):   0%|          | 0/3362 [00:00<?, ? examples/s]

Splitting formatted dataset into train/test (90/10%) using seed 42...
Dataset split: Train=3025, Eval=337

--- Sample Formatted Text for SFT (Using Chat Template) ---
<|im_start|>user
以下の医師国家試験問題について、思考過程と解答、簡潔な解説を生成してください。
- 思考過程は<think></think>タグで囲んでください。
- 解答はanswer: の後に続けて、後述の形式に厳密に従うように書いてください。
- 簡潔な解説はexplanation: の後に続けて書いてください。
【answerについて注意】
- 問題は単数選択、複数選択、数値入力のいずれかであり、問題文からその形式を判断する。
- 「どれか。」で終わる選択問題で数が明記されていない場合は、五者択一を意味するので選択肢を必ず1つだけ選び小文字のアルファベットで回答する。（単数選択）
- 「2つ選べ」「3つ選べ」などと書いてある場合に限り、指定された数だけの複数選択肢を選び、小文字のアルファベット順（abcde順）に並び替えて列挙する。（複数選択）
- 選択肢が存在しない場合は、小数や四捨五入など、問題文で特に指示があればそれに従い、選択肢記号ではなく数値を回答する。（数値入力）

問題：
医師の行動として適切なのはどれか。 

選択肢：
- a. 診断のため本人の同意なく患者の家系を調べた。
- b. 診療の内容を患者の実名を含めてSNSに投稿した。
- c. 検体の血液が余ったので本人の同意なく遺伝子配列を解析した。
- d. 学習のため本人の同意なく患者の皮膚所見をホームページに載せた。
- e. 虐待が疑われるため家族の同意なく児童の情報を児童相談所に通報した。<|im_end|>
<|im_start|>assistant
<think>
1.  **問題の分析**: この問題は、医師の行動の適切性を問う倫理・法規に関する問題である。5つの選択肢の中から、医師として適切な行動を1つ選ぶ必要がある。
2.  **選択肢aの評価**: 「診断のため本人の同意なく患者の家系を調べた。」
    *   家系情報は

In [11]:
# ==============================================================================
# Initialize SFT Trainer (Using Chat Formatted Data)
# ==============================================================================
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq # Import necessary collators
from unsloth import is_bfloat16_supported

print("\nInitializing SFTTrainer for chat-formatted data...")

# --- Select Data Collator ---
# For chat-formatted data using apply_chat_template, the standard LM collator or Seq2Seq collator often works.
# SFTTrainer might default to an appropriate one if None is provided and dataset_text_field is set.
# Let's use DataCollatorForLanguageModeling which is common for Causal LM fine-tuning.
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Or use the one from the article (might be suitable if specific padding/masking is needed)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
print(f"Using Data Collator: {type(data_collator).__name__}")


# --- Define TrainingArguments ---
# (Keep your existing arguments, adjust batch size/steps if needed based on new formatting)
# --- Define TrainingArguments ---
sft_training_args = TrainingArguments(
    output_dir=output_dir_sft,
    run_name=f"{new_sft_model_id}-run-{num_train_epochs}epochs-chat",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    learning_rate=learning_rate,
    warmup_ratio=warmup_ratio,
    max_grad_norm=max_grad_norm,
    seed=seed,
    logging_strategy="steps",
    logging_steps=logging_steps,
    logging_first_step=True,
    eval_strategy="steps",
    eval_steps=logging_steps,
    save_strategy="steps",
    save_steps=save_steps,
    save_total_limit=save_total_limit,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    report_to="wandb" if wandb.run is not None else "none",
    per_device_eval_batch_size=per_device_train_batch_size * 2,
    load_best_model_at_end=True,
)

# --- Initialize the SFTTrainer ---
# We no longer need `formatting_func` as the dataset is pre-formatted.
# We need to specify `dataset_text_field`.
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset_sft, # Pass the formatted and split dataset
    eval_dataset=eval_dataset_sft,   # Pass the formatted and split dataset
    dataset_text_field="text",       # Specify the column containing the formatted text
    args=sft_training_args,
    max_seq_length=max_seq_length,   # Keep max sequence length
    data_collator=data_collator,     # Use the selected data collator
    # packing=False, # Consider setting packing=True if sequences are short and you want to combine them
)
print("SFTTrainer initialized for chat-formatted data.")


Initializing SFTTrainer for chat-formatted data...
Using Data Collator: DataCollatorForSeq2Seq


Unsloth: Tokenizing ["text"] (num_proc=12):   0%|          | 0/3025 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=12):   0%|          | 0/337 [00:00<?, ? examples/s]

SFTTrainer initialized for chat-formatted data.


### Train

In [12]:
# ==============================================================================
# Train
# ==============================================================================
print("\nStarting SFT training...")
gc.collect()
torch.cuda.empty_cache()

# Start training
train_result = trainer.train()

# Training logs (like loss) are automatically sent to W&B if enabled.
# You can also access metrics via train_result:
# print(f"Training completed. Metrics: {train_result.metrics}")
print("SFT training finished.")


Starting SFT training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,025 | Num Epochs = 1 | Total steps = 378
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 134,217,728/32,000,000,000 (0.42% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
10,1.3783,1.373395
20,1.3879,1.36308
30,1.3385,1.298748
40,1.2434,1.176807
50,1.1411,1.073096
60,1.0361,0.977468
70,0.9459,0.88153
80,0.8655,0.81918
90,0.8156,0.796856
100,0.8191,0.778634


Unsloth: Not an error, but Qwen2ForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


SFT training finished.


In [13]:
# ==============================================================================
# Save Final SFT Adapter
# ==============================================================================
print(f"\nSaving final SFT LoRA adapter weights to {final_sft_adapter_save_path}...")
# The trainer saves the learned adapter weights relative to the base model.
trainer.save_model(final_sft_adapter_save_path) # Saves adapter_model.safetensors and adapter_config.json

# It's good practice to also save the tokenizer configuration alongside the adapter
tokenizer.save_pretrained(final_sft_adapter_save_path)
print(f"Final SFT adapter and tokenizer saved to: {final_sft_adapter_save_path}")

# ==============================================================================
# Optional: Push SFT Adapter to Hub
# ==============================================================================
upload_sft_adapter_to_hub = True # Set to False to skip uploading adapter

if upload_sft_adapter_to_hub:
    sft_adapter_repo_id = f"{hf_username}/{new_sft_model_id}-LoRA-Adapter"
    print(f"\nAttempting to push SFT adapter to Hugging Face Hub: {sft_adapter_repo_id}")
    try:
        # Use the model object (which has the adapter loaded) to push
        model.push_to_hub(sft_adapter_repo_id) # Pushes the currently active adapter
        tokenizer.push_to_hub(sft_adapter_repo_id) # Push tokenizer too
        print(f"SFT adapter successfully pushed to Hub.")
    except Exception as e:
        print(f"\n⚠️ Warning: Failed to push SFT adapter to Hub. Error: {e}")
        print("You may need to manually upload the files from:", final_sft_adapter_save_path)



Saving final SFT LoRA adapter weights to /content/drive/MyDrive/student_iwase/finetuned-models/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6/final_sft_adapter...
Final SFT adapter and tokenizer saved to: /content/drive/MyDrive/student_iwase/finetuned-models/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6/final_sft_adapter

Attempting to push SFT adapter to Hugging Face Hub: doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-LoRA-Adapter


README.md:   0%|          | 0.00/617 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]

Saved model to https://huggingface.co/doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-LoRA-Adapter


  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

SFT adapter successfully pushed to Hub.


### Inference

In [14]:
# ==============================================================================
# Inference with SFT Adapter (Using Chat Template)
# ==============================================================================
print("\nSetting up for inference with the final SFT adapter (using chat template)...")

# Ensure the model is prepared for inference (Unsloth handles this)
# FastLanguageModel.for_inference(model) # Already done implicitly by Unsloth's generate

# --- Prepare inference input using the chat template ---
inference_question = "鼠径部レベル以下の全感覚消失の脊髄損傷レベルはどれか。"
inference_choices = ["a. 第4頸髄。", "b. 第5胸髄。", "c. 第10胸髄。", "d. 第1腰髄。", "e. 脊髄円錐部"]
inference_choices_text = "\n".join([f"- {choice}" for choice in inference_choices])

# Construct the user message content
user_content_inference = f"""以下の医師国家試験問題について、思考過程と解答、簡潔な解説を生成してください。
- 思考過程は<think></think>タグで囲んでください。
- 解答はanswer: の後に続けて、後述の形式に厳密に従うように書いてください。
- 簡潔な解説はexplanation: の後に続けて書いてください。
【answerについて注意】
- 問題は単数選択、複数選択、数値入力のいずれかであり、問題文からその形式を判断する。
- 「どれか。」で終わる選択問題で数が明記されていない場合は、五者択一を意味するので選択肢を必ず1つだけ選び小文字のアルファベットで回答する。（単数選択）
- 「2つ選べ」「3つ選べ」などと書いてある場合に限り、指定された数だけの複数選択肢を選び、小文字のアルファベット順（abcde順）に並び替えて列挙する。（複数選択）
- 選択肢が存在しない場合は、小数や四捨五入など、問題文で特に指示があればそれに従い、選択肢記号ではなく数値を回答する。（数値入力）

問題：
{inference_question}

選択肢：
{inference_choices_text}"""

# Format the prompt using apply_chat_template
messages = [
    {"role": "user", "content": user_content_inference},
    # IMPORTANT: No assistant message here, the model should generate it.
]

# Apply the template, tokenize, and add the prompt for generation
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True, # Crucial for inference - tells the model to generate the next turn
    return_tensors="pt",
).to("cuda")

print("\n--- Generating Response using SFT Model (Chat Format) ---")
# Displaying the prompt part is harder here as it includes special tokens.
# We can decode the input_ids to see roughly what the model sees.
# print(f"Input Tokens (Decoded):\n{tokenizer.decode(inputs[0])}")


# --- Setup TextStreamer ---
text_streamer = TextStreamer(tokenizer, skip_prompt=True) # Skip the prompt part during streaming

print("\nModel Response:")
# Generate response
with torch.no_grad():
    _ = model.generate(
        input_ids=inputs, # Pass input_ids directly
        streamer=text_streamer,
        max_new_tokens=2048,        # Max tokens for the generated response
        temperature=1,           # Controls randomness
        top_p=0.95,                 # Nucleus sampling
        do_sample=True,            # Enable sampling
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id, # Use all configured EOS tokens
    )
print("\n-----------------------------------------")


Setting up for inference with the final SFT adapter (using chat template)...

--- Generating Response using SFT Model (Chat Format) ---

Model Response:
<think>
1.  **問題の分析**: この問題は、感覚消失の範囲から脊髄損傷のレベルを推定するものである。具体的には、「鼠径部レベル以下の全感覚消失」という記述から、最も上位にある損傷部位を特定する必要がある。
2.  **解剖学的基礎**: 脊髄には感覚神経の出入路があり、各レベルの神経線索は体の特定の部位を支配する。
    *   **頸髄 (C1-C8)**: 上肢および部分的な肩部・項部を支配する。
    *   **胸髄 (T1-T12)**: 体幹前部と上肢後方の高さを支配する。
    *   **腰髄 (L1-L5)**: 下肢を支配する。L1-L3は大腿・鼠径部・会陰部、L4-L5は大腿・脛部・足趾などを支配する。
    *   **尾髄 (S1-S5, 椎体) : 下肢の一部を支配する。
3.  **損傷部位と麻痺範囲の関係**: 脊髄損傷（脊髄損傷は完全切断と仮定する）のレベルにより、麻痺範囲が決まる。上位にある損傷が存在すれば、その下位の感覚は消失する。
    *   例: L1損傷があれば、L1以下（L1, L2, L3, L4, L5, S1~S5）は感覚消失になる。
4.  **症例の分析**: 症例の麻痺は「鼠径部レベル以下の全感覚消失」とある。
    *   鼠径部を支配する神経は L1-L3 である。したがって、L1-L3 の下位（L4-L5, S1~S5）は麻痺範囲内になる。
    *   鼠径部レベル以下（以下）が感覚消失（麻痺）ということは、損傷部位は L1 の下位（L2-L5, S1~S5）であるか、または L1 の上位（C1-C8, T1-T12）である可能性がある。
5.  **解答の推論**: 「鼠径部レベル以下の全感覚消失」というのは、感覚消失の範囲が鼠径部を含む下肢全体（L1-L5, S1~S5）であることを示唆する。最も上位にある損傷部位であると推定されるのは、L1 の上位に存在しない場合

In [15]:
# ==============================================================================
# Clean Up GPU Memory
# ==============================================================================
def clean_memory():
    """Releases GPU memory."""
    print("\nCleaning up GPU memory...")
    # Optional: delete large objects if RAM is also constrained
    # global model, trainer
    # del model
    # del trainer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Memory cleaned.")

clean_memory()



Cleaning up GPU memory...
Memory cleaned.


### Upload

In [16]:
# ==============================================================================
# Optional: Merge Adapter and Upload (GGUF, 4-bit)
# ==============================================================================
# Merging adapters requires significant RAM/VRAM as the full model weights are loaded.
# Unsloth provides helpers but be mindful of resource constraints.

# --- Merge and Upload GGUF ---
# Requires llama.cpp to be cloned and built if not done automatically by Unsloth
upload_gguf_merged = True # <<<<<<<<< SET TO TRUE TO ENABLE GGUF UPLOAD

if upload_gguf_merged:
    print("\nAttempting to merge SFT adapter and push GGUF...")
    gguf_repo_id = f"{hf_username}/{new_sft_model_id}-GGUF"
    quantization_method = "q4_k_m" # Common GGUF quantization type
    try:
        # Unsloth's push_to_hub_gguf can sometimes handle the merge internally,
        # especially if the model object still has the adapter loaded.
        # It might require temporary storage or high RAM.
        model.push_to_hub_gguf(
            gguf_repo_id,
            tokenizer=tokenizer,
            quantization_method=quantization_method,
            # token=userdata.get("HF_TOKEN") if use_secret else None # Pass token if needed
        )
        print(f"GGUF model pushed successfully to {gguf_repo_id}")
    except Exception as e:
        print(f"\n⚠️ Warning: Failed to create or push GGUF. Error: {e}")
        print("Merging and quantizing can fail due to memory limits or llama.cpp issues.")
        print("Try merging manually or ensure sufficient resources.")

    clean_memory()
    # Clean up llama.cpp build directory if it exists and causes issues
    !rm -rf /content/llama.cpp


Attempting to merge SFT adapter and push GGUF...


Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded
model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.
Unsloth: Will remove a cached repo with size 65.5G


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 59.18 out of 83.48 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


 23%|██▎       | 15/64 [00:00<00:01, 29.79it/s]
We will save to Disk and not RAM now.
100%|██████████| 64/64 [02:21<00:00,  2.21s/it]


Unsloth: Saving tokenizer... Done.
Done.


Unsloth: Converting qwen2 model. Can use fast conversion = False.


==((====))==  Unsloth: Conversion from QLoRA to GGUF information
   \\   /|    [0] Installing llama.cpp might take 3 minutes.
O^O/ \_/ \    [1] Converting HF to GGUF 16bits might take 3 minutes.
\        /    [2] Converting GGUF 16bits to ['q4_k_m'] might take 10 minutes each.
 "-____-"     In total, you will have to wait at least 16 minutes.

Unsloth: Installing llama.cpp. This might take 3 minutes...
Unsloth: CMAKE detected. Finalizing some steps for installation.
Unsloth: [1] Converting model at doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF into bf16 GGUF format.
The output location will be /content/doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF/unsloth.BF16.gguf
This might take 3 minutes...
INFO:hf-to-gguf:Loading model: DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'


  0%|          | 0/1 [00:00<?, ?it/s]

unsloth.Q4_K_M.gguf:   0%|          | 0.00/19.9G [00:00<?, ?B/s]

Saved GGUF to https://huggingface.co/doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF


No files have been modified since last commit. Skipping to prevent empty commit.


Saved Ollama Modelfile to https://huggingface.co/doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF
GGUF model pushed successfully to doctorin/DeepSeek-R1-Distill-Qwen-32B-Japanese-sft-0.6-GGUF

Cleaning up GPU memory...
Memory cleaned.


In [17]:
# --- Merge and Upload 4-bit Merged Model ---
upload_4bit_merged = False # <<<<<<<<< SET TO TRUE TO ENABLE 4-BIT MERGED UPLOAD

if upload_4bit_merged:
    print("\nAttempting to merge SFT adapter and push 4-bit model...")
    merged_4bit_repo_id = f"{hf_username}/{new_sft_model_id}-4bit"
    try:
        # Use Unsloth's helper to merge and quantize to 4-bit directly
        model.push_to_hub_merged(
            merged_4bit_repo_id,
            tokenizer,
            save_method="merged_4bit_forced", # Forces 4-bit quantization after merge
            # token=userdata.get("HF_TOKEN") if use_secret else None # Pass token if needed
        )
        print(f"4-bit merged model pushed successfully to {merged_4bit_repo_id}")
    except Exception as e:
        print(f"\n⚠️ Warning: Failed to create or push 4-bit merged model. Error: {e}")
        print("This process requires significant RAM/VRAM for merging before quantization.")

    clean_memory()


print("\nSFT Script finished.")


SFT Script finished.
