<a href="https://colab.research.google.com/github/mayakaripel/FieldScribe-Gemma3n/blob/main/Paligemma_To_Tflite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ==========================================================
# CELL 1: INSTALL DEPENDENCIES (SIMPLE & MODERN)
# ==========================================================
print("--- Installing required libraries ---")

!pip install -q --upgrade pip

# Install the main libraries and let pip resolve the best versions
!pip install -q "transformers>=4.41.0" # Ensure a recent version of transformers
!pip install -q tensorflow torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q huggingface-hub accelerate bitsandbytes
!pip install -q "ai-edge-torch"

# This is a preventative measure for the Protobuf error we've seen before.
# It ensures a compatible version is used.
!pip install -q "protobuf==4.25.3"

print("✅ Libraries installed. A runtime restart is now required.")

--- Installing required libraries ---
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.19.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3, but you have protobuf 6.32.1 which is incompatible.
grpcio-status 1.71.2 requires protobuf<6.0dev,>=5.26.1, but you have protobuf 6.32.1 which is incompatible.
google-ai-generativelanguage 0.6.15 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 6.32.1 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ai-edge-tensorflow

In [2]:
from transformers import PaliGemmaForConditionalGeneration


# Hugging Face model repo (you can replace with your fine-tuned checkpoint)
model_name = "google/paligemma-3b-pt-224"


# Load model in eval mode
paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
model_name,
trust_remote_code=True
).eval()

ImportError: cannot import name 'runtime_version' from 'google.protobuf' (/usr/local/lib/python3.12/dist-packages/google/protobuf/__init__.py)

In [None]:
# ==========================================================
# CELL 2.5: INTROSPECTION - LET'S LOOK INSIDE THE MODEL
# ==========================================================
from transformers import PaliGemmaForConditionalGeneration

# (Make sure the model is downloaded and you are logged in)
model_name = "google/paligemma-3b-pt-224"

# Load the model just to inspect it
inspection_model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True)

# Print the top-level components of the model
print(inspection_model)

In [None]:
# ==========================================================
# CELL 3: THE FINAL, CORRECTED TFLITE WRAPPER CLASS
# ==========================================================
import torch
from transformers import PaliGemmaForConditionalGeneration

class PaliGemmaForTFLite(torch.nn.Module):
    def __init__(self, model_path: str):
        super().__init__()
        # Load the base model
        self.paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.float16 # Use float16 for efficiency
        ).eval()

    def _prepare_input_embeds(self, pixel_values: torch.Tensor, input_ids: torch.Tensor):
        # This part is correct
        vis_out = self.paligemma.model.vision_tower(pixel_values=pixel_values)
        vis_emb = vis_out.last_hidden_state
        vis_proj = self.paligemma.model.multi_modal_projector(vis_emb)
        text_emb_layer = self.paligemma.model.language_model.get_input_embeddings()
        text_emb = text_emb_layer(input_ids)
        combined = torch.cat([vis_proj, text_emb], dim=1)
        return combined

    def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor):
        inputs_embeds = self._prepare_input_embeds(pixel_values, input_ids)

        # Run the combined input through the language model "brain"
        base_model_outputs = self.paligemma.model.language_model(inputs_embeds=inputs_embeds)

        # --- THIS IS THE FIX ---
        # The output from the base language_model is in 'last_hidden_state', not 'logits'.
        last_hidden_state = base_model_outputs.last_hidden_state

        # We still need to run it through the final "head" layer to get the actual logits.
        logits = self.paligemma.lm_head(last_hidden_state)
        # ---------------------

        # Return the final predictions
        return logits

print("✅ Final Corrected PaliGemmaForTFLite class is defined.")

In [None]:
# ==========================================================
# CELL 4: INSTANTIATE AND TRACE THE MODEL (FINAL VERSION)
# ==========================================================
import torch # <-- THIS IS THE FIX

# (Make sure the model_name variable is still defined from a previous cell)
model_name = "google/paligemma-3b-pt-224"

# 1. Create an instance of your corrected class
print("--- Instantiating the TFLite-ready model ---")
# Assuming the 'PaliGemmaForTFLite' class was defined in a previous cell
traceable_model = PaliGemmaForTFLite(model_name)
print("✅ Model instantiated.")

# 2. Create dummy inputs
print("--- Preparing dummy inputs for tracing ---")
H, W = 224, 224
# Use float16 to match the model's dtype for better consistency
dummy_image = torch.randn(1, 3, H, W, dtype=torch.float16)
dummy_ids = torch.randint(0, 1000, (1, 16), dtype=torch.long)
print("✅ Dummy inputs created.")

# 3. Trace the model, disabling the strict comparison check
print("--- Tracing the model (with check_trace=False) ---")
traced = torch.jit.trace(
    traceable_model,
    (dummy_image, dummy_ids),
    check_trace=False
)
print("✅ Model traced successfully.")

# 4. Save the traced "recipe" to a file
traced.save("paligemma_traced.pt")
print("✅✅✅ Model traced and saved to paligemma_traced.pt. Ready for the final conversion step!")

In [None]:
# ==========================================================
# CELL 5: THE FINAL CONVERSION TO TFLITE
# ==========================================================
import os

# Create a dummy folder for calibration data as required by the tool
os.makedirs("calibration_images", exist_ok=True)

print("--- 🚀 Starting the final conversion to TFLite... This will take a few minutes. ---")

# Run the AI Edge Torch conversion command
!python -m ai_edge_torch.convert \
  --input_model paligemma_traced.pt \
  --input_shapes "1,3,224,224;1,16" \
  --output_format tflite \
  --output_file paligemma.tflite \
  --quantize int8 \
  --calibration_dataset calibration_images/

print("🎉🎉🎉 CONGRATULATIONS! TFLite conversion complete! You can now download paligemma.tflite! 🎉🎉🎉")