# ðŸª° Bio-Fluid Surrogate: LBM + JAX Online Training

This notebook sets up a high-performance training environment for the Fly Surrogate framework.

**Hardware Note:** This notebook requires a **GPU runtime** (T4, L4, or A100). 

**Pipeline Overview:**
* **Drive Integration:** Mounts Google Drive to persist training checkpoints and `.gif` visualizations.
* **Memory Management:** Configures XLA and Taichi to share VRAM without conflicts.
* **LBM Engine:** Runs the Taichi-accelerated fluid solver to generate ground-truth data.
* **JAX Training:** Updates the ResNet surrogate model in real-time.

In [None]:
# 1. MOUNT GOOGLE DRIVE
from google.colab import drive
import os

print("--> Mounting Google Drive...")
drive.mount('/content/drive')

# 2. SETUP PROJECT DIRECTORY
project_root = '/content/drive/MyDrive/FlySurrogate_Dev'

if not os.path.exists(project_root):
    print(f"--> Creating project folder at {project_root}...")
    os.makedirs(project_root)
else:
    print(f"--> Found existing project folder at {project_root}")

os.chdir(project_root)

In [None]:
# 3. CLONE OR UPDATE REPOSITORY
repo_name = "fly_surrogate"
repo_url = "https://github.com/lhooz/fly_surrogate.git"

if not os.path.exists(repo_name):
    print(f"--> Cloning {repo_name}...")
    !git clone {repo_url}
else:
    print("--> Repository exists. Pulling latest changes...")
    %cd {repo_name}
    !git pull
    %cd ..

In [None]:
# 4. INSTALL DEPENDENCIES
repo_path = os.path.join(project_root, repo_name)
os.chdir(repo_path)

print("--> Installing dependencies from pyproject.toml...")
!pip install -e .

In [None]:
# 5. CONFIGURE GPU MEMORY & LAUNCH TRAINING
import os

# Prevent JAX from taking 90% of VRAM so Taichi has space
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45"

print("--> Starting Online Training Loop...")
print("--> Note: Visualizations will be saved to the 'checkpoints' folder.")

# Run the training script directly
!python train_surrogate_jax.py

In [None]:
# 6. DISPLAY LATEST VISUALIZATION
import glob
import os
from IPython.display import Image, display

viz_files = glob.glob("checkpoints/viz_cycle_*.gif")
if viz_files:
    latest_viz = max(viz_files, key=os.path.getctime)
    print(f"--> Displaying latest simulation: {latest_viz}")
    display(Image(filename=latest_viz, width=600))
else:
    print("No visualizations found yet. Ensure training has reached at least Cycle 0.")