# New contrast method

Testing out a new contrast method. 

In [None]:
from hydra import initialize, compose
from omegaconf import OmegaConf

# Initialize Hydra with the directory where your config lives.
# Note that hydra will tkae care of composing all our disparate config files
with initialize(config_path="conf", job_name="notebook_app"):
    # Compose the configuration, using "train" as the config name.
    cfg = compose(config_name="train", 
                  overrides=[
                      "model=vanilla_contrast_v3",
                      "scene.data_root=/home/jackd/source/egolifter/adt_processed",
                      "scene.scene_name=Apartment_release_golden_skeleton_seq100_10s_sample",
                      "output_root=./output/adt",
                      "exp_name=3dgs_new_contrast",
                      "lift.use_contr=True",
                      "wandb.project=egolifter_adt"
                  ])

In [None]:
# Now you can use cfg to see what was loaded.
print(OmegaConf.to_yaml(cfg))

In [None]:
# Make the output directory
import os
os.makedirs(cfg.scene.model_path, exist_ok=True)

In [None]:
# Set up the logger (wandb)
from lightning.pytorch.loggers import WandbLogger

# Make the wandb directory
os.makedirs(os.path.join(cfg.scene.model_path, "wandb"), exist_ok=True)
os.makedirs(cfg.wandb.save_dir, exist_ok=True)

# Create the logger
logger = WandbLogger(
    project=cfg.wandb.project, 
    entity=cfg.wandb.entity,
    name=cfg.exp_name,
    save_dir=cfg.wandb.save_dir,
)

# Tell the logger what hyperparameters to log
logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True))

In [None]:
# (OPTIONAL) And save the config to the output directory
# This is useful for keeping track of what you ran
OmegaConf.save(cfg, os.path.join(cfg.scene.model_path, "config.yaml"), resolve=True)

In [None]:
import lightning as L
from scene import Scene

# Set the seed for reproducibility
L.seed_everything(cfg.seed)

# Create a new scene object
scene = Scene(cfg)

In [None]:
from model import get_model

# Load the model. This is one of our LightningModules (i.e., VanillaGaussian, Unc2DUnet, etc.)
model = get_model(cfg, scene)
print(model)

In [None]:
# This will load an initial point cloud. The point cloud is loaded from scene.scene_info.point_cloud, which was initialized
# above in the Scene class. Internally, that comes from global_points.csv.gz file (Aria dataset only; other datasets 
# init this differently).
model.init_or_load_gaussians(
    scene.scene_info.point_cloud,
    scene.scene_info.nerf_normalization["radius"], # NOTE: not sure that this does... 
    cfg.scene.model_path,
    load_iteration = None,
)

In [None]:
# Loop over the model to print the parameters
for name, param in model.named_parameters():
    print(name, param.shape)

In [None]:
# Load the data loader. This is a PyTorch DataLoader object that will load the data for training.
train_loader = scene.get_data_loader("train", shuffle=True, num_workers=cfg.scene.num_workers)
valid_loader = scene.get_data_loader("valid", shuffle=False, num_workers=cfg.scene.num_workers)
valid_novel_loader = scene.get_data_loader("valid_novel", shuffle=False, num_workers=cfg.scene.num_workers)

In [None]:
# Init the trainer
trainer = L.Trainer(
    max_steps=cfg.opt.iterations,
    logger=logger,
    check_val_every_n_epoch=None,
    val_check_interval = cfg.opt.val_every_n_steps, # validation after every 5000 steps
    # callbacks=[checkpoint_callback],
    devices=cfg.gpus, 
)

In [None]:
# Train the model!
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=[valid_loader, valid_novel_loader],
)