# Flax Runtime-Only Setup and Verification

This notebook trims the repo to only what's required to run Flax as a library and verifies the installation by importing modules and running small examples.

Goals:
- Keep all code under `flax/` and `flaxlib_src/` untouched (no modules removed).
- Remove non-runtime assets (docs, tests, examples, CI, images, notebooks inside code trees, caches).
- Install locally in editable mode and verify imports.
- Run minimal Linen and NNX examples, plus serialization and TrainState checks.

Note: We do not remove any Python modules under `flax/`. The only internal trimming we do is safe housekeeping: delete `__pycache__` and `.ipynb` files within the package, which do not affect runtime.

## What We Keep vs Remove

- Keep:
  - `flax/` (all Python runtime code; no modules trimmed)
  - `flaxlib_src/` (Rust/C bindings used by Flax)
  - `pyproject.toml`, `LICENSE`, `README.md` (package + legal info)
- Remove (non-runtime):
  - `docs/`, `docs_nnx/`, `examples/`, `images/`, `benchmarks/`, `tests/`, `.github/`
  - Any `.ipynb` inside `flax/` (developer notes) and `__pycache__/`

> Rationale: These removed items are for documentation, testing, examples, CI, or development. They are not needed to import or run Flax in your environment. Keeping the entire `flax/` tree ensures no functionality is lost. The only internal cleanup is deleting caches and notebooks which do not affect runtime.

In [1]:
# Safety Check: Locate Non-Essential Folders
import os
from pathlib import Path

ROOT = Path("/Applications/CODES/DL - Tensorflow/flax").resolve()
print("Project root:", ROOT)

to_remove = [
    ROOT/"docs",
    ROOT/"docs_nnx",
    ROOT/"examples",
    ROOT/"images",
    ROOT/"benchmarks",
    ROOT/"tests",
    ROOT/".github",
]

existing = [p for p in to_remove if p.exists()]
print("Will consider removing:")
for p in existing:
    print(" -", p)
if not existing:
    print("(No non-essential folders found; repo may already be trimmed.)")

Project root: /Applications/CODES/DL - Tensorflow/flax
Will consider removing:
(No non-essential folders found; repo may already be trimmed.)


In [2]:
# Optional Cleanup: Remove Non-Essential Folders Safely
import shutil

DO_REMOVE = False  # set to True to actually delete
for p in to_remove:
    if p.exists():
        if DO_REMOVE:
            print("Removing:", p)
            shutil.rmtree(p)
        else:
            print("Would remove:", p)
print("Done.")

Done.


In [3]:
# Housekeeping inside `flax/`: remove caches + notebooks (no modules removed)
from pathlib import Path

PKG = ROOT/"flax"
ipynbs = list(PKG.rglob("*.ipynb"))
caches = list(PKG.rglob("__pycache__"))
print(f"Found {len(ipynbs)} notebooks and {len(caches)} cache dirs under flax/.")
for p in ipynbs:
    print("Would remove notebook:", p)
for p in caches:
    print("Would remove cache dir:", p)

DO_INTERNAL_CLEAN = False  # set True to apply housekeeping
if DO_INTERNAL_CLEAN:
    for p in ipynbs:
        p.unlink()
    import shutil
    for p in caches:
        shutil.rmtree(p)
    print("Housekeeping done.")
else:
    print("Housekeeping dry-run only; set DO_INTERNAL_CLEAN=True to apply.")

Found 1 notebooks and 2 cache dirs under flax/.
Would remove notebook: /Applications/CODES/DL - Tensorflow/flax/flax/core/flax_functional_engine.ipynb
Would remove cache dir: /Applications/CODES/DL - Tensorflow/flax/flax/__pycache__
Would remove cache dir: /Applications/CODES/DL - Tensorflow/flax/flax/core/__pycache__
Housekeeping dry-run only; set DO_INTERNAL_CLEAN=True to apply.


In [None]:
# Environment Setup: install this package in editable mode (with fallback)
import sys, os, subprocess, shlex, importlib
from pathlib import Path

ROOT = Path("/Applications/CODES/DL - Tensorflow/flax").resolve()
print("Using project root:", ROOT)

PACKAGING_FILES = [ROOT / "pyproject.toml", ROOT / "setup.py", ROOT / "setup.cfg"]


def run(cmd):
    print("$", cmd)
    proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(proc.stdout)
    return proc.returncode

print("Python executable:", sys.executable)

