# Training a model

This script demonstrates how to train a model on a pre-collected dataset.

In [1]:
import os

# These environment variables control where
# training and eval logs are written.
os.environ["RUN_DIR"] = "runs"
os.environ["EVAL_RUN_DIR"] = "eval_runs"

# This is used to set a constant Tensorboard port.
os.environ["TENSORBOARD_PORT"] = str(8989)

In [2]:
import ml.api as ml  # Source: https://github.com/codekansas/ml-starter

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# Configures logging for the 
ml.configure_logging(use_tqdm=True)

The framework used to train the models specifies five parts:

1. Model: The USA net model follows the vanilla NeRF implementation, and uses a simple MLP mapping 3D points to an output vector
2. Task: This is used to coordinate training by passing the dataset samples to the model and computing the loss function
3. Optimizer
4. Learning rate scheduler
5. Trainer: This 

In [4]:
config = {
    "model": {
        "name": "point2emb",          # `register_model` name in `usa.models.point2emb`
        "num_layers": 4,
        "hidden_dims": 256,
        "output_dims": 513,           # CLIP = 512, SDF = 1
    },
    "task": {
        "name": "clip_sdf",           # `register_task` name in `usa.tasks.clip_sdf`
        "dataset": "lab_r3d",         # Pre-collected dataset
        "clip_model": "ViT_B_16",
        "queries": [
            "Chair",
            "Shelves",
            "Man sitting at a computer",
            "Desktop computers",
            "Wooden box",
            "Doorway",
        ],
        "rotate_image": True,         # Dataset-specific, for visualization purposes
        "finished": {
            "max_steps": 10_000,      # Number of training steps
        },
        "dataloader": {
            "train": {
                "batch_size": 16,
                "num_workers": 0,
                "persistent_workers": False,
            },
        },
    },
    "optimizer": {
        "name": "adam",
        "lr": 3e-4,
    },
    "lr_scheduler": {
        "name": "linear",
    },
    "trainer": {
        "name": "vanilla_sl",
        "exp_name": "jupyter",
        "log_dir_name": "test",
        "base_run_dir": "runs",
        "run_id": 0,
        "checkpoint": {
            "save_every_n_steps": 2500,
            "only_save_most_recent": True,
        },
        "validation": {
            "valid_every_n_steps": 250,
            "num_init_valid_steps": 1,
        },
    },
    "logger": [{"name": "tensorboard"}],
}

In [5]:
objs = ml.instantiate_config(config)

# Unpacking the different components.
model = objs.model
task = objs.task
optimizer = objs.optimizer
lr_scheduler = objs.lr_scheduler
trainer = objs.trainer

  [1;36mINFO[0m   [90m2023-04-17 12:08:23[0m [ml.trainers.base] Experiment directory: /private/home/bbolte/Github/usa-net/notebooks/runs/jupyter/run_0
[1;35mINFOALL[0m  [90m2023-04-17 12:08:23[0m [ml.utils.device.auto] Device: [cuda:0]
  [1;36mINFO[0m   [90m2023-04-17 12:08:23[0m [ml.loggers.tensorboard] Tensorboard command: tensorboard serve --logdir /private/home/bbolte/Github/usa-net/notebooks/runs/jupyter/run_0/test/tensorboard/12-08-21 --bind_all --port 8989 --reload_interval 15
  [1;36mINFO[0m   [90m2023-04-17 12:08:27[0m [ml.loggers.tensorboard] Running TensorBoard process:
-------------------------------------------------------------------
TensorBoard 2.12.2 at http://localhost:8989/ (Press CTRL+C to quit)
-------------------------------------------------------------------
  [1;36mINFO[0m   [90m2023-04-17 12:08:27[0m [ml.core.registry] Components:
 ↪ [32mModel[0m: [36musa.models.point2emb.Point2EmbModel[0m ([34m/private/home/bbolte/Github/usa-net/usa/mo

In [6]:
from tensorboard import notebook

# Show Tensorboard inside the notebook.
notebook.display(port=int(os.environ['TENSORBOARD_PORT']))

# Runs the training loop.
trainer.train(model, task, optimizer, lr_scheduler)

Selecting TensorBoard with logdir /private/home/bbolte/Github/usa-net/notebooks/runs/jupyter/run_0/test/tensorboard/12-08-21 (started 0:00:00 ago; port 8989, pid 1504685).


 ↪ [32m+[0m logger.0.log_id=12-08-21
 ↪ [31m-[0m logger.0.log_id=12-07-10
  [1;36mINFO[0m   [90m2023-04-17 12:08:29[0m [usa.tasks.datasets.r3d] Preprocessing R3D arrays


Loading R3D file: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [00:03<00:00, 225.10it/s]






  [1;36mINFO[0m   [90m2023-04-17 12:08:32[0m [ml.trainers.mixins.cpu_stats] Starting CPU stats monitor for PID 1504360 with PID 1504872
  [1;36mINFO[0m   [90m2023-04-17 12:10:16[0m [ml.trainers.base] Exiting training job for /private/home/bbolte/Github/usa-net/notebooks/runs/jupyter/run_0/config.yaml


KeyboardInterrupt: 