# ABOC GPU Benchmark on Google Colab

This notebook runs the Encoder and Parallel Decoder on a GPU, using **pre-generated FlatBuffer data** from Google Cloud Storage (GCS).

## Prerequisites
1.  **Google Drive**: Upload your entire `ABOC` project folder to your Google Drive.
2.  **GCS Bucket**: Ensure you have access to `gs://mtn_fb_file_bucket`.
3.  **GPU Runtime**: Change Runtime Type to **T4 GPU** (or better).

In [None]:
# 1. Mount Google Drive (to access the code)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 2. Setup Environment
import os
import sys

# Change this path to where you uploaded the ABOC folder
PROJECT_PATH = '/content/drive/MyDrive/ABOC' 
if not os.path.exists(PROJECT_PATH):
    print(f"ERROR: Path {PROJECT_PATH} does not exist. Please check your Drive structure.")
else:
    os.chdir(PROJECT_PATH)
    sys.path.append(PROJECT_PATH)
    print(f"Working Directory set to: {os.getcwd()}")

# Install dependencies
# Added hdf5storage to fix ImportErrors in legacy modules
# Pinning numpy<2 to avoid PyTorch binary incompatibility (ValueError: METH_CLASS)
!pip install "numpy<2" flatbuffers plyfile Ninja hdf5storage
!sudo apt-get install ninja-build

In [None]:
# 3. Authenticate with Google Cloud (for GCS Access)
from google.colab import auth
auth.authenticate_user()

# Set project ID if needed (optional for public/accessible buckets)
# !gcloud config set project YOUR_PROJECT_ID

In [None]:
# 4. JIT Compile the C++ Backend (Force Rebuild to ensure GPU/Colab compat)
import torch
from torch.utils.cpp_extension import load

print("Compiling numpyAc C++ backend...")
try:
    # Remove old build artifacts if they exist to force clean build
    !rm -rf ./numpyAc/backend/build
    
    # Load/Compile
    # This will compile numpyAc_backend.cpp and bind it
    import numpyAc.numpyAc 
    print("Success: numpyAc module loaded.")
except Exception as e:
    print(f"Compilation Error: {e}")
    # Fallback compilation command if import fails (debug)
    # !cd numpyAc/backend && python3 setup.py install

In [None]:
# 5. Configuration (Loaded from config.py)
import config

BUCKET_NAME = config.BUCKET_NAME
FB_PREFIX = config.GCS_DATA_PREFIX
CKPT_PREFIX = config.GCS_CHECKPOINT_PREFIX

# Sampling Configuration
# Statistical Significance Guide (Pop = 48,000):
# - 381 samples: 95% Confidence, 5% Margin of Error (Gold Standard)
# - 100 samples: 95% Confidence, ~10% Margin of Error (Good Engineering Trade-off)
# - 50 samples: Quick Smoke Test
NUM_SAMPLES = 50

TEMP_DIR = "/content/tmp_data"
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(f"{TEMP_DIR}/decoded", exist_ok=True)

import tensorflow as tf
import glob
print(f"TensorFlow Device: {tf.config.list_physical_devices('GPU')}")
print(f"Using Bucket: {BUCKET_NAME}")

# -----------------------------------------------------
# Fetch Best Checkpoint from GCS
# -----------------------------------------------------
print(f"Listing checkpoints in gs://{BUCKET_NAME}/{CKPT_PREFIX}...")
gcs_ckpts = !gsutil ls gs://{BUCKET_NAME}/{CKPT_PREFIX}/*.weights.h5
ckpt_candidates = [c for c in gcs_ckpts if c.endswith('.weights.h5')]

MODEL_PATH = None

if ckpt_candidates:
    ckpt_candidates.sort()
    best_ckpt_gcs = ckpt_candidates[-1]
    ckpt_filename = os.path.basename(best_ckpt_gcs)
    MODEL_PATH = f"{TEMP_DIR}/{ckpt_filename}"
    
    print(f"Downloading checkpoint: {best_ckpt_gcs} -> {MODEL_PATH}")
    !gsutil cp {best_ckpt_gcs} {MODEL_PATH}
else:
    print("WARNING: No checkpoints found in GCS! Falling back to Drive default if available.")
    default_drive_path = os.path.join(PROJECT_PATH, 'modelsave/lidar/checkpoints_model_epoch_50.weights.h5')
    if os.path.exists(default_drive_path):
        MODEL_PATH = default_drive_path
        print(f"Using Drive Checkpoint: {MODEL_PATH}")
    else:
        print("ERROR: No checkpoints found anywhere. Inference will use Random Weights.")

# -----------------------------------------------------

from encoder_tf import EncoderTF
from decoder_tf_parallel import DecoderTFParallel

# Initialize Models (Loads Weights from Local Temp Path)
if MODEL_PATH and os.path.exists(MODEL_PATH):
    encoder = EncoderTF(f"{TEMP_DIR}/encoder_log.txt", model_path=MODEL_PATH)
    decoder = DecoderTFParallel(model_path=MODEL_PATH)
else:
    print("Initializing with RANDOM WEIGHTS (Benchmark is valid for speed, invalid for BPP).")
    encoder = EncoderTF(f"{TEMP_DIR}/encoder_log.txt")
    decoder = DecoderTFParallel()

