# Example Model Loading

This notebook is an example for loading the pre-trained model. Given that this is just a test, you
will only be able to load this model from Explore for now.

In [1]:
!pip install pytorch-lightning

Defaulting to user installation because normal site-packages is not writeable
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting torch>=2.1.0 (from pytorch-lightning)
  Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading pytorch_lightning-2.5.2-py3-none-any.whl (825 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.4/825.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.0/821.0 MB[0m [31m32.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading torchmetrics-1.8.0-py3-none-any.whl (981 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.9/981.9 kB[0m [31m7

In [1]:
import sys

sys.path.append('/explore/nobackup/people/jacaraba/development/satvision-pix4d')

import os
import torch
import logging
import argparse

import warnings

warnings.filterwarnings("ignore", message=".*cuda capability 7.0.*")

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

from satvision_pix4d.configs.config import _C, _update_config_from_file
from satvision_pix4d.utils import get_strategy, get_distributed_train_batches
from satvision_pix4d.pipelines import PIPELINES, get_available_pipelines
from satvision_pix4d.datamodules import DATAMODULES, get_available_datamodules

[2025-07-28 14:39:56,463] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cpu (auto detect)


/explore/nobackup/people/soehrle/envs/satvis_kernel/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-07-28 14:40:06,872] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


In [2]:
model_filename = '/explore/nobackup/projects/pix4dcloud/jacaraba/model_development/satmae/' + \
    'satmae_satvision_pix4d_pretrain-dev/satmae_satvision_pix4d_pretrain-dev/epoch-epoch=0.ckpt/checkpoint/mp_rank_00_model_states.pt'

In [3]:
config_filename = '/explore/nobackup/people/jacaraba/development/satvision-pix4d/tests/configs/test_satmae_dev.yaml'

In [4]:
config = _C.clone()
_update_config_from_file(config, config_filename)

In [5]:
import os
import torch
import logging
import argparse

import warnings

warnings.filterwarnings("ignore", message=".*cuda capability 7.0.*")

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

from satvision_pix4d.configs.config import _C, _update_config_from_file
from satvision_pix4d.utils import get_strategy, get_distributed_train_batches
from satvision_pix4d.pipelines import PIPELINES, get_available_pipelines
from satvision_pix4d.datamodules import DATAMODULES, get_available_datamodules


# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main(config, output_dir):

    logging.info('Training')

    # Save configuration path to disk
    path = os.path.join(
        output_dir,
        f"{config.TAG}.config.json"
    )

    with open(path, "w") as f:
        f.write(config.dump())

    logging.info(f"Full config saved to {path}")
    logging.info(config.dump())

    # Get the proper pipeline
    available_pipelines = get_available_pipelines()
    logging.info("Available pipelines:", available_pipelines)

    pipeline = PIPELINES[config.PIPELINE]
    logging.info(f'Using {pipeline}')

    ptlPipeline = pipeline(config)

    # Resume from checkpoint
    if config.MODEL.RESUME:
        logging.info(
            f'Attempting to resume from checkpoint {config.MODEL.RESUME}')
        ptlPipeline = pipeline.load_from_checkpoint(config.MODEL.RESUME)

    # Determine training strategy
    strategy = get_strategy(config)

    # Define core callbacks
    checkpoint_best = ModelCheckpoint(
        dirpath=output_dir,
        monitor="val_loss",
        mode="min",
        save_top_k=3,
        filename="best-{epoch}-{val_loss:.4f}",
    )
    checkpoint_periodic = ModelCheckpoint(
        dirpath=output_dir,
        every_n_epochs=1,
        save_top_k=-1,   # Save all checkpoints at the specified interval
        filename="epoch-{epoch}",
        save_last=True
    )
    lr_monitor_cb = LearningRateMonitor(logging_interval="epoch")

    # MLflow logger
    mlflow_logger = MLFlowLogger(
        experiment_name=config.TAG,
        # tracking_uri="file://" + os.path.abspath(config.OUTPUT),
        tracking_uri="file:///explore/nobackup/projects/pix4dcloud/mlruns",
        tags={
            "Model": config.MODEL.NAME,
            "Pipeline": config.PIPELINE,
            "Notes": config.DESCRIPTION if hasattr(config, "DESCRIPTION") else "",
        }
    )
    mlflow_logger.experiment.log_artifact(
        mlflow_logger.run_id,
        path
    )

    trainer = Trainer(
        accelerator=config.TRAIN.ACCELERATOR,
        devices=torch.cuda.device_count(),
        strategy=strategy,
        precision=config.PRECISION,
        max_epochs=config.TRAIN.EPOCHS,
        gradient_clip_val=1.0,
        # accumulate_grad_batches=getattr(config.TRAIN, "ACCUM_ITER", 1),  # If you have gradient accumulation
        log_every_n_steps=config.PRINT_FREQ,
        default_root_dir=output_dir,
        callbacks=[
            checkpoint_best,
            checkpoint_periodic,
            lr_monitor_cb
        ],
        logger=mlflow_logger
    )

    # limit the number of train batches for debugging
    if config.TRAIN.LIMIT_TRAIN_BATCHES:
        trainer.limit_train_batches = get_distributed_train_batches(
            config, trainer)

    # setup datamodule
    if config.DATA.DATAMODULE:
        available_datamodules = get_available_datamodules()
        logging.info(f"Available data modules: {available_datamodules}")
        datamoduleClass = DATAMODULES[config.DATAMODULE]
        datamodule = datamoduleClass(config)
        logging.info(f'Training using datamodule: {config.DATAMODULE}')
        
        trainer.fit(model=ptlPipeline, datamodule=datamodule)

        # quick test of datamodule
        #datamodule.setup(stage=None)
        #print("Train dataset size:", len(datamodule.trainset))
        #print("Validation dataset size:", len(datamodule.validset))

        #sample = datamodule.trainset[0]
        #print("Sample type:", type(sample))

        # If your dataset returns a tuple
        #if isinstance(sample, tuple):
        #    x, y = sample
        #    print("x shape:", x.shape)
        #    print("y shape:", y.shape)

    else:
        logging.info(
            'Training without datamodule, assuming data is set' +
            f' in pipeline: {ptlPipeline}')
        trainer.fit(model=ptlPipeline)

    return