<a href="https://colab.research.google.com/github/mohammed1916/tourmate/blob/main/nb/Gemma3N_(4B)-Vision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
# %%capture
# import os
# if "COLAB_" not in "".join(os.environ.keys()):
#     !pip install unsloth
# else:
#     # Do this only in Colab notebooks! Otherwise use pip install unsloth
#     !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
#     !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
#     !pip install --no-deps unsloth

In [2]:
# %%capture
# # Install latest transformers for Gemma 3N
# !pip install --no-deps --upgrade timm # Only for Gemma 3N

In [3]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Tue_May_27_02:24:01_Pacific_Daylight_Time_2025
Cuda compilation tools, release 12.9, V12.9.86
Build cuda_12.9.r12.9/compiler.36037853_0


In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print("GPU cache cleared and peak memory stats reset.")

GPU cache cleared and peak memory stats reset.


### Unsloth

In [9]:
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", # Llama 3.2 vision support
    "unsloth/Llama-3.2-11B-Vision-bnb-4bit",
    "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", # Can fit in a 80GB card!
    "unsloth/Llama-3.2-90B-Vision-bnb-4bit",

    "unsloth/Pixtral-12B-2409-bnb-4bit",              # Pixtral fits in 16GB!
    "unsloth/Pixtral-12B-Base-2409-bnb-4bit",         # Pixtral base model

    "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit",          # Qwen2 VL support
    "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
    "unsloth/Qwen2-VL-72B-Instruct-bnb-4bit",

    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit",      # Any Llava variant works!
    "unsloth/llava-1.5-7b-hf-bnb-4bit",
] # More models at https://huggingface.co/unsloth

In [10]:
model, processor = FastVisionModel.from_pretrained(
    "unsloth/gemma-3n-E4B",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    # llm_int8_enable_fp32_cpu_offload=True,
    device_map="cuda"
)

  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"{DEVICE_TYPE}:{i}") for i in range(n_gpus)])


==((====))==  Unsloth 2025.7.11: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    NVIDIA GeForce RTX 4060 Laptop GPU. Num GPUs = 1. Max memory: 7.996 GB. Platform: Windows.
O^O/ \_/ \    Torch: 2.7.1+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


Loading checkpoint shards: 100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.

**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!

In [11]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 32,                           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 32,                  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,               # We support rank stabilized LoRA
    loftq_config = None,               # And LoftQ
    target_modules = "all-linear",    # Optional now! Can specify a list if needed
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


<a name="Data"></a>
### Data Prep

In [12]:
from datasets import load_dataset, load_from_disk
import os

local_path = "local_pirm_bicubic_x2"

if os.path.exists(local_path):
    print(f"Loading dataset from local disk: {local_path}")
    dataset = load_from_disk(local_path)
else:
    print("Local dataset not found. Downloading from Hugging Face...")
    try:
        dataset = load_dataset("eugenesiow/PIRM", "bicubic_x2", split="validation")
        dataset.save_to_disk(local_path)
        print(f"Dataset saved locally to '{local_path}'.")
    except Exception as e:
        print(f"Could not load 'eugenesiow/PIRM': {e}")
        dataset = None
        print("Failed to load a suitable upscaling dataset.")

if dataset is not None:
    print("Dataset loaded successfully:")
    print(dataset)
    print("\nSample keys:", dataset[0].keys())

Loading dataset from local disk: local_pirm_bicubic_x2
Dataset loaded successfully:
Dataset({
    features: ['hr', 'lr'],
    num_rows: 100
})

Sample keys: dict_keys(['hr', 'lr'])


In [13]:
dataset["hr"][0]

'/storage/hf-datasets-cache/all/datasets/49041657089242-config-parquet-and-info-eugenesiow-PIRM-26573a06/downloads/extracted/1294fc48f51536a18237eefb895c701560eb5fd13b7cfffe76b582034db420f4/PIRM_valid_HR/1.png'

In [14]:
def preprocess_upscaling_dataset(sample):
    if "lr" not in sample or "hr" not in sample:
        # Handle cases where keys might be different
        print("Warning: 'lr' or 'hr' keys not found in sample. Skipping.")
        return None

    # low_res_image is the input for the model
    low_res_image = sample["lr"]

    # high_res_image is the ground truth. The model can't output an image directly
    # so we use a placeholder text. In a true upscaling pipeline, this would be
    # handled by a different model architecture, not a VLM.
    high_res_image_representation = "upscaled image." # Placeholder text

    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Upscale this image to a higher resolution."},
                {"type": "image", "image": low_res_image},
            ],
        },
        {"role": "assistant", "content": [
            {"type": "text", "text": high_res_image_representation}
        ]},
    ]
    return {"messages": conversation}

