# ABOC Model Training on Google Colab (TensorFlow + FlatBuffers)

This notebook trains the Ancestor-Based Octree Compression (ABOC) model using TensorFlow on TPUs. 
It loads data from **FlatBuffer (.fb)** files stored in a Google Cloud Storage (GCS) bucket.

**Key Features:**
*   **TPU Acceleration:** Uses `tf.distribute.TPUStrategy` for high-performance training.
*   **GCS Integration:** Streams data directly from GCS (`gs://mtn_fb_file_bucket`) to local Colab runtime for processing.
*   **FlatBuffers:** Efficiently parses `.fb` files without external compilation steps (schema code is inlined).
*   **Custom Data Pipeline:** Reconstructs octree context (Parent + Neighbors) on-the-fly.

## 1. Setup & Imports

In [None]:
# Install dependencies
!pip install -q tensorflow gcsfs flatbuffers

import os
import sys
import time
import math
import glob
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import flatbuffers
import gcsfs

# Authenticate with Google Cloud
from google.colab import auth
auth.authenticate_user()
print("Authenticated with Google Cloud.")

## 2. Configuration

In [None]:
# --- GCS Configuration ---
GCS_BUCKET = "mtn_fb_file_bucket"
GCS_BASE_URI = f"gs://{GCS_BUCKET}"
# Pattern to match .fb files. Adjust if files are in a subdir (e.g., 'data/*.fb')
GCS_DATA_PATTERN = f"{GCS_BASE_URI}/*.fb"
GCS_LOG_DIR = f"{GCS_BASE_URI}/logs"
GCS_CHECKPOINT_DIR = f"{GCS_BASE_URI}/checkpoints"

# --- Hyperparameters (Matching octAttention.py) ---
EPOCHS = 50
CONTEXT_LEN = 17 # Parent + 16 Neighbors
MAX_OCTREE_LEVEL = 21 # Matching networkTool.py

# Model Params
VOCAB_SIZE = 256
EMBED_DIM = 140 # 130 + 6 + 4
NUM_HEADS = 4
FF_DIM = 300
NUM_LAYERS = 3
DROPOUT = 0.0

# Learning Rate
INITIAL_LR = 1e-3

# --- TPU Setup ---
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    print('TPU not found. Using Default Strategy (CPU/GPU).')
    strategy = tf.distribute.get_strategy()

BATCH_SIZE_PER_REPLICA = 128 # Adjust based on memory
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print(f"Global Batch Size: {GLOBAL_BATCH_SIZE}")

## 3. FlatBuffers Definitions (Inlined)

Inlining the Python code generated by `flatc` for `OctreeData.Dataset` and `OctreeData.OctreeNode` so this notebook is self-contained.

In [None]:
import flatbuffers
from flatbuffers.compat import import_numpy
np_fb = import_numpy()

