# Botanist GRPO Training - Visual Flower Description

This notebook fine-tunes a Gemma-3n model using GRPO (Generative Reinforcement Learning from Preference Optimization) to improve botanical descriptions of flower images.

## 1. Setup and Imports

In [None]:

import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --upgrade --force-reinstall bitsandbytes accelerate xformers==0.0.31.post1 peft  triton cut_cross_entropy unsloth_zoo unsloth torchvision
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer together cerebras_cloud_sdk groq
!pip install --upgrade --force-reinstall transformers
!pip install --no-deps --upgrade timm trl==0.19.1 # Only for Gemma 3N


Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting accelerate
  Downloading accelerate-1.9.0-py3-none-any.whl.metadata (19 kB)
Collecting xformers==0.0.31.post1
  Downloading xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Collecting peft
  Downloading peft-0.17.0-py3-none-any.whl.metadata (14 kB)
Collecting triton
  Downloading triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Collecting cut_cross_entropy
  Downloading cut_cross_entropy-25.1.1-py3-none-any.whl.metadata (9.3 kB)
Collecting unsloth_zoo
  Downloading unsloth_zoo-2025.8.1-py3-none-any.whl.metadata (8.1 kB)
Collecting unsloth
  Downloading unsloth-2025.8.1-py3-none-any.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.3/47.3 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision
  Downloading torchvision-0.22.1-cp311-cp311-manyli

In [None]:
!export HF_HUB_ENABLE_HF_TRANSFER=1
from google.colab import drive
from huggingface_hub import HfApi,login

import os
from google.colab import userdata
drive.mount('/content/drive')
output_dir = f"/content/drive/MyDrive/colab_output/qwen-botanist"
os.makedirs(output_dir, exist_ok=True)
HF_TOKEN=userdata.get('HF_TOKEN')
login(token=HF_TOKEN)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import json
import torch
import re
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import pandas as pd
from tqdm import tqdm

# Unsloth and TRL imports
from unsloth import FastLanguageModel
from unsloth import is_bf16_supported
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import TrainingArguments
from google.colab import userdata

# Together AI for reward evaluation

import os
#from cerebras.cloud.sdk import Cerebras
#from groq import Groq
#client = Cerebras(api_key=userdata.get("CEREBRAS_API_KEY"))
#client = Groq(api_key=userdata.get("GROQ_API_KEY"))
#together_client = Together(api_key=userdata.get("TOGETHER_API_KEY"))
from openai import OpenAI
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=userdata.get("OPENROUTER_API_KEY") )


torch.manual_seed(42)
np.random.seed(42)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

🦥 Unsloth Zoo will now patch everything to make training faster!


## 2. Configuration

In [None]:
# Model configuration
MODEL_NAME =  "mekpro/gemma-3n-botanist-observe6-merged" # our model !
#MODEL_NAME = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit"

MAX_SEQ_LENGTH = 512 # Adjust based on your needs
MAX_NEW_TOKENS = 350
MAX_PROMPT_LENGTH = 150

LOAD_IN_4BIT = False

# PEFT configuration
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.1
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# Training configuration
BATCH_SIZE = 32
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 1e-5
NUM_TRAIN_EPOCHS = 1
WARMUP_STEPS = 20
LOGGING_STEPS = 1
SAVE_STEPS = 100
OUTPUT_DIR = output_dir

# GRPO configuration
NUM_GENERATION_PER_PROMPT = 8  # Number of responses to generate per prompt
TEMPERATURE = 1.1
TOP_P = 0.9

# Reward weights
BOTANIST_REWARD_WEIGHT = 0.5
FORMAT_REWARD_WEIGHT = 0.15
SPECIES_INFLORESCENCETYPE_WEIGHT = 0.35


# API configuration

# Dataset configuration
DATASET_NAME = "mekpro/plantnet300k_observe"
MAX_SAMPLES = 4000  # Set to a number to limit dataset size for testing

## 3. Load Model and Tokenizer