In [15]:
if dataset is not None:
    processed_dataset = dataset.map(preprocess_upscaling_dataset, remove_columns=dataset.column_names)
    print("Dataset successfully processed:")
    print(processed_dataset)

Map: 100%|██████████| 100/100 [00:00<00:00, 8619.26 examples/s]

Dataset successfully processed:
Dataset({
    features: ['messages'],
    num_rows: 100
})





In [16]:
# instruction = "Write the LaTeX representation for this image."

# def convert_to_conversation(sample):
#     conversation = [
#         {
#             "role": "user",
#             "content": [
#                 {"type": "text", "text": instruction},
#                 {"type": "image", "image": sample["image"]},
#             ],
#         },
#         {"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
#     ]
#     return {"messages": conversation}
# pass

In [18]:
# converted_dataset = [convert_to_conversation(sample) for sample in dataset]

In [None]:
# converted_dataset[0]

{'messages': [{'role': 'user',
   'content': [{'type': 'text',
     'text': 'Write the LaTeX representation for this image.'},
    {'type': 'image',
     'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=160x40>}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': '{ \\frac { N } { M } } \\in { \\bf Z } , { \\frac { M } { P } } \\in { \\bf Z } , { \\frac { P } { Q } } \\in { \\bf Z }'}]}]}

Lets take the Gemma 3n instruction chat template and use it in our base model

In [19]:
from unsloth import get_chat_template

processor = get_chat_template(
    processor,
    "gemma-3n"
)

Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before.

In [None]:
# Import necessary libraries
import torch
from transformers import TextStreamer
from unsloth import FastVisionModel, get_chat_template

# NOTE: This code assumes that the model, processor, and dataset
# have already been loaded and prepared in the preceding steps.
# Specifically, 'model' and 'processor' should be the FastVisionModel
# and processor objects, and 'dataset' should be the upscaling dataset
# from eugenesiow/PIRM.

# Set the model to inference mode.
FastVisionModel.for_inference(model)

# 1. Select a low-resolution image from the dataset.
# The `eugenesiow/PIRM` dataset uses 'lr' for low-resolution images.

# 2. Define the instruction for the upscaling task.
# This instruction guides the model's behavior.
instruction = "Upscale this image to a higher resolution."

# 3. Construct the conversational message format.
# The user's message contains both the text instruction and the low-resolution image.
messages = [
    {
        "role": "user",
        "content": [{"type": "image"}, {"type": "text", "text": instruction}],
    }
]

# 4. Apply the chat template to the messages.
# This formats the conversation into the model's expected input format.
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

# 5. Process the image and text into tensors for the model.
# FIX: The `image` variable is now passed inside a list, `[image]`.
# This resolves the `ValueError` by providing the expected input format.
inputs = processor(
    dataset["lr"],
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

# 6. Initialize a TextStreamer to print the model's output in real-time.
text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)

# 7. Generate the output from the model.
# The generation parameters are kept the same as they are general purpose.
result = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=128,
    use_cache=True,
    temperature=1.0,
    top_p=0.95,
    top_k=64
)

# NOTE: As discussed, this model is a vision-to-text model. It will
# generate a textual response (e.g., "Image restoration complete.") rather
# than a new image file. A true upscaling solution requires a different
# model architecture specifically designed for image-to-image tasks.


ValueError: Invalid input type. Must be a single image, a list of images, or a list of batches of images.

In [None]:
# FastVisionModel.for_inference(model)  # Enable for inference!

# image = dataset[2]["image"]
# instruction = "Write the LaTeX representation for this image."

# messages = [
#     {
#         "role": "user",
#         "content": [{"type": "image"}, {"type": "text", "text": instruction}],
#     }
# ]
# input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
# inputs = processor(
#     image,
#     input_text,
#     add_special_tokens=False,
#     return_tensors="pt",
# ).to("cuda")

# from transformers import TextStreamer

