In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
!pip install --no-deps unsloth

In [None]:
!pip install --no-deps --upgrade timm

In [None]:
!pip install -qU evaluate

In [None]:
from datasets import load_dataset, DatasetDict
import random
from unsloth import FastVisionModel
import torch

# Prepare Data

In [None]:
bugs = load_dataset("eceunal/bug-bite-images-aug_v3", split="train+validation")
img, label = bugs[0]["image"], bugs.features["label"].int2str(bugs[0]["label"])
print(label)

In [None]:
instruction = "Which insect bite is this if there is a bite?"

int2str = bugs.features["label"].int2str

def convert_to_conversation(sample):
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text",  "text": instruction},
                {"type": "image", "image": sample["image"]},
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": int2str(sample["label"])}],
        },
    ]
    return {"messages": conversation}

converted_dataset = [convert_to_conversation(sample) for sample in bugs]

In [None]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(
        converted_dataset,
        test_size=0.2,
        random_state=42,
        shuffle=True
)

print(len(train), len(test))

# Load Base Model

In [None]:
model, processor = FastVisionModel.from_pretrained(
    "unsloth/gemma-3n-E2B-it",
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
)

# Fine Tune

In [None]:
model = FastVisionModel.get_peft_model(
     model,
     finetune_vision_layers     = True,
     finetune_language_layers   = True,
     finetune_attention_modules = True,
     finetune_mlp_modules       = True,

     r = 16,
     lora_alpha = 16,
     lora_dropout = 0.05,
     bias = "none",
     random_state = 3407,
     target_modules = "all-linear",
     modules_to_save = [
         "lm_head",
         "embed_tokens",
     ],
 )

In [None]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
FastVisionModel.for_training(model)

trainer = SFTTrainer(
    model=model,
    train_dataset = train,
    processing_class = processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor, resize=512),
    max_seq_length = 2048,
    args = SFTConfig(
        num_train_epochs = 2,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = False,
        max_grad_norm = 0.3,
        warmup_steps = 5,
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
        save_steps = 200,
        # MUST for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
    )
)

trainer_stats = trainer.train()

In [None]:
trainer_stats

In [None]:
import matplotlib.pyplot as plt

logs = trainer.state.log_history

loss_values = [entry["loss"] for entry in logs if "loss" in entry]
steps = [entry["step"] for entry in logs if "loss" in entry]

plt.figure(figsize=(8, 5))
plt.plot(steps, loss_values, marker='o')
plt.title("Training Loss Curve")
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# Inference

In [None]:
from datasets import load_dataset, DatasetDict

In [None]:
bugs = load_dataset("eceunal/bug-bite-images-aug_v3", split="train+validation")
img, label = bugs[0]["image"], bugs.features["label"].int2str(bugs[0]["label"])
print(label)

In [None]:
img

In [None]:
messages = [{
        "role": "user",
        "content": [
            { "type": "image", "image": img},
            { "type": "text", "text": instruction }
        ]
    }]

inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=250)

print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

prompt_len = inputs["input_ids"].shape[-1]
answer_ids = outputs[0, prompt_len:]
answer_text = processor.tokenizer.decode(
    answer_ids, skip_special_tokens=True
).strip()

print(answer_text)

In [None]:
processor

In [None]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

# Push Merged Model to HF

In [None]:
model.push_to_hub_merged("eceunal/insectra-fine-tuned", processor, token = hf_token)

# Push Trainer to HF

In [None]:
trainer.save_model('gemma-fine-tuned')

In [None]:
trainer.push_to_hub('gemma-fine-tuned')

In [None]:
instruction = "Write me a story"

messages = [{
        "role": "user",
        "content": [
            { "type": "text", "text": instruction }
        ]
    }]

inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=250)

print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

prompt_len = inputs["input_ids"].shape[-1]
answer_ids = outputs[0, prompt_len:]
answer_text = processor.tokenizer.decode(
    answer_ids, skip_special_tokens=True
).strip()

