# Hybrid GPT/SSM fine-tuning
Based on: https://unsloth.ai/docs/models/tutorials/ibm-granite-4.0#fine-tuning-granite-4.0-in-unsloth

In [None]:
import os, importlib.util
import platform

if platform.system() != "Linux":
    print("mamba_ssm currently only runs on linux")
    raise

if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    print("Torch missing or running on google colab, installing adjusted versions")
    try: import numpy, PIL; _numpy = f"numpy=={numpy.__version__}"; _pil = f"pillow=={PIL.__version__}"
    except: _numpy = "numpy"; _pil = "pillow"
    !pip install -qqq \
        "torch==2.7.1" "triton>=3.3.0" {_numpy} {_pil} torchvision bitsandbytes "transformers==4.56.2" \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth"
elif importlib.util.find_spec("unsloth") is None:
    print("Installing unsloth")
    !pip install -qqq unsloth
    print("Installing secondary dependencies")
    !pip install -qqq --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth_zoo
    print("Installing mamba_ssm")
    # Mamba is supported only on torch==2.7.1. If you have newer torch versions, please wait 30 minutes!
    !pip install -qqq --no-build-isolation mamba_ssm==2.2.5 causal_conv1d==1.5.2 tf_keras

In [None]:
from unsloth import FastLanguageModel
import torch
from utils.gpt_utils import Encoding

fourbit_models = [
    "unsloth/granite-4.0-micro",
    "unsloth/granite-4.0-h-micro",
    "unsloth/granite-4.0-h-tiny",
    "unsloth/granite-4.0-h-small",
    # Base pretrained Granite 4 models
    "unsloth/granite-4.0-micro-base",
    "unsloth/granite-4.0-h-micro-base",
    "unsloth/granite-4.0-h-tiny-base",
    "unsloth/granite-4.0-h-small-base",
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/Phi-4",
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit",  # [NEW] We support TTS models!
]  # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/granite-4.0-h-micro",
    max_seq_length=2048,  # Choose any for long context!
    load_in_4bit=True,  # 4 bit quantization to reduce memory
    load_in_8bit=False,  # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning=False,  # [NEW!] We have full finetuning now!
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=32,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "shared_mlp.input_linear",
        "shared_mlp.output_linear",
    ],
    lora_alpha=32,
    lora_dropout=0,  # Supports any, but = 0 is optimized
    bias="none",  # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
    random_state=3407,
    use_rslora=False,  # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
)

In [None]:
from datasets import load_dataset, Dataset

# Use the below shared sheet
# sheet_url = 'https://docs.google.com/spreadsheets/d/1NrjI5AGNIwRtKTAse5TW_hWq2CwAS03qCHif6vaaRh0/export?format=csv&gid=0'

# Or unsloth/Support-Bot-Recommendation
sheet_url = "https://huggingface.co/datasets/unsloth/Support-Bot-Recommendation/raw/main/support_recs.csv"

dataset = load_dataset(
    "csv",
    data_files={"train": sheet_url},
    column_names=[
        "snippet",
        "recommendation",
    ],  # Replace with the actual column names of your sheet
    skiprows=1,  # skip header rows
)["train"]
def formatting_prompts_func(examples):
    user_texts = examples["snippet"]
    response_texts = examples["recommendation"]
    messages = [
        [
            {"role": "user", "content": user_text},
            {"role": "assistant", "content": response_text},
        ]
        for user_text, response_text in zip(user_texts, response_texts)
    ]
    texts = [
        tokenizer.apply_chat_template(
            message, tokenize=False, add_generation_prompt=False
        )
        for message in messages
    ]

    return {
        "text": texts,
    }
assert(isinstance(dataset, Dataset))
dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
)
sample = dataset[5]
print("Snippet", sample["snippet"])
print("Recommendation", sample["recommendation"])
print("Text", sample["text"])

