### Unzip Dataset

In [9]:
!unzip archive.zip

Archive:  archive.zip
  inflating: Garbage_classification_benchmark/final_comprehensive_ground_truth_dataset.csv  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/cardboard_044.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/cardboard_052.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/cardboard_110.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/cardboard_129.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/paper_140.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/paper_156.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/paper_162.jpg  
  inflating: Garbage_classification_benchmark/multiple_material_image/cardboard_paper/paper_181.jpg  
  inflating: Garbage_classification_benchmark/multipl

### Installation

In [6]:
%pip install unsloth

Collecting unsloth
  Downloading unsloth-2025.7.11-py3-none-any.whl.metadata (47 kB)
Collecting unsloth_zoo>=2025.7.11 (from unsloth)
  Downloading unsloth_zoo-2025.7.11-py3-none-any.whl.metadata (8.1 kB)
Collecting torch>=2.4.0 (from unsloth)
  Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Collecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting triton>=3.0.0 (from unsloth)
  Downloading triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.27-py3-none-any.whl.metadata (11 kB)
Collecting transformers!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,>=4.51.3 (from unsloth)
  Downloading transformers-4.54.1-py3-none-any.whl.metadata (41 kB)
Collec

In [7]:
%%capture
# Install latest transformers for Gemma 3N
%pip install --no-deps --upgrade timm # Only for Gemma 3N

### Unsloth

In [None]:
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch

model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/gemma-3n-E4B-it", # or "unsloth/gemma-3n-E2B-it"
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    # token = ""
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.7.11: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    NVIDIA A100 80GB PCIe. Num GPUs = 1. Max memory: 79.151 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

We now add LoRA adapters for parameter efficient finetuning - this allows us to only efficiently train 1% of all parameters.

**[NEW]** We also support finetuning ONLY the vision part of the model, or ONLY the language part. Or you can select both! You can also select to finetune the attention or the MLP layers!

In [2]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


In [3]:
import torch._dynamo.config
torch._dynamo.config.recompile_limit = 512

### Inference before FT

Before we do any finetuning, maybe the vision model already knows how to analyse the images? Let's check if this is the case!

We use Gemma 3N's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`

In [4]:
from PIL import Image

FastVisionModel.for_inference(model) # Enable for inference!

PROMPT_FOR_VISION = (
    "You are a garbage classification assistant. Based on the image, identify and classify all distinct parts of the object. "
    "For each part, determine the type of garbage from the following options: A: Cardboard, B: Glass, C: Metal, D: Paper, E: Plastic, F: Trash. "
    "Your response must be in a JSON format. The JSON should contain a single key, 'material', which holds an array of objects. "
    "Each object in the array must have two keys: 'part_name' (a brief description of the item) and 'answer' (the classification from the provided options, in the format 'A: Cardboard'). "
    "If the image contains multiple distinct parts made of different materials, list each part as a separate object in the 'material' array. "
    "For example, if the image shows a paper coffee cup with a plastic lid, you should output two separate objects in the array. "
    "The cup should be classified as 'D: Paper' and the lid as 'E: Plastic'. "
    "If a part is not classified into a specific category, consider it as 'F: Trash'."
)

image_path = "WhatsApp Image 2025-08-01 at 23.50.15_38d0037a.jpg"

image = Image.open(image_path).convert("RGB").resize((512, 512))

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": PROMPT_FOR_VISION}
    ]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
# Convert the grayscale image to RGB
if image.mode != "RGB":
    image = image.convert("RGB")

inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)

```json
{
  "material": [
    {
      "part_name": "Cup body",
      "answer": "A: Cardboard"
    },
    {
      "part_name": "Lid",
      "answer": "E: Plastic"
    }
  ]
}
```<end_of_turn>


## Data Preparation for Trash Classification


In [5]:
import pandas as pd
from PIL import Image
import os
from sklearn.model_selection import train_test_split
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
from datasets import Dataset, DatasetDict, Features, Value, Sequence, Image as DatasetsImage

In [6]:
DATASET_CSV_PATH = os.path.abspath("Garbage_classification_benchmark/final_comprehensive_ground_truth_dataset.csv")
IMAGE_ROOT_DIR = "Garbage_classification_benchmark"

PROMPT_FOR_VISION = (
    "You are a garbage classification assistant. Based on the image, identify and classify all distinct parts of the object. "
    "For each part, determine the type of garbage from the following options: A: Cardboard, B: Glass, C: Metal, D: Paper, E: Plastic, F: Trash. "
    "Your response must be in a JSON format. The JSON should contain a single key, 'material', which holds an array of objects. "
    "Each object in the array must have two keys: 'part_name' (a brief description of the item) and 'answer' (the classification from the provided options, in the format 'A: Cardboard'). "
    "If the image contains multiple distinct parts made of different materials, list each part as a separate object in the 'material' array. "
    "For example, if the image shows a paper coffee cup with a plastic lid, you should output two separate objects in the array. "
    "The cup should be classified as 'D: Paper' and the lid as 'E: Plastic'. "
    "If a part is not classified into a specific category, consider it as 'F: Trash'."
)