# Install common runtime deps used later in this notebook
try:
    run(f"{sys.executable} -m pip install --upgrade pip")
    # Include TF/TFDS so later dataset and TF Hub sections work out-of-the-box
    run(f"{sys.executable} -m pip install --upgrade jax jaxlib numpy optax tensorflow tensorflow-datasets")
except Exception as e:
    print("Dependency install warning:", e)

installed = False
if any(p.exists() for p in PACKAGING_FILES):
    rc = run(f"{sys.executable} -m pip install -e {shlex.quote(str(ROOT))}")
    if rc == 0:
        installed = True
        print("Editable install completed.")
    else:
        print("Editable install failed; falling back to source path import.")
else:
    print("No packaging metadata found in project root (no pyproject/setup). Using source path import.")

# Fallback: import directly from source tree by putting project root on sys.path
if not installed:
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))
    try:
        import flax as _flax_check
        print("Imported flax from source path:", Path(_flax_check.__file__).resolve())
    except Exception as e:
        print("Source-path import of flax failed:", type(e).__name__, str(e)[:300])
        raise

# Final sanity print
try:
    import flax
    ver = getattr(flax, "__version__", "(no __version__)")
    print("flax import OK; version:", ver)
except Exception as e:
    print("flax import failed even after fallback:", type(e).__name__, str(e)[:300])
    raise

Using project root: /Applications/CODES/DL - Tensorflow/flax
/opt/anaconda3/bin/python
$ /opt/anaconda3/bin/python -m pip install --upgrade pip

$ /opt/anaconda3/bin/python -m pip install --upgrade jax jaxlib numpy optax

$ /opt/anaconda3/bin/python -m pip install --upgrade jax jaxlib numpy optax
Collecting numpy
  Using cached numpy-2.3.4-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting optax
  Using cached optax-0.2.6-py3-none-any.whl.metadata (7.6 kB)
Collecting chex>=0.1.87 (from optax)
  Using cached chex-0.1.91-py3-none-any.whl.metadata (18 kB)
Collecting absl-py>=0.7.1 (from optax)
  Using cached absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting toolz>=1.0.0 (from chex>=0.1.87->optax)
  Using cached toolz-1.1.0-py3-none-any.whl.metadata (5.1 kB)
Using cached optax-0.2.6-py3-none-any.whl (367 kB)
Using cached chex-0.1.91-py3-none-any.whl (100 kB)
Using cached absl_py-2.3.1-py3-none-any.whl (135 kB)
Using cached toolz-1.1.0-py3-none-any.whl (58 kB)
Installin

RuntimeError: Command failed: /opt/anaconda3/bin/python -m pip install -e '/Applications/CODES/DL - Tensorflow/flax'

In [None]:
# Import Verification: core modules
import importlib, pkgutil, traceback
mods = [
    "flax",
    "flax.core",
    "flax.linen",
    "flax.training",
    "flax.serialization",
    "flax.traverse_util",
    "flax.struct",
    "flax.jax_utils",
    # Experimental / optional:
    "flax.nnx",
    # Avoid importing tensorboard utilities by default (optional deps).
]
results = {}
for m in mods:
    try:
        importlib.import_module(m)
        results[m] = "OK"
    except Exception as e:
        results[m] = f"FAIL: {e.__class__.__name__}: {e}"
for k,v in results.items():
    print(f"{k:25s} -> {v}")

# Optionally attempt to discover and import subpackages under flax/ (best-effort)
print("\nBest-effort import of subpackages (skips known optional/deps-heavy areas)...")
import flax, types
skips = {"metrics", "testing"}
for f in pkgutil.iter_modules(flax.__path__, prefix="flax."):
    name = f.name.split(".")[-1]
    if name in skips:
        print(f"skip {f.name}")
        continue
    try:
        importlib.import_module(f.name)
        print("ok  ", f.name)
    except Exception as e:
        print("fail", f.name, "->", e.__class__.__name__, str(e)[:140])

In [None]:
# Linen Example: simple MLP forward pass
import jax
import jax.numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    features: tuple
    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            x = nn.relu(x)
        x = nn.Dense(self.features[-1])(x)
        return x

key = jax.random.key(0)
x = jax.random.normal(key, (4, 8))
model = MLP(features=(16, 4))
params = model.init(key, x)
y = model.apply(params, x)
print("Output shape:", y.shape)

In [None]:
# Serialization roundtrip
from flax import serialization

