# TensorFlow EffAxNet 2D Demo
This notebook demonstrates how to instantiate and train the 2D Efficient Axial Network using TensorFlow.

## Environment Setup

Install `axiom[tf]` or ensure TensorFlow 2.13+ and Keras are already available before running the rest of this notebook. The following cells import the EffAxNet TensorFlow API.

In [None]:
import tensorflow as tf
from tensorflow import keras
from axiom.tf.models import EffAxNetV1

print(tf.__version__)  # quick confirmation


## Create a Feature Extractor
Use `EffAxNetV1` with `include_top=False` to obtain a convolutional backbone that outputs pooled features. Set `pooling="avg"` to receive a flat vector suitable for a custom head.

In [None]:
tf.random.set_seed(42)

backbone = EffAxNetV1(
    variant="2d",
    include_top=False,
    pooling="avg",
    input_shape=(128, 128, 3),
)

print(backbone.output_shape)


## Attach a Custom Classification Head
Wrap the backbone with additional layers. The final dense layer can target any number of classes. Freeze the backbone for rapid transfer-learning warm-up.

In [None]:
backbone.trainable = False

inputs = backbone.input
x = keras.layers.Dense(128, activation="relu", name="transfer_dense1")(backbone.output)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(10, activation="softmax", name="transfer_logits")(x)
model = keras.Model(inputs, outputs, name="effaxnet2d_transfer")

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.summary(line_length=110)


## Quick Smoke-Test Training
Use random tensors to verify the pipeline wiring. Replace this with your dataset pipeline (e.g., `tf.data.Dataset`) for real training.

In [None]:
dummy_images = tf.random.normal([32, 128, 128, 3])
dummy_labels = tf.random.uniform([32], minval=0, maxval=10, dtype=tf.int32)

history = model.fit(
    dummy_images,
    dummy_labels,
    batch_size=8,
    epochs=2,
    verbose=1,
)


## Optional Fine-Tuning
After the classifier head stabilises, unfreeze selected stages and continue training with a lower learning rate.

In [None]:
for layer in backbone.layers[-20:]:
    layer.trainable = True

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.fit(dummy_images, dummy_labels, epochs=1, batch_size=8, verbose=1)


## Save and Reload Weights
Use the new `include_weights` parameter to restore checkpoints saved from a previous run.

In [None]:
checkpoint_path = "effaxnet2d_transfer.weights.h5"
model.save_weights(checkpoint_path)

reloaded_backbone = EffAxNetV1(
    variant="2d",
    include_top=False,
    pooling="avg",
    input_shape=(128, 128, 3),
    include_weights=checkpoint_path,
)

print("Weights restored for feature extractor:", reloaded_backbone.output_shape)


## Next Steps
- Replace dummy data with your dataset input pipeline.
- Adjust `include_top`, `pooling`, and `classifier_activation` depending on your objective.
- Consider saving/exporting the entire model via `model.save()` for deployment.