### Read Dataset and Split the Single Material Classes

In [7]:
# Load the dataset
df = pd.read_csv(DATASET_CSV_PATH)

# Step 1: Split by material type
df_single = df[df["material_type"] == "single"].copy()
df_multi = df[df["material_type"] == "multiple"].copy()

In [8]:
# Step 2: Stratified train-test split on single-material by primary_class
single_train, single_test = train_test_split(
    df_single,
    test_size=0.3,
    stratify=df_single["primary_class"],
    random_state=42
)

# Step 3: Random train-test split on multiple-material
multi_train, multi_test = train_test_split(
    df_multi,
    test_size=0.3,
    random_state=42,
    shuffle=True
)

# Step 4: Add split labels
single_train["dataset_split"] = "train"
single_test["dataset_split"] = "test"
multi_train["dataset_split"] = "train"
multi_test["dataset_split"] = "test"

# Step 5: Combine train and test
train_split = pd.concat([single_train, multi_train], ignore_index=True)
test_split = pd.concat([single_test, multi_test], ignore_index=True)

# Step 6: Save to CSV
train_split.to_csv("train_split.csv", index=False)
test_split.to_csv("test_split.csv", index=False)

### Function for formatting training data for Unsloth

In [9]:
# Format example with PIL.Image
def format_data_for_unsloth(example):
    try:
        relative_path = example["relative_path"]
        full_path = os.path.normpath(os.path.join(IMAGE_ROOT_DIR, relative_path))

        print(f"Processing image: {full_path}")
        image = Image.open(full_path).convert("RGB").resize((512, 512))

        assistant_response = example["chatgpt4o_answer"]

        return {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": PROMPT_FOR_VISION},
                    ],
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": assistant_response},
                    ],
                },
            ]
        }

    except (FileNotFoundError, UnidentifiedImageError, OSError, KeyError) as e:
        print(f"❌ Skipping due to error: {e}")
        return None

In [10]:
processed_data = []

# Iterate only over the training set
for _, row in train_split.iterrows():
    result = format_data_for_unsloth(row)
    if result is not None:
        processed_data.append(result)

Processing image: Garbage_classification_benchmark/single_material_image/glass/glass_150.jpg
Processing image: Garbage_classification_benchmark/single_material_image/plastic/plastic_001.jpg
Processing image: Garbage_classification_benchmark/single_material_image/plastic/plastic_129.jpg
Processing image: Garbage_classification_benchmark/single_material_image/cardboard/cardboard_090.jpg
Processing image: Garbage_classification_benchmark/single_material_image/paper/paper_161.jpg
Processing image: Garbage_classification_benchmark/single_material_image/plastic/plastic_131.jpg
Processing image: Garbage_classification_benchmark/single_material_image/metal/metal_268.jpg
Processing image: Garbage_classification_benchmark/single_material_image/plastic/plastic_141.jpg
Processing image: Garbage_classification_benchmark/single_material_image/trash/trash_026.jpg
Processing image: Garbage_classification_benchmark/single_material_image/glass/glass_102.jpg
Processing image: Garbage_classification_bench

In [11]:
train_data, val_data = train_test_split(processed_data, test_size=0.2, random_state=42)

<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.

In [12]:
from unsloth.trainer import UnslothVisionDataCollator
from transformers import EarlyStoppingCallback
from trl import SFTTrainer, SFTConfig

FastVisionModel.for_training(model) # Enable for training!

early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3,  # Number of evaluations with no improvement before stopping
    early_stopping_threshold=0.0  # Minimum improvement to be considered as an improvement
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = UnslothVisionDataCollator(model, tokenizer, resize=512), # Must use!
    train_dataset = train_data, # Use the list of processed samples
    eval_dataset = val_data, # Also use for evaluation if desired
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        # gradient_checkpointing = True,  # reduces VRAM usage at the cost of slightly longer training time.
        warmup_steps = 5,
        # max_steps = 100,
        num_train_epochs = 3, # Set this instead of max_steps for full training runs
        learning_rate = 5e-6,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused", # "adamw_8bit"
        weight_decay = 0.01,
        lr_scheduler_type = "cosine", # or use "linear"
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",     # For Weights and Biases

        # MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
        
        eval_strategy="steps",
        eval_steps=50,                # Adjust as needed for how often to check eval loss
        save_total_limit=2,           # Optional: limit saved checkpoints
        load_best_model_at_end=True,  # Restore best model
        metric_for_best_model="eval_loss",  # Required for early stopping
        greater_is_better=False 
    ),
    callbacks=[early_stopping]
)

In [13]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100 80GB PCIe. Max memory = 79.151 GB.
9.795 GB of memory reserved.