state_dict = serialization.to_state_dict(params)
restored = serialization.from_state_dict(params, state_dict)
# sanity: compare a leaf
import jax.tree_util as jtu
leaves_a, _ = jtu.tree_flatten(params)
leaves_b, _ = jtu.tree_flatten(restored)
print("num leaves:", len(leaves_a))
print("allclose:", all(jnp.allclose(a, b) for a,b in zip(leaves_a, leaves_b)))

In [None]:
# TrainState example with Optax
import optax
from flax.training import train_state

def loss_fn(params, x):
    y = model.apply(params, x)
    return jnp.mean(jnp.square(y))

tx = optax.adam(1e-2)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, x):
    l, grads = jax.value_and_grad(loss_fn)(state.params, x)
    state = state.apply_gradients(grads=grads)
    return state, l

for i in range(3):
    state, l = train_step(state, x)
    print(f"step {i}: loss={float(l):.6f}")

In [None]:
# NNX Example (best-effort; optional depending on version)
try:
    import flax.nnx as nnx
    class NNXMLP(nnx.Module):
        def __init__(self, in_features, hidden_features, out_features, *, rngs: nnx.Rngs):
            self.d1 = nnx.Linear(in_features, hidden_features, rngs=rngs)
            self.d2 = nnx.Linear(hidden_features, out_features, rngs=rngs)
        def __call__(self, x):
            return self.d2(nnx.relu(self.d1(x)))
    rngs = nnx.Rngs(0)
    model_nnx = NNXMLP(8, 16, 4, rngs=rngs)
    y2 = model_nnx(x)
    print("NNX forward ok, shape:", y2.shape)
except Exception as e:
    print("NNX example skipped:", e.__class__.__name__, str(e)[:200])

## Trimming Justification
- `docs/`, `docs_nnx/`: documentation sources; not needed to import or run Flax.
- `examples/`, `benchmarks/`: usage samples and perf scripts; not required at runtime.
- `tests/`: test suite only; not required by users of the library.
- `images/`: static assets for docs; safe to remove for runtime.
- `.github/`: CI and templates; no effect on local import or execution.
- Notebooks inside `flax/`: developer materials; safe to remove for runtime.
- `__pycache__/`: compiled bytecode caches; recreated on demand and safe to delete.

We deliberately keep the entire `flax/` package tree and `flaxlib_src/` to ensure no functional module is omitted. If any optional-import failure occurs (e.g., tensorboard/Orbax), install those extras only when you use those features.

## Full Correlation & Justification Report
This section scans the project and labels every file and folder as either `keep` or `remove`, with a one-line justification for each decision. Nothing is left unclassified — the report covers all paths under `flax/` and `flaxlib_src/` so you can review before applying any deletions.

Policy recap:
- Keep: all runtime code under `flax/` (the Python package) and `flaxlib_src/` (native bindings), plus packaging files within the `flax/` project folder (e.g., `pyproject.toml`, `setup.cfg`, `setup.py`) if present.
- Remove: non-runtime collateral under the `flax/` project folder such as `docs/`, `docs_nnx/`, `examples/`, `images/`, `benchmarks/`, `tests/`, `.github/`, developer notebooks inside `flax/flax/`, and any `__pycache__/` directories.

Outputs:
- `trim_report.json`: exhaustive per-path decisions with reasons.
- `trim_report.csv`: CSV version for quick scanning.
- Optional: an apply step (next cell) that enforces removals with a safety toggle and writes `trim_log.txt`.

In [None]:
# Correlate all paths and generate report
from pathlib import Path
import json, csv
from typing import List, Dict

WS = Path("/Applications/CODES/DL - Tensorflow").resolve()
FLAX_PROJECT = WS / "flax"
FLAX_PKG = FLAX_PROJECT / "flax"  # python package root
FLAXLIB_SRC = WS / "flaxlib_src"

print("Workspace:", WS)
print("Flax project root:", FLAX_PROJECT)
print("Flax package root:", FLAX_PKG)
print("flaxlib_src root:", FLAXLIB_SRC)

targets: List[Path] = [p for p in [FLAX_PROJECT, FLAXLIB_SRC] if p.exists()]
if not targets:
    raise SystemExit("No targets found to scan (expected 'flax/' or 'flaxlib_src/').")

REMOVE_DIR_NAMES = {"docs", "docs_nnx", "examples", "images", "benchmarks", "tests", ".github"}

def under(path: Path, base: Path) -> bool:
    try:
        path.resolve().relative_to(base.resolve())
        return True
    except Exception:
        return False