print(answer_text)

# Testing

In [None]:
import torch, pandas as pd
from tqdm.auto import tqdm
from transformers import TextStreamer
from PIL import Image
from IPython.display import display



pred_rows = []

for idx, sample in tqdm(enumerate(test), total=len(test)):
    pil_img = next(c for c in sample["messages"][0]["content"]
                   if c["type"]=="image")["image"]
    if isinstance(pil_img, dict):
        pil_img = Image.open(pil_img["path"])
    elif isinstance(pil_img, str):
        pil_img = Image.open(pil_img)
    if pil_img.mode == "L":
        pil_img = pil_img.convert("RGB")

    instruction = "Which insect bite is this if there is a bite? Return only bite name or no bite, no additional comment needed"

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": pil_img},
                {"type": "text", "text": instruction}
            ],
        }
    ]

    inputs = processor.apply_chat_template(
      messages,
      add_generation_prompt=True,
      tokenize=True,
      return_dict=True,
      return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=512)
    print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))
    prompt_len = inputs["input_ids"].shape[-1]
    answer_ids = outputs[0, prompt_len:]  # remove prompt tokens
    answer_text = processor.tokenizer.decode(
              answer_ids, skip_special_tokens=True
    ).strip()

    print("Prediction: " + answer_text + ".   Real Answer: " + sample["messages"][1]["content"][0]["text"])
    pred_rows.append({
        "row_idx":        idx,
        "gold_label":     sample["messages"][1]["content"][0]["text"],
        "raw_prediction": answer_text,
    })

df = pd.DataFrame(pred_rows)
df.to_csv("predictions.csv", index=False)
print(f"✅  Saved {len(df)} rows to predictions.csv")

# Evaluation for Fine-Tuned Gemma 3n

In [None]:
import pandas as pd
import re
from collections import Counter

df = pd.read_csv("predictions.csv")    # columns: prediction, real_answer


# ----- 2  normalise → list-of-tokens ---------------------------------------
def normalise_tokens(txt: str) -> list[str]:
    """
    ‣ lower-cases
    ‣ strips non-alphanumerics
    ‣ splits on whitespace      → returns a list of “clean” tokens
    """
    if pd.isna(txt):
        return []
    cleaned = re.sub(r"[^0-9a-zA-Z ]", " ", txt).lower()
    tokens  = cleaned.split()
    return tokens

df["tok_pred"] = df["raw_prediction"].apply(normalise_tokens)
df["tok_real"] = df["gold_label"].apply(normalise_tokens)


# ----- 3  rule: real-tokens ⊆ pred-tokens  OR  pred-tokens ⊆ real-tokens ----
def subset_or_superset(row) -> bool:
    a, b = set(row.tok_pred), set(row.tok_real)
    return a.issuperset(b) or b.issuperset(a)
    #       ↑ ignores any *extra* words on either side

df["result"] = df.apply(subset_or_superset, axis=1).map({True: "pass",
                                                         False: "fail"})


# ----- 4  quick report ------------------------------------------------------
print(Counter(df["result"]))            # e.g. Counter({'pass': 87, 'fail': 13})
df.to_csv("pred_vs_real_with_result.csv", index=False)

display(df.head())

In [None]:
df = pd.read_csv("pred_vs_real_with_result.csv")

In [None]:
accuracy = (df["result"] == "pass").mean() * 100

result_counts = df["result"].value_counts()

plt.figure(figsize=(8, 4))
plt.bar(["Accuracy"], [accuracy], color="skyblue")
plt.ylim(0, 100)
plt.ylabel("Accuracy (%)")
plt.title("Model Accuracy")
plt.show()

plt.figure(figsize=(8, 4))
plt.bar(result_counts.index, result_counts.values, color=["green", "red"])
plt.ylabel("Count")
plt.title("Pass vs Fail")
plt.show()

In [None]:
print(f"Accuracy: {accuracy:.2f}%")

In [None]:
class_results = df.groupby(['gold_label', 'result']).size().unstack(fill_value=0)