In [None]:
from trl.trainer.sft_trainer import SFTTrainer
from trl.trainer.sft_config import SFTConfig
from transformers import PreTrainedTokenizerBase, ProcessorMixin
assert(isinstance(dataset, Dataset))
assert(isinstance(tokenizer, PreTrainedTokenizerBase) or isinstance(tokenizer, ProcessorMixin))
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    eval_dataset=None,  # Can set up evaluation!
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,  # Use GA to mimic batch size!
        warmup_steps=5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps=60,
        learning_rate=2e-4,  # Reduce to 2e-5 for long training runs
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",  # Use TrackIO/WandB etc
    ),
)

In [None]:
from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part="<|start_of_role|>user<|end_of_role|>",
    response_part="<|start_of_role|>assistant<|end_of_role|>",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode(
    [
        tokenizer.pad_token_id if x == -100 else x
        for x in trainer.train_dataset[100]["labels"]
    ]
).replace(tokenizer.pad_token, " ")

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
# @title Test Scenarios
# --- Scenario 1: Video-Conferencing Screen-Share Bug (11 turns) ---
scenario_1 = """
User: Everyone in my meeting just sees a black screen when I share.
Agent: Sorry about that—are you sharing a window or your entire screen?
User: Entire screen on macOS Sonoma.
Agent: Thanks. Do you have “Enable hardware acceleration” toggled on in Settings → Video?
User: Yeah, that switch is on.
Agent: Could you try toggling it off and start a quick test share?
User: Did that—still black for attendees.
Agent: Understood. Are you on the desktop app v5.4.2 or the browser client?
User: Desktop v5.4.2—just updated this morning.
"""

# --- Scenario 2: Smart-Lock Low-Battery Loop (9 turns) ---
scenario_2 = """
User: I changed the batteries, but the lock app still says 5 % and won’t auto-lock.
Agent: Let’s check firmware. In the app, go to Settings → Device Info—what version shows?
User: 3.18.0-alpha.
Agent: Latest stable is 3.17.5. Did you enroll in the beta program?
User: I might have months ago.
Agent: Beta builds sometimes misreport battery. Remove one battery, wait ten seconds, reinsert, and watch the LED pattern.
User: LED blinks blue twice, then red once.
Agent: That blink code means “config mismatch.” Do you still have the old batteries handy?
User: Tossed them already.
"""

# --- Scenario 3: Accounting SaaS — Corrupted Invoice Export (10 turns) ---
scenario_3 = """
User: Every invoice I download today opens as a blank PDF.
Agent: Is this happening to historic invoices, new ones, or both?
User: Both. Anything I export is 0 bytes.
Agent: Are you exporting through “Bulk Actions” or individual invoice pages?
User: Individual pages.
Agent: Which browser/OS combo?
User: Chrome on Windows 11, latest update.
Agent: We released a new PDF renderer at 10 a.m. UTC. Could you try Edge quickly, just to rule out a caching quirk?
User: Tried Edge—same zero-byte file.
"""

# --- Scenario 4: Fitness-Tracker App — Stuck Step Count (8 turns) ---
scenario_4 = """
User: My step count has been frozen at 4,237 since last night.
Agent: Which phone are you syncing with?
User: iPhone 15, iOS 17.5.
Agent: In the Health Permissions screen, does “Motion & Fitness” show as ON?
User: Yes, it’s toggled on.
Agent: When you pull down to refresh the dashboard, does the sync spinner appear?
User: Spinner flashes for a second, then nothing changes.
"""

# --- Scenario 5: Online-Course Platform — Quiz Submission Error (12 turns) ---
scenario_5 = """
User: My quiz submits but then shows “Unknown grading error” and resets the answers.
Agent: Which course and quiz name?
User: History 301, Unit 2 Quiz.
Agent: Do you notice a red banner or any code like GR-### in the corner?
User: Banner says “GR-412”.
Agent: That code points to answer-payload size. Were you pasting images or long text into any answers?
User: Maybe a long essay—about 800 words in Question 5.
Agent: Are you on a laptop or mobile?
User: Laptop, Safari on macOS.
"""