In [14]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 992 | Num Epochs = 3 | Total steps = 372
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 38,420,480 of 7,888,398,672 (0.49% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
50,0.0742,2.312983
100,0.0546,2.31482
150,0.0463,2.319943
200,0.0406,2.321056


Unsloth: Not an error, but Gemma3nForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


In [15]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

1062.4979 seconds used for training.
17.71 minutes used for training.
Peak reserved memory = 11.715 GB.
Peak reserved memory for training = 1.92 GB.
Peak reserved memory % of max memory = 14.801 %.
Peak reserved memory for training % of max memory = 2.426 %.


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [23]:
import os

save_dir = "qizunlee/gemma3n_E4B_it_ft_3RGarbageClassification_merged"
os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist

In [None]:
# Merge the LoRA adapter with the base model
model.push_to_hub_merged("qizunlee/gemma3n_E4B_it_ft_3RGarbageClassification_merged", tokenizer, token = "<TOKEN>")

No files have been modified since last commit. Skipping to prevent empty commit.


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  /tmp/tmpxu5gr_7q/tokenizer.model      : 100%|##########| 4.70MB / 4.70MB            

  /tmp/tmpxu5gr_7q/tokenizer.json       : 100%|##########| 33.4MB / 33.4MB            

No files have been modified since last commit. Skipping to prevent empty commit.


Found HuggingFace hub cache directory: /home/azureuser/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Downloading safetensors index for unsloth/gemma-3n-e4b-it...


model.safetensors.index.json: 0.00B [00:00, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.
Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...7q/model-00001-of-00004.safetensors:   2%|1         | 50.3MB / 3.08GB            

Unsloth: Merging weights into 16bit:  25%|██▌       | 1/4 [00:35<01:47, 35.90s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...7q/model-00002-of-00004.safetensors:   2%|2         |  101MB / 4.97GB            

Unsloth: Merging weights into 16bit:  50%|█████     | 2/4 [01:20<01:22, 41.16s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...7q/model-00003-of-00004.safetensors:   1%|1         | 66.5MB / 4.99GB            

Unsloth: Merging weights into 16bit:  75%|███████▌  | 3/4 [02:20<00:49, 49.52s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...7q/model-00004-of-00004.safetensors:   0%|          | 36.1kB / 2.66GB            

Unsloth: Merging weights into 16bit: 100%|██████████| 4/4 [02:52<00:00, 43.14s/it]


In [None]:
# Merge the LoRA adapter with the base model
# model = model.merge_and_unload()

# Save the model (lora_adapter) locally
model.save_pretrained("gemma3n_E4B_it_ft_3RGarbageClassification")
tokenizer.save_pretrained("gemma3n_E4B_it_ft_3RGarbageClassification")

# Saving the model to HF
model.push_to_hub("qizunlee/gemma3n_E4B_it_ft_3RGarbageClassification",
                  token = "<TOKEN>")
tokenizer.push_to_hub("qizunlee/gemma3n_E4B_it_ft_3RGarbageClassification",
                      token = "<TOKEN>")

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...ification/adapter_model.safetensors:   0%|          | 9.88kB /  154MB            

Saved model to https://huggingface.co/qizunlee/gemma3n_E4B_it_ft_3RGarbageClassification


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...rbageClassification/tokenizer.model: 100%|##########| 4.70MB / 4.70MB            

  ...arbageClassification/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            

<a name="Inference"></a>
### Inference
Let's run the model! You can change the instruction and input - leave the output blank!

In [None]:
FastVisionModel.for_inference(model) # Enable for inference!

PROMPT_FOR_VISION = (
    "You are a garbage classification assistant. Based on the image, identify and classify all distinct parts of the object. "
    "For each part, determine the type of garbage from the following options: A: Cardboard, B: Glass, C: Metal, D: Paper, E: Plastic, F: Trash. "
    "Your response must be in a JSON format. The JSON should contain a single key, 'material', which holds an array of objects. "
    "Each object in the array must have two keys: 'part_name' (a brief description of the item) and 'answer' (the classification from the provided options, in the format 'A: Cardboard'). "
    "If the image contains multiple distinct parts made of different materials, list each part as a separate object in the 'material' array. "
    "For example, if the image shows a paper coffee cup with a plastic lid, you should output two separate objects in the array. "
    "The cup should be classified as 'D: Paper' and the lid as 'E: Plastic'. "
    "If a part is not classified into a specific category, consider it as 'F: Trash'."
)

image_path = "WhatsApp Image 2025-08-01 at 23.50.15_38d0037a.jpg"

image = Image.open(image_path).convert("RGB").resize((512, 512))

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": PROMPT_FOR_VISION}
    ]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
# Convert the grayscale image to RGB
if image.mode != "RGB":
    image = image.convert("RGB")

inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)

In [None]:
# Conversion to GGUF format
# https://huggingface.co/spaces/ggml-org/gguf-my-repo