In [None]:
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=LOAD_IN_4BIT,
    dtype=torch.bfloat16 if is_bf16_supported() else torch.float16,
)
model.generation_config.cache_implementation=None
# Configure PEFT

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    target_modules=TARGET_MODULES,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    use_rslora=False,
    loftq_config=None,
)


# Set padding token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

==((====))==  Unsloth 2025.7.11: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Unsloth: Making `model.base_model.model.model.language_model` require gradients


## 4. Load and Prepare Dataset

In [None]:
# Load dataset
dataset = load_dataset(DATASET_NAME, split="train")

# Extract unique species
species_list = dataset["species"]
dataset = dataset.shuffle(seed=42)
unique_species = list(set(species_list))
print(f"Total unique species: {len(unique_species)}")

# Limit dataset size if specified
if MAX_SAMPLES:
    dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset))))
    print(f"Limited dataset to {len(dataset)} samples")

# Define the instruction prompt
INSTRUCTION_PROMPT = '''As a botanist, describe its visual features and species_name in JSON format. {'color':'', 'inflorescencetype':'', 'inflorescence_description':'','flower_arrangement':'', 'flower_density':'', 'species':'', 'family':'', 'genus':''}'''
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system" , "content" : INSTRUCTION_PROMPT},
        {"role": "user",    "content" : "Describe "+x["species"]}
    ]
})
#print(f"Prepared {len(dataset)} prompts")
#print(dataset[0])

#write function to make a dict {'species' : 'inflorescencetype'} from dataset
def get_species_inflorescence_dict(dataset):
    species_inflorescence_dict = {}
    for example in dataset:
        species_inflorescence_dict[example['species']] = example['inflorescencetype']
    return species_inflorescence_dict
species_list = dataset["species"]
unique_species = list(set(species_list))
print("unique after sample: %d" %len(unique_species))
species_inflorescence_data = get_species_inflorescence_dict(dataset)
print(species_inflorescence_data)