def decide(p: Path):
    rel = p.resolve().relative_to(WS)
    is_dir = p.is_dir()
    # Default
    action, reason = "keep", "out of scope"

    # Always keep the control notebook
    if rel.as_posix() == "JAX2TF-ViT-FLAX.ipynb":
        return action, "control notebook (keep)"

    # Cache directories
    if is_dir and p.name == "__pycache__":
        return "remove", "bytecode cache (safe to regenerate)"

    # Non-runtime collateral inside the Flax project
    if under(p, FLAX_PROJECT):
        # Remove well-known non-runtime directories anywhere under project
        if any(name in p.parts for name in REMOVE_DIR_NAMES):
            return "remove", "non-runtime collateral (docs/tests/examples/ci/assets)"
        # Developer notebooks inside the python package
        if under(p, FLAX_PKG) and p.suffix == ".ipynb":
            return "remove", "developer notebook inside package (not required at runtime)"
        # Packaging files in project root are kept
        if p.parent == FLAX_PROJECT and p.name in {"pyproject.toml", "setup.cfg", "setup.py", "README.md", "LICENSE"}:
            return "keep", "project packaging/metadata (required to build/install)"
        # Keep everything else under the python package
        if under(p, FLAX_PKG):
            return "keep", "Flax runtime package"
        # Otherwise, default keep within project unless matched above
        return "keep", "project runtime/ancillary needed for install"

    # Keep native sources entirely
    if under(p, FLAXLIB_SRC):
        return "keep", "native bindings/runtime sources"

    return action, reason

# Collect all paths (including the roots)
all_paths: List[Path] = []
for base in targets:
    all_paths.append(base)
    all_paths.extend(sorted(base.rglob("*")))

decisions: List[Dict] = []
for p in all_paths:
    act, why = decide(p)
    decisions.append({
        "path": str(p.resolve().relative_to(WS)),
        "type": "dir" if p.is_dir() else "file",
        "action": act,
        "reason": why,
    })

keep_n = sum(1 for d in decisions if d["action"] == "keep")
rem_n = sum(1 for d in decisions if d["action"] == "remove")
print(f"Decisions: keep={keep_n}, remove={rem_n}, total={len(decisions)}")

# Write reports
report_json = WS / "trim_report.json"
report_csv = WS / "trim_report.csv"
with report_json.open("w") as f:
    json.dump(decisions, f, indent=2)