In [None]:
# 6. Benchmark Loop (With Averaging & Accuracy Check)
import time
import subprocess
import glob
import random
import shutil
import numpy as np
import OctreeData.Dataset as Dataset

def count_points_in_fb(fb_path):
    try:
        with open(fb_path, 'rb') as f:
            buf = f.read()
            dataset = Dataset.Dataset.GetRootAsDataset(buf, 0)
            # Total Octree Nodes as input complexity proxy
            return dataset.NodesLength()
    except Exception:
        return 0

def count_points_in_ply(ply_path):
    try:
        with open(ply_path, 'r') as f:
            header = True
            count = 0
            for line in f:
                if header:
                    if line.startswith('element vertex'):
                        return int(line.split()[-1])
                    if line.strip() == 'end_header':
                        header = False
        return 0
    except:
        return 0

# Metrics Storage
stats = {
    'download': [],
    'encode': [],
    'decode': [],
    'bpp': [], 
    'input_nodes': [],
    'output_points': []
}

def benchmark_gcs_file(gcs_path):
    filename = os.path.basename(gcs_path)
    local_fb = f"{TEMP_DIR}/{filename}"
    local_encoded = f"{TEMP_DIR}/{filename}.enc.bin"
    local_decoded = f"{TEMP_DIR}/decoded/{filename}.ply"
    
    print(f"--- Processing {filename} ---")
    
    # A. Download (.fb)
    t0 = time.time()
    !gsutil cp {gcs_path} {local_fb}
    download_time = time.time() - t0
    
    # Get Input Stats
    input_node_count = count_points_in_fb(local_fb)
    
    # B. Encode (GPU)
    t1 = time.time()
    # Note: EncoderTF.compress returns bpp=0.0 for FB inputs.
    _, size_bytes = encoder.compress(local_fb)
    
    basename = os.path.splitext(filename)[0]
    default_out = f"./Exp/Kitti_TF/data/{basename}.bin"
    
    if os.path.exists(default_out):
        shutil.move(default_out, local_encoded)
    else:
        print("Error: Encoder output not found.")
        return
        
    enc_time = time.time() - t1
    
    # C. Decode (GPU Parallel)
    t2 = time.time()
    decoder.decode(local_encoded, local_decoded)
    dec_time = time.time() - t2
    
    # Accuracy Check
    out_pt_count = count_points_in_ply(local_decoded)
    
    # BPP Calculation (Bits per Output Point)
    bpp_custom = (size_bytes * 8.0) / max(1, out_pt_count) 
    
    print(f"  DL: {download_time:.3f}s | Enc: {enc_time:.3f}s | Dec: {dec_time:.3f}s")
    print(f"  Size: {size_bytes} B | BPP (est): {bpp_custom:.2f}")
    print(f"  Accuracy Check: InputNodes={input_node_count}, OutputPoints={out_pt_count}")
    
    # Store Stats
    stats['download'].append(download_time)
    stats['encode'].append(enc_time)
    stats['decode'].append(dec_time)
    stats['bpp'].append(bpp_custom)
    stats['input_nodes'].append(input_node_count)
    stats['output_points'].append(out_pt_count)
    
    # Cleanup
    if os.path.exists(local_fb): os.remove(local_fb)
    if os.path.exists(local_encoded): os.remove(local_encoded)
    if os.path.exists(local_decoded): os.remove(local_decoded)

# List files in Bucket (FlatBuffers)
print("Listing .fb files in GCS...")
gcs_list = !gsutil ls gs://{BUCKET_NAME}/{FB_PREFIX}/*.fb
all_files = [f for f in gcs_list if f.endswith('.fb')]
print(f"Found {len(all_files)} FlatBuffer files.")

# Sample N Files
num_to_sample = min(len(all_files), NUM_SAMPLES)
print(f"Benchmarking {num_to_sample} random files...")
samples = random.sample(all_files, num_to_sample)

for f in samples:
    benchmark_gcs_file(f)

# -----------------------------------------------------
# FINAL SUMMARY
# -----------------------------------------------------
if len(stats['encode']) > 0:
    avg_enc = sum(stats['encode']) / len(stats['encode'])
    avg_dec = sum(stats['decode']) / len(stats['decode'])
    avg_bpp = sum(stats['bpp']) / len(stats['bpp'])
    
    print("\n==========================================")
    print("           BENCHMARK RESULTS              ")
    print("==========================================")
    print(f"Files Processed:   {len(stats['encode'])}")
    print(f"Avg Encode Time:   {avg_enc:.4f} s  ({1.0/avg_enc:.2f} Hz)")
    print(f"Avg Decode Time:   {avg_dec:.4f} s  ({1.0/avg_dec:.2f} Hz)")
    print(f"Avg Compressed BPP:{avg_bpp:.4f} (Bits per Point)")
    print("==========================================")
    
    # Note on Accuracy
    print("Accuracy Validation (Sample 5):")
    for i, count in enumerate(stats['output_points'][:5]):
        print(f"  File {i+1}: {count} points reconstructed.")
else:
    print("No files processed.")