### Introduction

In this notebook we will fine-tune an LLM for the following task:

-- A Gemma model was given a text and a rewrite-prompt, such as "Rewrite this text as a sea shanty."

-- We get the original and rewritten text, and we want to fine-tune our model to reconstruct the prompt.

-- I curated a dataset from different publicly available datasets with a diverse set of texts, prompts, and rewritten texts. We will train the model on this dataset (GemmaPromptPrediction).

### Required imports

In [1]:
%%capture
!pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"

# Temporary fix for https://github.com/huggingface/datasets/issues/6753
!pip install datasets==2.16.0 fsspec==2023.10.0 gcsfs==2023.10.0

import os
os.environ["WANDB_DISABLED"] = "true"

### We will use Mistral-7B (without instruction-tuning)

In [2]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/llama-2-13b-bnb-4bit",
    "unsloth/codellama-34b-bnb-4bit",
    "unsloth/tinyllama-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-04-24 20:04:19.234410: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745525059.634741      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745525059.755502      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using

==((====))==  Unsloth 2025.3.19: Fast Mistral patching. Transformers: 4.51.1.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/4.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/155 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

### Lora settings (rank=32)

Comment: A lower rank (4 or 8) is likely to work better, since we're only training on ~8000 completions with ~20 tokens each, and r=32 is likely to overfit.

In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = True,
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.3.19 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


### We add a question-answer template.

This template is so that the model already has a high probability of outputting the correct answer and gives a head start to the fine-tuning, which makes it more accurate.

In [6]:
import pandas as pd
df=pd.read_csv("/kaggle/input/gemmapromptprediction/sft_df.csv")
df['original_text'] = df['original_text'].str.slice(0, 1200)
df['preprocessed_text'] = df['preprocessed_text'].str.slice(0, 1200)
df['prompt'] = '### Question: What is a possible rewrite prompt that turns the first text into the second?\n<first_text>' + df['original_text'] + '<end_of_first_text>\n<second_text>' + df['preprocessed_text'] + '<end_of_second_text>\n### Answer: The prompt is: ' + df['rewrite_prompt'] + '</s>'
print(df['prompt'][0])

### Question: What is a possible rewrite prompt that turns the first text into the second?
<first_text>FARGO-Law enforcement agencies in Fargo, Moorhead and surrounding communities are launching a new multi-agency street crimes unit to battle growing crime concerns in the metro area, Fargo police said Tuesday.

Fargo police Lt. Mike Mitchell said law enforcement officials plan to announce the new unit at a news conference set for Thursday afternoon.

North Dakota Attorney General Wayne Stenehjem is expected to join acting Fargo Police Chief Dave Todd for the news conference, along with North Dakota Bureau of Criminal Investigation Director Dallas Carlson, West Fargo Police Chief Mike Reitan, Moorhead Police Chief David Ebinger and Cass County and Clay County sheriffs Paul Laney and Bill Bergquist.

The news conference, which will be held in the Community Room at the Main Library, 102 3rd St. N., is designed to discuss the formation of the new, multi-agency unit being developed to addre

### We only train the model on only completions. 

This means only the part after the answer will be included in the loss.

(after "### Answer:"=[26307, 28747, 415, 11510, 349, 28747])

In [7]:
from datasets import Dataset, DatasetDict
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = df
dataset = Dataset.from_pandas(dataset)
response_template=[26307, 28747, 415, 11510, 349, 28747]
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['prompt'])):
        output_texts.append(example['prompt'][i])
    return output_texts


### Set optimizer hyperparameters
(There are 32 samples per step)

In [8]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    formatting_func=formatting_prompts_func,
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    data_collator=collator,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 8,
        warmup_steps = 5,
        num_train_epochs=1,
        learning_rate = 1e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/8077 [00:00<?, ? examples/s]

### Logging

In [9]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
6.883 GB of memory reserved.


### Start training

In [10]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 8,077 | Num Epochs = 1 | Total steps = 252
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 8 x 1) = 32
 "-____-"     Trainable parameters = 83,886,080/7,000,000,000 (1.20% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,2.8676
2,3.1437
3,2.7244
4,2.382
5,2.3432
6,1.9136
7,2.2263
8,1.8419
9,1.9126
10,1.8443


In [11]:
model.push_to_hub("eruzak/mistral_predict_prompt_sft_final", token = "YOUR_TOKEN") 

README.md:   0%|          | 0.00/577 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/336M [00:00<?, ?B/s]

Saved model to https://huggingface.co/eruzak/mistral_predict_prompt_sft_final


### Load validation dataset

In [25]:
from transformers import TextStreamer

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
df_val=pd.read_csv("/kaggle/input/gemmapromptprediction/eval_df.csv")
df_val['original_text'] = df_val['original_text'].str.slice(0, 1200)
df_val['preprocessed_text'] = df_val['preprocessed_text'].str.slice(0, 1200)
df_val['pre_prompt'] = '### Question: What is a possible rewrite prompt that turns the first text into the second?\n<first_text>' + df_val['original_text'] + '<end_of_first_text>\n<second_text>' + df_val['preprocessed_text'] + '<end_of_second_text>\n### Answer: The prompt is: '

### Validation

In [29]:
for n in range(10):   
    model_inputs = tokenizer([df_val['pre_prompt'][n]], return_tensors="pt")
    #model_inputs = tokenizer("Afghanistan is a country in", return_tensors="pt")
    input_ids=model_inputs.input_ids.to('cuda')
    #generated_ids=model.generate(**model_inputs,max_length=1000)
    text_streamer = TextStreamer(tokenizer)
    _ = model.generate(input_ids, streamer = text_streamer, do_sample=True, max_new_tokens = 5000)
    print(f"Original prompt: {df_val['rewrite_prompt'][n]}")
    print("_____________________________________________\n\n")

<s> ### Question: What is a possible rewrite prompt that turns the first text into the second?
<first_text>Archaeologists in Egypt have unearthed the tomb of a previously unknown queen, Egyptian officials say.<end_of_first_text>
<second_text>At [Detective Agency Name], we believe that every case is a puzzle waiting to be solved, and we're experts at piecing together even the most complex mysteries. From the rich history of the tomb discovered in Abu-Sir, to the secrets hidden within the walls of its ancient civilization, we are here to help you uncover the truth.

**Call us today for a free consultation and let us turn your mystery into a revelation.**

**[Detective Agency Name]**

**[Phone Number]**
**[Email Address]**<end_of_second_text>
### Answer: The prompt is:  Change it into a detective agency's ad.</s>
Original prompt: Make the text into a detective agency advertisement
_____________________________________________


<s> ### Question: What is a possible rewrite prompt that turn

### Comments:

This solution wouldn't achieve a high score in the competition's leaderboard. The leaderboard's metric is based on a BERT sentence similarity model, and since it's not very robust, the top solutions hack the reward, for example, by returning "Improve the text to this." for all samples or predicting the prompt and appending the string "lucrarea" many times to it.

Another thing is that this model is trained by SFT, meaning it's trained to sample prompts randomly conditional to the provided texts, while a high-scoring model should be trained using RL or another method in order to give the answer that maximizes the competition's metric; for example, that would mean outputting "Improve the text to this" whenever it doesn't know the answer with high certainty, instead of just outputting a random reasonable prompt.