# text_streamer = TextStreamer(processor, skip_prompt=True)
# result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
#                         use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)

  number="1" 
  count_sym=1 
  count_sym_arg=1 
  count_sym_arg_arg=1 
  count_sym_arg_arg_arg=1 
  count_sym_arg_arg_arg_arg=1 
  count_sym_arg_arg_arg_arg_arg=1 
  count_sym_arg_arg_arg_arg_arg_arg=1 
  count_sym_arg_arg_arg_arg_arg_arg_arg=1 
  


You can see it's absolutely terrible! It doesn't follow instructions at all

<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup.

In [None]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    train_dataset=converted_dataset,
    processing_class=processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor),
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,

        # use reentrant checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
        warmup_ratio = 0.03,
        max_steps = 60,
        #num_train_epochs = 2,          # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",             # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)

Unsloth: Model does not have a default image size - using 512


In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
5.416 GB of memory reserved.


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 68,686 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 76,840,960 of 7,926,819,152 (0.97% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,6.085
2,6.6772
3,5.7421
4,6.6135
5,6.2421
6,6.3108
7,5.6542
8,5.6249
9,6.5855
10,3.0446


In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

760.6053 seconds used for training.
12.68 minutes used for training.
Peak reserved memory = 6.061 GB.
Peak reserved memory for training = 0.645 GB.
Peak reserved memory % of max memory = 41.117 %.
Peak reserved memory for training % of max memory = 4.376 %.


<a name="Inference"></a>
### Inference
Let's run the model! You can modify the instruction and input—just leave the output blank.

We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`.

In [None]:
FastVisionModel.for_inference(model)  # Enable for inference!

image = dataset[10]["image"]
instruction = "Write the LaTeX representation for this image."

messages = [
    {
        "role": "user",
        "content": [{"type": "image"}, {"type": "text", "text": instruction}],
    }
]

input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(processor, skip_prompt=True)
result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                        use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)

[ [ B _ { n } ^ { + } , b _ { 2 } ^ { + } ] = n B _ { n } ^ { + } , \quad [ [ B _ { n } ^ { - } , b _ { 2 } ^ { + } ] , b _ { 2 } ^ { - } ] = n B _ { n } ^ { - } .
<eos>


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("lora_model")  # Local saving
processor.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# processor.push_to_hub("your_name/lora_model", token = "...") # Online saving

['lora_model/processor_config.json']

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
    from unsloth import FastVisionModel

    model, processor = FastVisionModel.from_pretrained(
        model_name="lora_model",  # YOUR MODEL YOU USED FOR TRAINING
        load_in_4bit=True,  # Set to False for 16bit LoRA
    )
    FastVisionModel.for_inference(model)  # Enable for inference!

FastVisionModel.for_inference(model)  # Enable for inference!

sample = dataset[1]
image = sample["image"].convert("RGB")
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": sample["text"],
            },
            {
                "type": "image",
            },
        ],
    },
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)

### Saving to float16 for VLLM
ref:
https://huggingface.co/settings/tokens

In [None]:
# Select ONLY 1 to save! (Both not needed!)

# Save locally to 16bit
if True: model.save_pretrained_merged("unsloth_finetune", processor,)

# To export and save to your Hugging Face account
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", processor, token = "PUT_HERE")

In [None]:
from datasets import load_dataset
import requests

# Attempt to find a public image restoration dataset on Hugging Face.
# Searching the hub for "image restoration" or similar tags.
# Let's try a dataset that is known to be public and image-based,
# like a small subset or a different task that can be adapted if necessary.
# The "Reflectance_and_Synthetically_Generated_Shading" dataset seems relevant and might be public.
try:
    dataset = load_dataset("AIML-Lab/Reflectance_and_Synthetically_Generated_Shading", split="train")
except Exception as e:
    print(f"Could not load 'AIML-Lab/Reflectance_and_Synthetically_Generated_Shading': {e}")
    dataset = None
    print("Failed to load a suitable image restoration dataset.")

if dataset is not None:
    print("Dataset loaded successfully:")
    print(dataset)

    # Examine a sample to determine the keys for degraded and clean images.
    if len(dataset) > 0:
        sample = dataset[0]
        print("\nSample structure:")
        print(sample)
    else:
        print("\nDataset is empty.")
        sample = None

    if sample is not None:
        # Define the preprocessing function based on the sample structure.
        # Assuming keys like 'input_image' and 'output_image' or similar based on the dataset name.
        # We will adapt this based on the actual keys observed in the sample structure printout.
        # For now, using placeholder keys 'input_image' and 'output_image'.

        def preprocess_image_restoration(sample):
            try:
                degraded_image = sample["input_image"] # Placeholder key
                clean_image = sample["output_image"]   # Placeholder key
            except KeyError as e:
                print(f"KeyError: Missing expected image key in sample: {e}. Please check dataset structure.")
                # Fallback or re-examine dataset keys.
                # If keys are different, manually inspect sample output and correct.
                # For example, if keys are 'image1' and 'image2':
                # degraded_image = sample['image1']
                # clean_image = sample['image2']
                return None # Return None for samples that cannot be processed

            # Convert images to RGB if they are not already, to ensure compatibility
            if degraded_image.mode != 'RGB':
                degraded_image = degraded_image.convert('RGB')
            if clean_image.mode != 'RGB':
                clean_image = clean_image.convert('RGB')


            # Use a placeholder text representation for the clean image.
            # A more advanced approach would be needed for actual image generation.
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Restore this image to its clean version."},
                        {"type": "image", "image": degraded_image},
                    ],
                },
                {"role": "assistant", "content": [{"type": "text", "text": "Image restoration complete."}]}, # Placeholder text
            ]
            return {"messages": conversation}

        # Apply the preprocessing function and filter out any None results
        try:
            processed_dataset = dataset.map(preprocess_image_restoration).filter(lambda x: x is not None)
            print("\nDataset successfully processed:")
            print(processed_dataset)
        except Exception as e:
            print(f"Error during dataset mapping: {e}")
    else:
        print("Cannot preprocess empty dataset or retrieve sample.")


Could not load 'AIML-Lab/Reflectance_and_Synthetically_Generated_Shading': Dataset 'AIML-Lab/Reflectance_and_Synthetically_Generated_Shading' doesn't exist on the Hub or cannot be accessed.
Failed to load a suitable image restoration dataset.


In [None]:
from datasets import load_dataset

# Try another publicly available dataset with paired images.
# Searching for datasets tagged with 'image-to-image' or similar.
# The 'raiven/coco-stuff-filtered' dataset seems to contain image pairs (image and segmented image).
# While not strictly image restoration, it has paired images that can be used to test the preprocessing logic.
try:
    dataset = load_dataset("raiven/coco-stuff-filtered", split="train")
    print("Dataset loaded successfully:")
    print(dataset)

    # Examine a sample to determine the keys for input and output images.
    if len(dataset) > 0:
        sample = dataset[0]
        print("\nSample structure:")
        print(sample)
    else:
        print("\nDataset is empty.")
        sample = None

    if sample is not None:
        # Based on the dataset name and potential structure, let's assume keys 'image' and 'segmented_image'.
        # We will confirm this from the printed sample structure and adjust if needed.
        # For this task, we will treat 'image' as the degraded input and 'segmented_image' as the target output
        # to demonstrate the preprocessing structure for paired images, even if the task is different.

        def preprocess_image_restoration(sample):
            try:
                degraded_image = sample["image"] # Assuming 'image' is the input
                clean_image = sample["segmented_image"]   # Assuming 'segmented_image' is the output
            except KeyError as e:
                print(f"KeyError: Missing expected image key in sample: {e}. Please check dataset structure.")
                return None

            # Convert images to RGB if they are not already
            if degraded_image.mode != 'RGB':
                degraded_image = degraded_image.convert('RGB')
            if clean_image.mode != 'RGB':
                clean_image = clean_image.convert('RGB')

            # Use a placeholder text representation for the clean image.
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Process this image."}, # Generic instruction
                        {"type": "image", "image": degraded_image},
                    ],
                },
                {"role": "assistant", "content": [{"type": "text", "text": "Processed image representation."}]}, # Placeholder text
            ]
            return {"messages": conversation}

        # Apply the preprocessing function and filter out any None results
        try:
            processed_dataset = dataset.map(preprocess_image_restoration).filter(lambda x: x is not None)
            print("\nDataset successfully processed:")
            print(processed_dataset)
        except Exception as e:
            print(f"Error during dataset mapping: {e}")
    else:
        print("Cannot preprocess empty dataset or retrieve sample.")

except Exception as e:
    print(f"Could not load dataset 'raiven/coco-stuff-filtered': {e}")
    dataset = None
    print("Failed to load a suitable paired image dataset.")


Could not load dataset 'raiven/coco-stuff-filtered': Dataset 'raiven/coco-stuff-filtered' doesn't exist on the Hub or cannot be accessed.
Failed to load a suitable paired image dataset.
