In [None]:
!pip install unsloth
!pip install --no-deps --upgrade timm

In [None]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E2B-it",
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 16,           # Larger = higher accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

from datasets import load_dataset
dataset = load_dataset("heissanjay/km-fullset-all-final-sm", split="train")
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

In [None]:
from unsloth.chat_templates import standardize_data_formats
train_dataset = standardize_data_formats(train_dataset)
eval_dataset = standardize_data_formats(eval_dataset)

In [None]:
def formatting_prompts_func(examples):
   convos = examples["messages"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

In [None]:
train_dataset = train_dataset.map(formatting_prompts_func, batched = True)
eval_dataset = eval_dataset.map(formatting_prompts_func, batched = True)

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 16,
        warmup_ratio = 0.03,
        # num_train_epochs = 3, 
        max_steps = 100,
        learning_rate = 2e-3,
        logging_steps = 1,
        eval_steps = 10,
        eval_strategy = "steps",
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

In [None]:
import torch._dynamo
torch._dynamo.config.cache_size_limit = 1024 * 4

In [None]:
trainer_stats = trainer.train()

In [None]:
model.save_pretrained("gemma-3n-kalvimate")
tokenizer.save_pretrained("gemma-3n-kalvimate")

In [None]:
system_prompt = '''
You are KalviMate, a friendly, patient, and encouraging AI tutor. Your purpose is to help students from rural government schools in India learn from their state board syllabus. Explain concepts simply, use relatable examples, and always be supportive. Your knowledge is based strictly on the provided textbook content.
'''

messages = [
    {"role": "system", "content": [{"type": "text", "text": f"{system_prompt}"}]},
    {
    "role": "user",
    "content": [{"type" : "text", "text" : "Hi, Kalvimate. What are we learning today?",}]
},
]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
).to("cuda")

from transformers import TextStreamer
_ = model.generate(
    **inputs,
    max_new_tokens = 512, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

In [None]:

model.push_to_hub("repo-id/gemma-3n-finetune", token = "hf_..") # Online saving
tokenizer.push_to_hub("heissanjay/gemma-3n-finetune", token = "hf_..") # Online saving