In [None]:
FastLanguageModel.for_inference(model)  # Enable native 2x faster inference

messages = [
    {"role": "user", "content": scenario_1},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,  # Must add for generation
    padding=True,
    return_tensors="pt",
    return_dict=True,
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(tokenizer, skip_prompt=False)

_ = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=512,  # Increase if tokens are getting cut off
    use_cache=True,
    # Adjust the sampling params to your preference
    do_sample=True,
    temperature=0.7,
    top_p=0.8,
    top_k=20,
)

In [None]:
FastLanguageModel.for_inference(model)  # Enable native 2x faster inference

messages = [
    {"role": "user", "content": scenario_2},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,  # Must add for generation
    padding=True,
    return_tensors="pt",
    return_dict=True,
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(tokenizer, skip_prompt=False)

_ = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=512,  # Increase if tokens are getting cut off
    use_cache=True,
    # Adjust the sampling params to your preference
    do_sample=False,
    temperature=0.7,
    top_p=0.8,
    top_k=20,
)

In [None]:
model.save_pretrained("granite_lora")  # Local saving
tokenizer.save_pretrained("granite_lora")

In [None]:
if False:
    from unsloth import FastLanguageModel

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="granite_lora",  # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length=2048,
        load_in_4bit=True,
    )
# Merge to 16bit
if False:
    model.save_pretrained_merged(
        "granite_finetune_16bit",
        tokenizer,
        save_method="merged_16bit",
    )
if False:  # Pushing to HF Hub
    model.push_to_hub_merged(
        "HF_USERNAME/granite_finetune_16bit",
        tokenizer,
        save_method="merged_16bit",
        token="YOUR_HF_TOKEN",
    )

# Merge to 4bit
if False:
    model.save_pretrained_merged(
        "granite_finetune_4bit",
        tokenizer,
        save_method="merged_4bit",
    )
if False:  # Pushing to HF Hub
    model.push_to_hub_merged(
        "HF_USERNAME/granite_finetune_4bit",
        tokenizer,
        save_method="merged_4bit",
        token="YOUR_HF_TOKEN",
    )

# Just LoRA adapters
if False:
    model.save_pretrained("granite_lora")
    tokenizer.save_pretrained("granite_lora")
if False:  # Pushing to HF Hub
    model.push_to_hub("HF_USERNAME/granite_lora", token="YOUR_HF_TOKEN")
    tokenizer.push_to_hub("HF_USERNAME/granite_lora", token="YOUR_HF_TOKEN")

## GGUF / llama.cpp conversion

In [None]:
# Save to 8bit Q8_0
if False:
    model.save_pretrained_gguf(
        "granite_finetune",
        tokenizer,
    )
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False:
    model.push_to_hub_gguf(
        "HF_USERNAME/granite_finetune", tokenizer, token="YOUR_HF_TOKEN"
    )

# Save to 16bit GGUF
if False:
    model.save_pretrained_gguf("granite_finetune", tokenizer, quantization_method="f16")
if False:  # Pushing to HF Hub
    model.push_to_hub_gguf(
        "HF_USERNAME/granite_finetune",
        tokenizer,
        quantization_method="f16",
        token="YOUR_HF_TOKEN",
    )

# Save to q4_k_m GGUF
if False:
    model.save_pretrained_gguf(
        "granite_finetune", tokenizer, quantization_method="q4_k_m"
    )
if False:  # Pushing to HF Hub
    model.push_to_hub_gguf(
        "HF_USERNAME/granite_finetune",
        tokenizer,
        quantization_method="q4_k_m",
        token="YOUR_HF_TOKEN",
    )

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "HF_USERNAME/granite_finetune",  # Change hf to your username!
        tokenizer,
        quantization_method=[
            "q4_k_m",
            "q8_0",
            "q5_k_m",
        ],
        token="YOUR_HF_TOKEN",  # Get a token at https://huggingface.co/settings/tokens
    )