with report_csv.open("w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=["path","type","action","reason"])
    w.writeheader()
    w.writerows(decisions)
print("Wrote:", report_json)
print("Wrote:", report_csv)

In [None]:
# Apply removals from report (guarded)
import json, shutil, os
from pathlib import Path

WS = Path("/Applications/CODES/DL - Tensorflow").resolve()
FLAX_PROJECT = WS / "flax"
FLAX_PKG = FLAX_PROJECT / "flax"
FLAXLIB_SRC = WS / "flaxlib_src"

report_json = WS / "trim_report.json"
if not report_json.exists():
    raise SystemExit("trim_report.json not found. Run the previous cell first.")

with report_json.open() as f:
    decisions = json.load(f)

def safe_under(p: Path) -> bool:
    # Only allow deletions under the intended project trees
    return any(str(p).startswith(str(base)) for base in [FLAX_PROJECT, FLAXLIB_SRC])

APPLY_REMOVE = False  # set True to actually delete
removed, failed = 0, 0
log_lines = []

# Delete deepest paths first to handle directories after files
for d in sorted((x for x in decisions if x["action"] == "remove"), key=lambda x: x["path"].count("/"), reverse=True):
    p = (WS / d["path"]).resolve()
    if not safe_under(p):
        log_lines.append(f"SKIP (unsafe): {p}")
        continue
    if APPLY_REMOVE:
        try:
            if p.is_dir():
                shutil.rmtree(p)
            elif p.exists():
                p.unlink()
            log_lines.append(f"REMOVED: {p}")
            removed += 1
        except Exception as e:
            log_lines.append(f"FAILED: {p} -> {e}")
            failed += 1
    else:
        log_lines.append(f"DRY-RUN would remove: {p}")

log_path = WS / "trim_log.txt"
with log_path.open("w") as f:
    f.write("\n".join(log_lines))

mode = "APPLIED" if APPLY_REMOVE else "DRY-RUN"
print(f"{mode}: decisions processed. To actually delete, set APPLY_REMOVE=True and re-run.")
print(f"Entries marked remove: {sum(1 for d in decisions if d['action']=='remove')} | removed={removed} | failed={failed}")
print("Wrote:", log_path)

## Datasets: CIFAR-10 and ImageNet-1k
This notebook can run quick demos and fine-tunes on two canonical datasets:

- CIFAR-10: 5/5 — Krizhevsky, 2009 (https://www.cs.toronto.edu/~kriz/cifar.html).
  Usability: TFDS `cifar10` auto-downloads; tiny and ideal for sanity checks.
- ImageNet-1k (ILSVRC 2012): 5/5 — Russakovsky et al., 2015 (https://image-net.org/challenges/LSVRC/2012/).
  Usability: TFDS `imagenet2012` requires manual download/acceptance; large and best with GPU/TPU.

ImageNet-1k manual setup (TFDS):
- Create a data directory and set `TFDS_DATA_DIR` or pass `data_dir` to loaders.
- Follow TFDS instructions to place tar files. See: https://www.tensorflow.org/datasets/catalog/imagenet2012#manual_download_instructions

Example (zsh):
```zsh
# Choose a data directory
export TFDS_DATA_DIR="$HOME/tfds_data"
mkdir -p "$TFDS_DATA_DIR"
# Then run the ImageNet loader cell in this notebook; it will use TFDS.
```

In [None]:
# TFDS dataset utilities for CIFAR-10 and ImageNet-1k
import os
try:
    import tensorflow as tf
    import tensorflow_datasets as tfds
except Exception as e:
    print("TensorFlow/TFDS not found. Install with:")
    print("  python -m pip install -U tensorflow tensorflow-datasets")
    raise

AUTOTUNE = getattr(tf.data, "AUTOTUNE", 16)

def _resize_to(image, size=224, antialias=True):
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (size, size), method="bilinear", antialias=antialias)
    return tf.clip_by_value(image, 0.0, 1.0)

def _augment_train(image):
    # Light augment: random flip; you can extend as needed
    image = tf.image.random_flip_left_right(image)
    return image

def _prep_example(example, size=224, train=False):
    image = example["image"]
    label = tf.cast(example["label"], tf.int32)
    image = _resize_to(image, size)
    if train:
        image = _augment_train(image)
    return image, label

def _make_pipeline(ds, batch_size, size=224, train=False, shuffle=True):
    if train and shuffle:
        ds = ds.shuffle(10_000)
    ds = ds.map(lambda ex: _prep_example(ex, size=size, train=train), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

def get_cifar10_datasets(batch_size=64, size=224, data_dir=None):
    ds_train, ds_val = tfds.load("cifar10", split=["train", "test"], data_dir=data_dir, as_supervised=False, with_info=False)
    return (
        _make_pipeline(ds_train, batch_size, size=size, train=True, shuffle=True),
        _make_pipeline(ds_val, batch_size, size=size, train=False, shuffle=False),
    )

def get_imagenet_datasets(batch_size=128, size=224, data_dir=None):
    name = "imagenet2012"
    try:
        builder = tfds.builder(name, data_dir=data_dir)
    except Exception as e:
        raise RuntimeError("TFDS builder for ImageNet-1k not available. Set TFDS_DATA_DIR or data_dir and ensure manual files are placed.") from e
    # Do NOT attempt download here; expects manual files already in place per TFDS docs.
    if not builder.info.splits:
        # Trigger metadata preparation without download if already present
        pass
    try:
        ds_train = tfds.load(name, split="train", data_dir=data_dir, as_supervised=False)
        ds_val = tfds.load(name, split="validation", data_dir=data_dir, as_supervised=False)
    except Exception as e:
        msg = [
            "Could not load ImageNet-1k via TFDS.",
            "Ensure manual download is completed per https://www.tensorflow.org/datasets/catalog/imagenet2012#manual_download_instructions",
            "Set TFDS_DATA_DIR or pass data_dir to this function.",
        ]
        raise RuntimeError("\n".join(msg)) from e
    return (
        _make_pipeline(ds_train, batch_size, size=size, train=True, shuffle=True),
        _make_pipeline(ds_val, batch_size, size=size, train=False, shuffle=False),
    )

In [None]:
# Quick run: CIFAR-10 demo (sanity check)
try:
    ds_tr, ds_val = get_cifar10_datasets(batch_size=64, size=224)
    for images, labels in ds_tr.take(1):
        print("CIFAR-10 train batch:", images.shape, labels.shape, images.dtype, labels.dtype)
    for images, labels in ds_val.take(1):
        print("CIFAR-10 val batch:", images.shape, labels.shape)
except Exception as e:
    print("CIFAR-10 loader error:", type(e).__name__, str(e)[:300])

In [None]:
# Quick run: ImageNet-1k loader check
import os
IMAGENET_DATA_DIR = os.getenv("TFDS_DATA_DIR")  # or set to a path string
print("TFDS_DATA_DIR:", IMAGENET_DATA_DIR)

try:
    ds_tr_imnet, ds_val_imnet = get_imagenet_datasets(batch_size=128, size=224, data_dir=IMAGENET_DATA_DIR)
    for images, labels in ds_val_imnet.take(1):
        print("ImageNet val batch:", images.shape, labels.shape, images.dtype, labels.dtype)
except Exception as e:
    print("ImageNet loader notice:")
    print(str(e))
    print("If you have the data, set TFDS_DATA_DIR to its location and re-run.")

## CIFAR-10 fine-tuning with a TF Hub ViT (explained simply)
We will:
1) Install the TF Hub package (lets us load the ViT).
2) Pick a ViT feature-extractor from TF Hub (no classification head).
3) Build a small Keras model: ViT (features) -> Dense(10) for CIFAR-10 classes.
4) Train for a few epochs and see accuracy.
5) Evaluate on the validation set.