Total unique species: 381
Limited dataset to 4000 samples
unique after sample: 342
{'Hypericum humifusum': 'Cyme', 'Secale cereale': 'Spike', 'Calendula stellata': 'Capitulum (Head)', 'Smilax aspera': 'Panicle of Umbels', 'Phyllanthus amarus': 'Axillary cyme', 'Lupinus albus': 'Terminal raceme', 'Zamioculcas zamiifolia': 'Spadix and Spathe', 'Hippophae rhamnoides': 'Axillary Cluster', 'Tradescantia pallida': 'Terminal cymes', 'Hebe salicifolia': 'Raceme', 'Cucurbita maxima': 'Solitary', 'Casuarina equisetifolia': 'Catkin', 'Moehringia trinervia': 'Cyme', 'Tagetes erecta': 'Capitulum', 'Melilotus indicus': 'Raceme', 'Lupinus angustifolius': 'Raceme', 'Phedimus aizoon': 'Cyme (Corymb-like)', 'Carthamus tinctorius': 'Capitulum', 'Epipactis helleborine': 'Raceme', 'Gynura procumbens': 'Capitulum', 'Hyoscyamus niger': 'Scorpioid cyme', 'Lactuca sativa': 'Capitulum', 'Ophrys lutea': 'Spike', 'Acalypha hispida': 'Catkin', 'Pelargonium peltatum': 'Pseudo-umbel', 'Sedum pachyphyllum': 'Terminal

## 5. Reward Functions

In [None]:
instruction_grade = '''
as a botanist professor, grade this observation of the species consistent with real data, give score 0-100.
- If can tell correct color of that species: give +20score, if similar: +10 score, if wrong : -10 score, if color not exists: -40 score
- If can tell inflorescencetype of that species, give +20 score, if wrong inflorescencetype: -10 score , if that inflorescencetype not exist:-20 score
- If can tell valid family of plant from species, + 10 score , if wrong -5 score, if family not valid plant family , -20 score
- If can tell correct flower_density+10 score,
- Other 40 score from visual description consistency, minus 20 score if have wrong grammar.
- If have other alphabet word in non-english appeared, minus 50 score
only reply score number 0-100 dont describe"},
'''
def botanist_rw(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Use LLM to evaluate botanical accuracy of the response.
    Returns a list of scores between 0 and 1.
    """
    rewards = []

    for prompt, response in zip(prompts, completions):

        eval_prompt = str(response[0]["content"])

        try:
            api_response = client.chat.completions.create(
                #model="Qwen/Qwen3-235B-A22B-Instruct-2507-tput",
                #model="qwen-3-235b-a22b-instruct-2507",
                model="qwen/qwen3-235b-a22b-2507:nitro",
                messages=[
                    {"role": "system", "content": instruction_grade},
                    {"role": "user", "content": eval_prompt}
                ],
                temperature=0.1,
                max_completion_tokens=10,
            )

            # Extract score from response
            score_text = api_response.choices[0].message.content.strip()

            # Extract number from response
            score_match = re.search(r'\d+', score_text)
            if score_match:
                score = float(score_match.group())
                score = min(max(score, 0), 100) / 100.0  # Normalize to 0-1
            else:

                score = 0.0

            output = response[0]["content"].replace("```json","").replace("\n","")

            print("%.2f : %s : %s " % (score,str(prompt[1]["content"]), output))


        except Exception as e:
            print(f"Error in botanist_reward: {e}")
            score = 0.0

        rewards.append(score * BOTANIST_REWARD_WEIGHT)  # Apply weight here

    return rewards

def format_rw(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Validate JSON format and required fields.
    Returns a list of scores between 0 and 1.
    """
    required_fields = { "color", "inflorescencetype", "inflorescence_description", "flower_arrangement", "flower_density", "species", "family", "genus"}

    rewards = []

    for r in completions:
        score = 100.0
        response = r[0]["content"]
        response = response.replace("```json","").replace("\n","").replace("```","")
        #print(response)

        try:
            # Try to parse JSON from the extracted text
            data = json.loads(response)
            # Check if it's a dictionary
            if not isinstance(data, dict):
                score = 0.0
            else:
                # Check for required fields
                present_fields = set(data.keys())

                # Deduct points for missing required fields
                missing_fields = required_fields - present_fields
                score -= len(missing_fields) * 10

                # Deduct points for extra fields
                extra_fields = present_fields - required_fields
                score -= len(extra_fields) * 20

        except json.JSONDecodeError:
            # Invalid JSON format
            score = 0.0
        except Exception as e:
            print(f"Error in format_reward: {e}")
            print(f"Response type: {type(response)}")
            print(f"Response content: {response}")  # This will help debug
            score = 0.0

        rewards.append(max(score, 0.0) / 100.0 * FORMAT_REWARD_WEIGHT)

    return rewards


#write reward function to check that try to make object jsonstring 'completions' (if fail score=0), check that object have 'species' key ,
# if content of species is a a substring of prompts , add score 0.5 ,
# if content of inflorescencetype is == species_inflorescence_data['species'], add score 0.5
# score * SPECIES_INFLORESCENCETYPE_WEIGHT , append to list  , return
#def species_inflorescencetype_reward_func((prompts: List[str], completions: List[str], **kwargs) -> List[float]:

def species_rw(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    rewards = []
    for prompt, response in zip(prompts, completions):
        score = 0.0
        response_text = response[0]["content"].replace("```json","").replace("\n","").replace("```","")

        try:
            data = json.loads(response_text)
            if "species" in data:
                # Check if species name is a substring of the prompt
                #print(type(data["species"]))
                #print(type(str(prompt[1])[-20:]))
                #print(type(species_inflorescence_data[data["species"]]))
                #print(prompt)
                #print(response)
                #print(data["species"] + " : " + str(prompt[1]["content"]) + " : " + species_inflorescence_data[data["species"]])
                if data["species"] in str(prompt[1]["content"]):
                    score += 0.5

                # Check if inflorescence type matches the data for the species
                if "inflorescencetype" in data and data["species"] in species_inflorescence_data:
                    if data["inflorescencetype"] == species_inflorescence_data[data["species"]]:
                        score += 0.5

        except json.JSONDecodeError:
            # Invalid JSON format
            score = 0.0
        except Exception as e:
            print(f"Error in species_inflorescencetype_reward_func: {e}")
            score = 0.0

        rewards.append(score * SPECIES_INFLORESCENCETYPE_WEIGHT)

    return rewards

## 6. GRPO Configuration and Training

In [None]:
# Create training arguments
model.config.use_cache = False
training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    max_grad_norm=0.5,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    warmup_steps=WARMUP_STEPS,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    fp16=not is_bf16_supported(),
    bf16=is_bf16_supported(),
    gradient_checkpointing=False,
    optim="paged_adamw_8bit",
    seed=42,
      # GRPO specific
    num_generations=NUM_GENERATION_PER_PROMPT,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_NEW_TOKENS,
)

# Create GRPO trainer
trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    reward_funcs=[botanist_rw, format_rw, species_rw],
)

print("GRPO Trainer initialized successfully!")

GRPO Trainer initialized successfully!


## 7. Training

In [None]:

torch._dynamo.config.cache_size_limit = 512  # Default is 64, increase as needed
# Start training
print("Starting GRPO training...")
trainer.train()

print("Training completed!")

0.58 : Describe Vaccaria hispanica : {"color": "Vibrant scarlet-red with petals brown from inside in.** "inflorescencetype": "Solitary", "inflorescence_description": "Single flower on an long pedicel, arising from leaf axil.", "flower_arrangement": "Solitary terminal flowers, sometimes axillary.", "flower_density": "Typically scattered, not dense clumps of flowers.", "species": "Punica granatum", "family": "Lythraceae", "genus": "Punica"} 
0.58 : Describe Vaccaria hispanica : {"color": "Vibrant magenta-pink", "inflorescencetype": "Dense terminal cyme", "inflorescence_description": "2-4 flowers in a compact, rounded cluster topping the fuzzy brown leaves.", "flower_arrangement": "A dense mat of bright flowers tightly packed on a short, slightly elongated stem.", "flower_density": "Dense", "species": "Fedia cornucopiae", "family": "Caprifoliaceae", "genus": "Fedia"} 
0.45 : Describe Vaccaria hispanica : {"color": "Vibrant scarlet-red with petals visible on upper petals.", "inflorescencet

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / botanist_rw / mean,rewards / botanist_rw / std,rewards / format_rw / mean,rewards / format_rw / std,rewards / species_rw / mean,rewards / species_rw / std
1,-0.0,0.290312,0.162853,107.21875,82.0,139.0,0.0,107.21875,82.0,139.0,0.0,0.137187,0.118709,0.13125,0.050402,0.021875,0.058802
2,-0.0,0.3625,0.15024,108.6875,82.0,149.0,0.0,108.6875,82.0,149.0,0.0,0.189063,0.136183,0.140625,0.03689,0.032812,0.069398
3,0.0,0.256094,0.144731,112.75,86.0,298.0,0.0,112.75,86.0,298.0,0.002495,0.119375,0.166563,0.13125,0.050402,0.005469,0.030936
4,0.0,0.192188,0.11064,130.90625,88.0,350.0,0.09375,108.241379,88.0,127.0,0.002193,0.059375,0.079755,0.121875,0.059484,0.010937,0.043039
5,0.0,0.450156,0.249967,106.3125,87.0,134.0,0.0,106.3125,87.0,134.0,0.001985,0.200156,0.16842,0.140625,0.03689,0.109375,0.145739
6,0.0,0.255469,0.18333,115.625,93.0,250.0,0.0,115.625,93.0,250.0,0.002733,0.121875,0.120775,0.117188,0.063002,0.016406,0.068276
7,0.0,0.232031,0.10049,111.875,90.0,350.0,0.03125,104.193542,90.0,118.0,0.002213,0.091875,0.11398,0.140156,0.036863,0.0,0.0
8,0.0,0.234062,0.118744,114.9375,74.0,237.0,0.0,114.9375,74.0,237.0,0.002361,0.10125,0.086098,0.121875,0.059484,0.010937,0.043039
9,0.0,0.345,0.165225,119.375,94.0,350.0,0.03125,111.935478,94.0,143.0,0.002152,0.160625,0.12585,0.140625,0.03689,0.04375,0.07699
10,0.0,0.281406,0.175315,112.28125,88.0,137.0,0.0,112.28125,88.0,137.0,0.002361,0.13375,0.125016,0.13125,0.050402,0.016406,0.068276


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
0.25 : Describe Melilotus indicus : {"color": "Vibrant violet-blue", "inflorescencetype": "Raceme", "inflorescence_description": "Compact racemes with distinct, crescere along branching stems.", "flower_arrangement": "Flowers clustered along a dense, ascending stalk structure.", "flower_density": "Very dense", "species": "Crotalaria polysperma", "family": "Fabaceae", "genus": "Crotalaria"} 
0.95 : Describe Melilotus indicus : {"color": "Vibrant violet-blue", "inflorescencetype": "Raceme", "inflorescence_description": "Compact racemes with distinct, nodding, individual flowers.", "flower_arrangement": "Flowers closely packed on a slender stem.", "flower_density": "Tightly on raceme", "species": "Lupinus polyphyllus", "family": "Fabaceae", "genus": "Lupinus"} 
0.95 : Describe Melilotus indicus : {"color": "Vibrant violet-blue", "inflorescencetype": "Raceme", "inflorescence_description": "Compact, robust-racemes with abundan

KeyboardInterrupt: 

## 8. Save Model

In [None]:
# model.save_pretrained("gemma-3n")  # Local saving
# tokenizer.save_pretrained("gemma-3n")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving

model_save_path = f"gemma-3n-botanist-grpo6p"

#model.save_pretrained(output_dir)  # Local saving
#tokenizer.save_pretrained(output_dir)
#model.save_pretrained_merged(output_dir+"-merged", tokenizer)
#model.push_to_hub("mekpro/"+model_save_path, token = HF_TOKEN) # Online saving
#tokenizer.push_to_hub("mekpro/"+model_save_path, token = HF_TOKEN) # Online saving


model.push_to_hub_merged(
    repo_id="mekpro/"+model_save_path+"-merged",
    tokenizer=tokenizer,
    save_method="merged_16bit",
    token = HF_TOKEN
)

#model.push_to_hub_gguf(model_save_path+"-gguf", repo_id="mekpro/gemma-3n-e4b-botanist-gguf-grpo", quantization_type="q8_0", token=HF_TOKEN)


No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Successfully copied all 3 files from cache to mekpro/gemma-3n-botanist-grpo6p-merged.
Downloading safetensors index for mekpro/gemma-3n-botanist-observe6-merged...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

No files have been modified since last commit. Skipping to prevent empty commit.
Unsloth: Merging weights into 16bit:   0%|          | 0/3 [00:00<?, ?it/s]No files have been modified since last commit. Skipping to prevent empty commit.
Unsloth: Merging weights into 16bit:  33%|███▎      | 1/3 [00:21<00:42, 21.21s/it]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  67%|██████▋   | 2/3 [02:57<01:40, 100.87s/it]

model-00003-of-00003.safetensors:   0%|          | 0.00/2.82G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 3/3 [04:27<00:00, 89.07s/it]


## 9. Test the Fine-tuned Model

In [None]:
def test_model(test_species: str):
    """Test the fine-tuned model with a sample species"""
    test_prompt = f"{INSTRUCTION_PROMPT}\n\nSpecies to observe: {test_species}"

    # Tokenize
    inputs = tokenizer(test_prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Decode
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the generated part
    response = response[len(test_prompt):].strip()

    print(f"Test Species: {test_species}")
    print(f"Model Response:\n{response}")

    # Evaluate with reward functions
    bot_reward = botanist_reward(response, test_species)
    fmt_reward = format_reward(response)

    print(f"\nBotanist Reward: {bot_reward}/100")
    print(f"Format Reward: {fmt_reward}/100")
    print(f"Combined Score: {BOTANIST_REWARD_WEIGHT * bot_reward + FORMAT_REWARD_WEIGHT * fmt_reward:.1f}/100")

# Test with a few species
test_species_list = unique_species[:3] if len(unique_species) >= 3 else unique_species
for species in test_species_list:
    test_model(species)
    print("\n" + "="*80 + "\n")