class OctreeNode(object):
    __slots__ = ['_tab']
    @classmethod
    def GetRootAs(cls, buf, offset=0):
        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
        x = OctreeNode()
        x.Init(buf, n + offset)
        return x
    def Init(self, buf, pos):
        self._tab = flatbuffers.table.Table(buf, pos)
    def Level(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
        if o != 0:
            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
        return 0
    def Octant(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
        if o != 0:
            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
        return 0
    def Occupancy(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
        if o != 0:
            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
        return 0
    def ParentOccupancy(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
        if o != 0:
            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
        return 0
    def NeighborOccupancies(self, j):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
        if o != 0:
            a = self._tab.Vector(o)
            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
        return 0
    def NeighborOccupanciesLength(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
        if o != 0:
            return self._tab.VectorLen(o)
        return 0

class Dataset(object):
    __slots__ = ['_tab']
    @classmethod
    def GetRootAsDataset(cls, buf, offset=0):
        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
        x = Dataset()
        x.Init(buf, n + offset)
        return x
    def Init(self, buf, pos):
        self._tab = flatbuffers.table.Table(buf, pos)
    def Nodes(self, j):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
        if o != 0:
            x = self._tab.Vector(o)
            x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
            x = self._tab.Indirect(x)
            obj = OctreeNode()
            obj.Init(self._tab.Bytes, x)
            return obj
        return None
    def NodesLength(self):
        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
        if o != 0:
            return self._tab.VectorLen(o)
        return 0

## 4. Utilities

In [None]:
class PrintLog:
    def __init__(self, log_path):
        self.log_path = log_path
        self.fs = gcsfs.GCSFileSystem()
        # Ensure dir exists
        try:
            self.fs.makedirs(os.path.dirname(log_path), exist_ok=True)
        except:
            pass

    def __call__(self, msg):
        print(msg)
        try:
            with self.fs.open(self.log_path, 'ab') as f:
                f.write((msg + "\n").encode('utf-8'))
        except Exception as e:
            print(f"Warning: Failed to write to GCS log: {e}")

print_log = PrintLog(f"{GCS_LOG_DIR}/training_{int(time.time())}.log")
print_log("Starting training session.")

## 5. Model Definition (TensorFlow/Keras Port)

Ported from `octAttention.py` (PyTorch) to TensorFlow/Keras.
Inputs: `[Batch, 17, 3]` (Channels: Occupancy, Level, Octant)
Output: `[Batch, 256]` (Logits)

In [None]:
def create_model():
    # Inputs: [Batch, Sequence=17, Channels=3]
    # Channels: 0=Occupancy, 1=Level, 2=Octant

    inputs = keras.Input(shape=(CONTEXT_LEN, 3), dtype=tf.int32)

    # Split channels
    occ_input = inputs[:, :, 0]
    lvl_input = inputs[:, :, 1]
    oct_input = inputs[:, :, 2]

    # Embeddings (Matching octAttention.py)
    # self.encoder = nn.Embedding(ntoken, 130)
    occ_emb = layers.Embedding(VOCAB_SIZE, 130)(occ_input)

    # self.encoder1 = nn.Embedding(MAX_OCTREE_LEVEL + 1, 6)
    lvl_emb = layers.Embedding(MAX_OCTREE_LEVEL + 1, 6)(lvl_input)

    # self.encoder2 = nn.Embedding(9, 4)
    oct_emb = layers.Embedding(9, 4)(oct_input)

    # Concatenate: 130 + 6 + 4 = 140
    x = layers.Concatenate(axis=-1)([occ_emb, lvl_emb, oct_emb])

    # Custom Layer to handle SparseTensor from TPU Embeddings
    class ToDense(layers.Layer):
        def __init__(self, **kwargs):
            super(ToDense, self).__init__(**kwargs)

        def compute_output_shape(self, input_shape):
            return input_shape

        def call(self, inputs):
            # Check for SparseTensor (kinda redundant with Keras, but safe)
            if isinstance(inputs, tf.SparseTensor):
                return tf.sparse.to_dense(inputs)
            # Check for KerasTensor's sparse property
            if hasattr(inputs, 'sparse') and inputs.sparse:
                return tf.sparse.to_dense(inputs)
            return inputs

    x = ToDense()(x)

    # Scale by sqrt(embedding_dim) - PyTorch does this in forward
    x = x * tf.math.sqrt(tf.cast(140.0, tf.float32))

    # Positional Encoding
    class PositionalEncoding(layers.Layer):
        def __init__(self, d_model, max_len=5000):
            super().__init__()
            self.d_model = d_model
            # Compute PE once
            position = np.arange(max_len)[:, np.newaxis]
            div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
            pe = np.zeros((max_len, d_model))
            pe[:, 0::2] = np.sin(position * div_term)
            pe[:, 1::2] = np.cos(position * div_term)
            self.pe = tf.constant(pe, dtype=tf.float32)

        def compute_output_shape(self, input_shape):
            return input_shape

        def call(self, x):
            # x shape: [Batch, Seq, Dim]
            seq_len = tf.shape(x)[1]
            return x + self.pe[:seq_len, :]

    x = PositionalEncoding(EMBED_DIM)(x)
    if DROPOUT > 0:
        x = layers.Dropout(DROPOUT)(x)

    # Transformer Encoder
    for _ in range(NUM_LAYERS):
        # MultiHead Attention
        # key_dim = embed_dim // num_heads
        att_output = layers.MultiHeadAttention(num_heads=NUM_HEADS, key_dim=EMBED_DIM//NUM_HEADS)(x, x)
        if DROPOUT > 0:
            att_output = layers.Dropout(DROPOUT)(att_output)
        x1 = layers.LayerNormalization(epsilon=1e-6)(x + att_output)

        # Feed Forward
        ffn_output = layers.Dense(FF_DIM, activation='relu')(x1)
        ffn_output = layers.Dense(EMBED_DIM)(ffn_output)
        if DROPOUT > 0:
            ffn_output = layers.Dropout(DROPOUT)(ffn_output)
        x = layers.LayerNormalization(epsilon=1e-6)(x1 + ffn_output)

    # Decoder Head
    # self.decoder0 = nn.Linear(ninp, ninp)
    x = layers.Dense(EMBED_DIM, activation='relu')(x)

    # self.decoder1 = nn.Linear(ninp, ntoken)
    logits = layers.Dense(VOCAB_SIZE)(x)

    # Extract center token (Index 8 in sequence of 17)
    # octAttention.py: output = output[8, :, :] (Wait, PyTorch is [Seq, Batch, Dim])
    # Here we are [Batch, Seq, Dim], so we want [:, 8, :]
    center_logits = logits[:, 8, :]

    return keras.Model(inputs=inputs, outputs=center_logits)

with strategy.scope():
    model = create_model()
    model.summary()

## 6. Data Pipeline

Loads `.fb` files, parses with inlined FlatBuffer classes, and streams to `tf.data`.

In [None]:
def load_flatbuffer_data(gcs_file_path_tensor):
    """Loads a single .fb file from GCS and yields (input, target) pairs."""
    if hasattr(gcs_file_path_tensor, 'numpy'):
        gcs_path = gcs_file_path_tensor.numpy().decode('utf-8')
    else:
        gcs_path = gcs_file_path_tensor.decode('utf-8')

    fs = gcsfs.GCSFileSystem()

    temp_file = tempfile.NamedTemporaryFile(suffix='.fb', delete=False)
    try:
        # Download to local temp
        fs.get(gcs_path, temp_file.name)

        with open(temp_file.name, 'rb') as f:
            buf = f.read()

        dataset = Dataset.GetRootAsDataset(buf, 0)
        nodes_len = dataset.NodesLength()

        # Iterate all nodes
        for i in range(nodes_len):
            node = dataset.Nodes(i)

            # --- Parse Context ---
            # Logic from dataset.py: [8 neighbors] + [Parent] + [8 neighbors]
            parent_occ = node.ParentOccupancy()
            neighbors_len = node.NeighborOccupanciesLength()

            # Explicitly loop as per flatbuffers behavior
            neighbors = []
            for j in range(neighbors_len):
                neighbors.append(node.NeighborOccupancies(j))

            # Create 17-length context
            context = np.zeros(17, dtype=np.int32)
            if len(neighbors) >= 16: # Safety check, should be 16
                context[0:8] = neighbors[0:8]
                context[8] = parent_occ
                context[9:] = neighbors[8:16]
            else:
                # Handle edge case if schema differs
                context[8] = parent_occ

            # All inputs: [17, 3]
            # Channel 0: Occupancy (Context)
            # Channel 1: Level (Scalar repeated)
            # Channel 2: Octant (Scalar at center, 0 elsewhere)

            input_tensor = np.zeros((17, 3), dtype=np.int32)
            input_tensor[:, 0] = context

            lvl = node.Level()
            input_tensor[:, 1] = max(0, lvl - 1) # dataset.py logic

            octant = node.Octant()
            input_tensor[8, 2] = octant # Only center has octant info

            # Target
            target = node.Occupancy()

            yield input_tensor, target

    except Exception as e:
        print(f"Error loading {gcs_path}: {e}")
    finally:
        if os.path.exists(temp_file.name):
            os.remove(temp_file.name)

def create_dataset(file_pattern, batch_size, is_training=True):
    files = tf.io.gfile.glob(file_pattern)
    print_log(f"Found {len(files)} files matching {file_pattern}")
    if not files:
        raise ValueError(f"No files found for {file_pattern}")

    dataset = tf.data.Dataset.from_tensor_slices(files)
    if is_training:
        dataset = dataset.shuffle(len(files))
        dataset = dataset.repeat()

    dataset = dataset.interleave(
        lambda x: tf.data.Dataset.from_generator(
            load_flatbuffer_data,
            output_signature=(
                tf.TensorSpec(shape=(17, 3), dtype=tf.int32),
                tf.TensorSpec(shape=(), dtype=tf.int32)
            ),
            args=[x]
        ),
        cycle_length=tf.data.AUTOTUNE,
        block_length=1,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    if is_training:
        dataset = dataset.shuffle(10000)

    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

## 7. Training Loop

In [None]:
train_ds = create_dataset(GCS_DATA_PATTERN, GLOBAL_BATCH_SIZE)
train_dist_ds = strategy.experimental_distribute_dataset(train_ds)

with strategy.scope():
    optimizer = keras.optimizers.Adam(learning_rate=INITIAL_LR)
    loss_object = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    @tf.function
    def train_step(inputs):
        x, y = inputs
        with tf.GradientTape() as tape:
            predictions = model(x, training=True)
            loss = compute_loss(y, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss

# --- Run Training ---
print_log("Starting Training Loop...")
STEPS_PER_EPOCH = 1000 # Configurable or calculated based on dataset size
train_iter = iter(train_dist_ds)

for epoch in range(EPOCHS):
    print_log(f"Epoch {epoch+1}/{EPOCHS}")
    total_loss = 0.0
    start = time.time()

    for step in range(STEPS_PER_EPOCH):
        loss = strategy.run(train_step, args=(next(train_iter),))
        loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        total_loss += loss

        if step % 100 == 0:
             print(f"  Step {step}, Loss: {loss:.4f}")

    avg_loss = total_loss / STEPS_PER_EPOCH
    duration = time.time() - start
    print_log(f"  Epoch Ended. Avg Loss: {avg_loss:.4f}, Time: {duration:.2f}s")

    # Save Checkpoint
    if (epoch + 1) % 5 == 0:
        save_path = f"{GCS_CHECKPOINT_DIR}/model_epoch_{epoch+1}.h5"
        model.save_weights(save_path)
        print_log(f"  Saved checkpoint to {save_path}")