# Plot side-by-side bars for True and False per class
class_results.plot(kind='bar', figsize=(10,5))
plt.xlabel("Gold Label")
plt.ylabel("Count")
plt.title("True vs False Predictions per Class")
plt.xticks(rotation=45, ha='right')
plt.legend(title="Result")
plt.tight_layout()
plt.show()

# Evaluation by Gemini for raw Gemma 3n Answers

In [None]:
from unsloth import FastVisionModel
from peft import PeftModel
import torch

BASE_ID   = "unsloth/gemma-3n-E2B-it"

raw_model, raw_processor = FastVisionModel.from_pretrained(
    BASE_ID,
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
)

In [None]:
raw_pred_rows = []

for idx, sample in tqdm.tqdm(enumerate(test), total=len(test)):
    pil_img = next(c for c in sample["messages"][0]["content"]
                   if c["type"]=="image")["image"]
    if isinstance(pil_img, dict):
        pil_img = Image.open(pil_img["path"])
    elif isinstance(pil_img, str):
        pil_img = Image.open(pil_img)
    if pil_img.mode == "L":
        pil_img = pil_img.convert("RGB")

    instruction = "Which insect bite is this if there is a bite? Return only bite name or no bite, no additional comment needed"

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": pil_img},
                {"type": "text", "text": instruction}
            ],
        }
    ]

    inputs = processor.apply_chat_template(
      messages,
      add_generation_prompt=True,
      tokenize=True,
      return_dict=True,
      return_tensors="pt"
    ).to(model.device)

    outputs = raw_model.generate(**inputs, max_new_tokens=512)
    print(raw_processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))
    prompt_len = inputs["input_ids"].shape[-1]
    answer_ids = outputs[0, prompt_len:]  # remove prompt tokens
    answer_text = raw_processor.tokenizer.decode(
              answer_ids, skip_special_tokens=True
    ).strip()

    print("Prediction: " + answer_text + ".   Real Answer: " + sample["messages"][1]["content"][0]["text"])
    raw_pred_rows.append({
        "row_idx":        idx,
        "gold_label":     sample["messages"][1]["content"][0]["text"],
        "raw_prediction": answer_text,
    })

df = pd.DataFrame(raw_pred_rows)
df.to_csv("predictions_raw_model.csv", index=False)
print(f"✅  Saved {len(df)} rows to predictions.csv")

In [None]:
!pip install -q -U google-generativeai
!pip install -q -U "google-genai>=0.6.0" tqdm

In [None]:
from google.colab import userdata
import google.generativeai as genai

In [None]:
gemini_key = userdata.get('GEMINI_KEY')
MODEL_NAME = "gemini-2.5-flash"

In [None]:
genai.configure(api_key=gemini_key)
gemini_model = genai.GenerativeModel(MODEL_NAME)

In [None]:
import json, time, pandas as pd
from google import genai
from google.genai import types

CSV_PATH = "/content/predictions_raw_model.csv"
df = pd.read_csv(CSV_PATH)

required_cols = {"gold_label", "raw_prediction"}
missing = required_cols - set(df.columns)
assert not missing, f"❌ Missing columns in CSV: {missing}"

client = genai.Client(api_key=gemini_key)


def generate_promt(gt, pred):
    return f"""
      You are a grading assistant.

      INPUT (given at runtime)
      {{
        "prediction": "<model output>",
        "answer":     "<ground-truth label>"
      }}

      TASK
      1. **Normalise** both strings:
        • lower-case
        • trim leading/trailing spaces
        • replace every non-alphabetic character (underscores, punctuation, etc.) with a single space
        • collapse multiple spaces into one

      2. **Tokenise** (split on spaces).

      3. **Singularise** each token with this simple rule:
        – if the token ends with “s” **and** its length > 3, drop the trailing “s”.
        (thus “bites” → “bite”, “fleas” → “flea”, but “wasps” → “wasp”).

      4. Let **P** be the set of tokens from *prediction*, **A** the set from *answer*.
        If **A ⊆ P** **or** **P ⊆ A**, the result is **pass**; otherwise **fail**.

      OUTPUT – exactly and only this JSON schema:

      {{
        "pass": true
      }}

     **Answer:**\n{gt}\n\n**Prediction:**\n{pred}\n
     """

