[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-ai-edge/model-explorer/blob/main/example_colabs/quick_start.ipynb)

# Google AI Edge Model Explorer
A visualization tool that lets you analyze ML models and graphs, accelerating deployment to on-device targets. [Learn more](https://ai.google.dev/edge/model-explorer).

**Key Features**

* Visualize large models effortlessly
* Find model conversion issues
* Identify optimization targets
* Easy to use intuitive UI

Follow the [installation instructions](https://github.com/google-ai-edge/model-explorer/wiki/5.-Run-in-Colab-Notebook) to add it to your own Colab.

Want to run Model Explorer locally? [Get Started here](https://github.com/google-ai-edge/model-explorer/wiki/1.-Installation)

# Download a copy of the EfficientDet TFLite model

In [None]:
import os
import ssl
import tempfile
import urllib.request
import torch
import torchvision

# Create an SSL context that does not verify certificates.
ssl_context = ssl._create_unverified_context()
tmp_path = tempfile.mkdtemp()
model_path = os.path.join(tmp_path, "model.tflite")

# Download the model
try:
  with urllib.request.urlopen(
      "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/efficientdet.tflite",
      context=ssl_context,
  ) as response:
    data = response.read()

  with open(model_path, "wb") as file:
    file.write(data)

  print("Model downloaded successfully!")

except Exception as e:
  print(f"Failed to download model: {e}")

# Install Model Explorer using pip

In [None]:
!pip install ai-edge-model-explorer

# Faster installation by skipping deps that are included in colab runtime:
# !pip install --no-deps ai-edge-model-explorer-adapter ai-edge-model-explorer

# Visualize the downloaded EfficientDet model

In [None]:
import model_explorer

model_explorer.visualize(model_path)

# Visualize a PyTorch model



In [None]:
# Get mobilnet v2 pytorch model as an example.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)

# Visualize
model_explorer.visualize_pytorch('mobilenet', exported_program=ep)

# Visualize a JAX model (via StableHLO MLIR)

In [None]:
import jax
import jax.numpy as jnp

def my_function(x):
  y = jnp.sin(x)
  z = jnp.cos(x)
  return y * z

inputs = jnp.array(2.0)

# JIT, lower the function, and get the textual representation
#   - jax.jit(fn): Creates a JIT-compiled version of our function.
#   - .lower(inputs): Traces the function with the dummy input to generate
#     the low-level representation without actually executing it.
#   - .as_text(debug_info=True): Converts the lowered representation (StableHLO)
#     to a human-readable string. `debug_info=True` includes source-level
#     location information, which is useful for visualization.
stablehlo_mlir = jax.jit(my_function).lower(inputs).as_text(debug_info=True)

# Write the MLIR to file
file_path = '/content/stablehlo.mlir'
with open(file_path, 'w') as f:
  f.write(stablehlo_mlir)

# Visualize
model_explorer.visualize(file_path)