Why this works:
- The TF Hub ViT gives a general image representation.
- The small Dense layer learns to map that representation to 10 CIFAR classes.
- This is fast and usually reaches good accuracy with very little code.

Tip: You can freeze the ViT (faster) or unfreeze it (more accuracy). We start frozen, then optionally unfreeze.

In [None]:
# 1) Install/import TensorFlow Hub (only runs once)
import sys, subprocess, shlex
def _run(cmd):
    print("$", cmd)
    p = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(p.stdout)
    if p.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")
try:
    import tensorflow_hub as hub
except Exception:
    _run(f"{sys.executable} -m pip install -U tensorflow-hub")
    import tensorflow_hub as hub
print("tensorflow-hub version:", hub.__version__)

In [None]:
# 2) Choose a TF Hub ViT feature extractor (no classification head)
import tensorflow as tf
import tensorflow_hub as hub

# You can change this to another handle from the same collection.
# If loading fails, pick a different handle or version.
HUB_HANDLES = [
    "https://tfhub.dev/sayakpaul/vit_b16_fe/1",
    "https://tfhub.dev/sayakpaul/vit_s16_fe/1",
    # Add more known handles here if needed
]

selected_handle = None
for h in HUB_HANDLES:
    try:
        _ = hub.KerasLayer(h, trainable=False)
        selected_handle = h
        break
    except Exception as e:
        print("Handle failed:", h, "->", type(e).__name__, str(e)[:120])

if selected_handle is None:
    raise SystemExit("No TF Hub ViT handle could be loaded. Please set HUB_HANDLES[0] to a valid handle from the README.")

print("Using TF Hub handle:", selected_handle)

In [None]:
# 3) Build the model: ViT feature extractor -> Dense(10)
import tensorflow as tf
import tensorflow_hub as hub

inputs = tf.keras.Input(shape=(224, 224, 3), name="image")
vit = hub.KerasLayer(selected_handle, trainable=False, name="vit_fe")
features = vit(inputs)  # typically a 1D embedding
outputs = tf.keras.layers.Dense(10, activation="softmax", name="classifier")(features)
model = tf.keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)
model.summary()

In [None]:
# 4) Train for a few epochs on CIFAR-10 (simple and quick)
BATCH = 64
EPOCHS = 3  # increase for better accuracy

ds_tr, ds_val = get_cifar10_datasets(batch_size=BATCH, size=224)
steps_tr = 50   # limit steps for a quicker demo; set None for full epoch
steps_val = 20  # limit steps for a quicker demo; set None for full epoch

history = model.fit(
    ds_tr,
    validation_data=ds_val,
    epochs=EPOCHS,
    steps_per_epoch=steps_tr,
    validation_steps=steps_val,
)

In [None]:
# 5) Evaluate and (optional) unfreeze ViT for extra accuracy
val_loss, val_acc = model.evaluate(ds_val, steps=steps_val)
print({"val_loss": float(val_loss), "val_acc": float(val_acc)})

DO_UNFREEZE = False  # set True to fine-tune the ViT weights as well
if DO_UNFREEZE:
    print("Unfreezing the ViT and training briefly...")
    model.get_layer("vit_fe").trainable = True
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-5),  # lower LR when unfreezing
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )
    history2 = model.fit(
        ds_tr,
        validation_data=ds_val,
        epochs=1,
        steps_per_epoch=steps_tr,
        validation_steps=steps_val,
    )