# ABOC Model Training on Google Colab (TensorFlow + FlatBuffers)

This notebook trains the Ancestor-Based Octree Compression (ABOC) model using TensorFlow on TPUs. 
It loads data from **FlatBuffer (.fb)** files stored in a Google Cloud Storage (GCS) bucket.

## Source Code
Code is pulled from GitHub: [michaelnutt2/ABOC](https://github.com/michaelnutt2/ABOC)

**Key Features:**
*   **TPU Acceleration:** Uses `tf.distribute.TPUStrategy` for high-performance training.
*   **GCS Integration:** Streams data directly from GCS (`gs://mtn_fb_file_bucket`) to local Colab runtime for processing.
*   **FlatBuffers:** Efficiently parses `.fb` files via the `OctreeData` module.
*   **Validation Split:** Automatically splits data into Training (90%) and Validation (10%) sets.

## 1. Setup & Clone Repo

In [None]:
# Install dependencies
!pip install -q tensorflow gcsfs flatbuffers

import os
import sys
import time

# Clone the repository
PROJECT_PATH = '/content/ABOC'
if not os.path.exists(PROJECT_PATH):
    !git clone https://github.com/michaelnutt2/ABOC.git
else:
    print("Repo already cloned. Pulling latest...")
    !cd {PROJECT_PATH} && git pull

os.chdir(PROJECT_PATH)
sys.path.append(PROJECT_PATH)
print(f"Working Directory set to: {os.getcwd()}")

import math
import glob
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import flatbuffers
import gcsfs

# Authenticate with Google Cloud
from google.colab import auth
auth.authenticate_user()
print("Authenticated with Google Cloud.")

## 2. Configuration & TPU Setup

In [None]:
# Import Config from Repo
import config

# --- GCS Configuration ---
GCS_BUCKET = config.BUCKET_NAME
GCS_BASE_URI = f"gs://{GCS_BUCKET}"
# Pattern to match .fb files. Adjust if files are in a subdir (e.g., 'data/*.fb')
GCS_DATA_PATTERN = f"{GCS_BASE_URI}/{config.GCS_DATA_PREFIX}/*.fb"
GCS_LOG_DIR = f"{GCS_BASE_URI}/logs"
GCS_CHECKPOINT_DIR = f"{GCS_BASE_URI}/{config.GCS_CHECKPOINT_PREFIX}"

# --- Model Hyperparameters ---
EPOCHS = config.EPOCHS
CONTEXT_LEN = config.CONTEXT_LEN
MAX_OCTREE_LEVEL = config.MAX_OCTREE_LEVEL
VOCAB_SIZE = config.VOCAB_SIZE
EMBED_DIM = config.EMBED_DIM
NUM_HEADS = config.NUM_HEADS
FF_DIM = config.FF_DIM
NUM_LAYERS = config.NUM_LAYERS
DROPOUT = config.DROPOUT

# Learning Rate
INITIAL_LR = 1e-3

# --- TPU Setup ---
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    print('TPU not found. Using Default Strategy (CPU/GPU).')
    strategy = tf.distribute.get_strategy()

BATCH_SIZE_PER_REPLICA = 128 # Adjust based on memory
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print(f"Global Batch Size: {GLOBAL_BATCH_SIZE}")
print(f"Data Pattern: {GCS_DATA_PATTERN}")

## 3. Data Loading (Using Repo Modules)

We use `OctreeData.Dataset` and `OctreeData.OctreeNode` from the cloned repository to parse FlatBuffers.

In [None]:
import OctreeData.Dataset as Dataset
import OctreeData.OctreeNode as OctreeNode

def parse_flatbuffer_bytes(buf):
    """
    Parses a single FlatBuffer byte string.
    Returns: Tuple of arrays (Context, Level, Octant, Label)
    """
    dataset = Dataset.Dataset.GetRootAsDataset(buf, 0)
    
    num_nodes = dataset.NodesLength()
    
    # Pre-allocate arrays
    # Inputs: [Occupancy_Context, Level, Octant]
    # context_len matches config
    contexts = np.zeros((num_nodes, CONTEXT_LEN), dtype=np.int32)
    levels = np.zeros((num_nodes,), dtype=np.int32)
    octants = np.zeros((num_nodes,), dtype=np.int32)
    labels = np.zeros((num_nodes,), dtype=np.int32)

    mid = CONTEXT_LEN // 2
    
    for i in range(num_nodes):
        node = dataset.Nodes(i)
        
        # Label
        labels[i] = node.Occupancy()
        
        # Features
        levels[i] = max(0, node.Level() - 1)
        octants[i] = node.Octant()
        
        # Context Construction
        # Use NeighborOccupanciesAsNumpy if available (Requires updated generated code)
        try:
            neighbors = node.NeighborOccupanciesAsNumpy()
        except AttributeError:
            # Fallback if generated code is old
            n_len = node.NeighborOccupanciesLength()
            neighbors = np.zeros(n_len, dtype=np.int32)
            for j in range(n_len):
                neighbors[j] = node.NeighborOccupancies(j)
        
        n_len = len(neighbors)
        needed = CONTEXT_LEN - 1
        
        if n_len >= needed:
            contexts[i, 0:mid] = neighbors[0:mid]
            contexts[i, mid] = node.ParentOccupancy()
            contexts[i, mid+1:] = neighbors[mid:]
        else:
            # Pad or partial fill (Safe fallback)
            contexts[i, mid] = node.ParentOccupancy()
            
    return contexts, levels, octants, labels

In [None]:
# Prepare Data Split
fs = gcsfs.GCSFileSystem()
all_files = fs.glob(GCS_DATA_PATTERN)
print(f"Found {len(all_files)} files in GCS.")

np.random.shuffle(all_files)

VAL_SPLIT = 0.1
split_idx = int(len(all_files) * (1 - VAL_SPLIT))

train_files = all_files[:split_idx]
val_files = all_files[split_idx:]

print(f"Training Files: {len(train_files)}")
print(f"Validation Files: {len(val_files)}")

def data_generator(file_list):
    """
    Generator that yields (inputs, labels) for tf.data.Dataset.
    Reads from GCS bucket files provided in file_list.
    """
    # Shuffle files each epoch (or relies on dataset.shuffle)
    # If using .from_generator, better to shuffle file list per epoch inside generator if possible, 
    # OR rely on shuffle buffer (but buffer needs to be large).
    # We'll shuffle list here if we loop indefinitely, but tf.data handles epochs.
    # We iterate once per 'repeat'.
    
    random_files = list(file_list)
    np.random.shuffle(random_files)
    
    fs_local = gcsfs.GCSFileSystem()
    
    for f_path in random_files:
        try:
            with fs_local.open(f_path, 'rb') as f:
                buf = f.read()
                
            ctx, lvl, octat, y = parse_flatbuffer_bytes(buf)
            yield (ctx, lvl, octat), y
            
        except Exception as e:
            print(f"Error processing {f_path}: {e}")
            continue

def create_dataset(file_list, batch_size, is_training=True):
    
    # Wrap generator with args
    def gen():
        for x in data_generator(file_list):
            yield x

    ds = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            (tf.TensorSpec(shape=(None, CONTEXT_LEN), dtype=tf.int32), 
             tf.TensorSpec(shape=(None,), dtype=tf.int32), 
             tf.TensorSpec(shape=(None,), dtype=tf.int32)),
            tf.TensorSpec(shape=(None,), dtype=tf.int32)
        )
    )
    
    def unbatch_file(inputs, labels):
        ctx, lvl, octat = inputs
        ds_file = tf.data.Dataset.from_tensor_slices(((ctx, lvl, octat), labels))
        return ds_file
        
    ds = ds.flat_map(unbatch_file)
    
    if is_training:
        ds = ds.shuffle(buffer_size=100000)
    
    ds = ds.batch(batch_size, drop_remainder=is_training) # Drop remainder for train (TPU req?), optional for val
    
    def format_inputs(inputs, label):
        ctx, lvl, octat = inputs
        
        lvl_expanded = tf.expand_dims(lvl, -1) # [Batch, 1]
        lvl_seq = tf.repeat(lvl_expanded, CONTEXT_LEN, axis=1) # [Batch, Seq]
        
        oct_expanded = tf.expand_dims(octat, -1)
        oct_seq = tf.repeat(oct_expanded, CONTEXT_LEN, axis=1)
        
        x = tf.stack([ctx, lvl_seq, oct_seq], axis=-1)
        return x, label
        
    ds = ds.map(format_inputs, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    
    return ds

## 4. Model Creation & Training

In [None]:
# Create Datasets
train_ds = create_dataset(train_files, GLOBAL_BATCH_SIZE, is_training=True)
val_ds = create_dataset(val_files, GLOBAL_BATCH_SIZE, is_training=False)

# Import model definition from repo
from tf_model import create_model

with strategy.scope():
    model = create_model()
    optimizer = keras.optimizers.Adam(learning_rate=INITIAL_LR)
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])
    model.summary()

# Checkpoint Callback
checkpoint_path = f"{GCS_CHECKPOINT_DIR}/checkpoints_model_epoch_{{epoch:02d}}.weights.h5"
cp_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    verbose=1
)

# TensorBoard
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=GCS_LOG_DIR, histogram_freq=1)

# Train
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    steps_per_epoch=200, # Approx steps per epoch (since generator is infinite/unbatched)
    validation_steps=20, # Steps to validate
    callbacks=[cp_callback, tensorboard_callback]
)