# TensorFlow EffAxNet 3D Demo
This notebook demonstrates how to instantiate and train the 3D Efficient Axial Network using TensorFlow.

## Environment Setup
Ensure TensorFlow 2.13+ with 3D support (GPU recommended). Import the EffAxNet constructors from `axiom`.

In [None]:
import tensorflow as tf
from tensorflow import keras
from axiom.tf.models import EffAxNetV1

print(tf.__version__)


## Build a Volumetric Backbone
Creating a 3D EffAxNet feature extractor works the same wayâ€”set `variant="3d"` and request pooled outputs.

In [None]:
tf.random.set_seed(13)

backbone = EffAxNetV1(
    variant="3d",
    include_top=False,
    pooling="avg",
    input_shape=(64, 64, 64, 1),
)

backbone.summary(line_length=100, expand_nested=False)


## Attach a Segmentation/Class Head
For volumetric classification, add dense layers. For segmentation tasks, replace the head with 3D convs or UNet-style decoders.

In [None]:
backbone.trainable = False

inputs = backbone.input
x = keras.layers.Dense(256, activation="relu")(backbone.output)
x = keras.layers.Dropout(0.4)(x)
outputs = keras.layers.Dense(4, activation="softmax", name="lesion_logits")(x)
model = keras.Model(inputs, outputs, name="effaxnet3d_transfer")

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=5e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)


## Smoke-Test Training Step
Use randomly generated volumetric data to validate that the model compiles and trains. Replace this with real DICOM/NIfTI pipelines.

In [None]:
dummy_volumes = tf.random.normal([8, 64, 64, 64, 1])
dummy_labels = tf.random.uniform([8], minval=0, maxval=4, dtype=tf.int32)

model.fit(dummy_volumes, dummy_labels, batch_size=2, epochs=1, verbose=1)


## Fine-Tune Selected Blocks
3D models are memory heavy. Unfreeze carefully and consider mixed precision.

In [None]:
for layer in backbone.layers[-15:]:
    layer.trainable = True

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=2e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.fit(dummy_volumes, dummy_labels, batch_size=2, epochs=1, verbose=1)


## Persist Checkpoints
Save the trained weights and rebuild the backbone later using the `include_weights` shortcut.

In [None]:
checkpoint_path = "effaxnet3d_transfer.weights.h5"
model.save_weights(checkpoint_path)

reloaded_backbone = EffAxNetV1(
    variant="3d",
    include_top=False,
    pooling="avg",
    input_shape=(64, 64, 64, 1),
    include_weights=checkpoint_path,
)

print("Reloaded backbone output shape:", reloaded_backbone.output_shape)


## Next Steps
- Swap dummy volumes with your medical dataset (e.g., `tfio.IODataset.from_parquet`, NIfTI readers, etc.).
- Convert the backbone to mixed precision with `tf.keras.mixed_precision.set_global_policy('mixed_float16')` for speed.
- Export to SavedModel or ONNX for deployment once fine-tuning is complete.