def grade_row(gt: str, pred: str, tries: int = 3):

    prompt = generate_promt(gt, pred)

    for attempt in range(1, tries + 1):
        try:
          response = client.models.generate_content(
          model="gemini-2.5-flash",
          contents=prompt,
          config=types.GenerateContentConfig(
              response_mime_type="application/json",
            ),
          )
          print(response.text)
          data = json.loads(response.text)
          print(data)
          print(f"\nGold: {gt}\nPred: {pred}\n→ {data}")
          return data["pass"]

        except Exception as e:
            print(f"Attempt {attempt}/{tries} failed: {e}")
            if attempt == tries:
                return None, f"✖️ {e}"
            time.sleep(2 * attempt)

In [None]:
scores, gold_label, prediction = [], [],[]
for _, row in tqdm.tqdm(df.iterrows(), total=len(df)):
    try:
        print(row)
        sc = grade_row(gt=row["gold_label"], pred=row["raw_prediction"])
        scores.append(sc)
        gold_label.append(row["gold_label"])
        prediction.append(row["raw_prediction"])

    except Exception as e:
        raise e


df["result"] = scores
df["gold_label"] = gold_label
df["prediction"] = prediction

In [None]:
df.to_csv("pred_vs_real_with_result_raw.csv", index=False)

In [None]:
df["result"]

In [None]:
accuracy = (df["result"] == True).mean() * 100

result_counts = df["result"].value_counts()

plt.figure(figsize=(8, 4))
plt.bar(["Accuracy"], [accuracy], color="skyblue")
plt.ylim(0, 100)
plt.ylabel("Accuracy (%)")
plt.title("Model Accuracy")
plt.show()


result_counts = df['result'].value_counts()

plt.figure(figsize=(6,4))
plt.bar(result_counts.index.astype(str), result_counts.values, color=["red", "green"])
plt.xlabel("Result")
plt.ylabel("Count")
plt.title("True vs False Predictions")
plt.show()

In [None]:
print("Raw model accuracy " + str(accuracy) + "%")

In [None]:
class_results = df.groupby(['gold_label', 'result']).size().unstack(fill_value=0)

# Plot side-by-side bars for True and False per class
class_results.plot(kind='bar', figsize=(10,5))
plt.xlabel("Gold Label")
plt.ylabel("Count")
plt.title("True vs False Predictions per Class")
plt.xticks(rotation=45, ha='right')
plt.legend(title="Result")
plt.tight_layout()
plt.show()

In [None]:
df_raw = pd.read_csv("pred_vs_real_with_result_raw.csv")
df_ft = pd.read_csv("pred_vs_real_with_result.csv")

In [None]:
ticks_df = df_raw[df_raw['gold_label'] == 'ticks']
ticks_accuracy = (ticks_df['result'] == True).mean() * 100

ticks_df_ft = df_ft[df_ft['gold_label'] == 'ticks']
ticks_accuracy_ft = (ticks_df_ft['result'] == "pass").mean() * 100

In [None]:
accuracy_data = pd.DataFrame({
    "Model": ["Raw", "FT"],
    "Accuracy (%)": [ticks_accuracy, ticks_accuracy_ft]
})

# Plot
plt.figure(figsize=(5,4))
plt.bar(accuracy_data["Model"], accuracy_data["Accuracy (%)"])
plt.ylabel("Accuracy (%)")
plt.title("Ticks Accuracy Comparison")
for idx, val in enumerate(accuracy_data["Accuracy (%)"]):
    plt.text(idx, val + 1, f"{val:.2f}%", ha='center')
plt.ylim(0, 100)
plt.show()