# Getting Started

In [None]:
from pathlib import Path
import lightning as L

from ssg_tools.dataset.hetero_dataset import HeteroSceneGraphModule
from ssg_tools.models.ksgn import IncrementalKSGN
from lightning.pytorch.callbacks import ModelCheckpoint

Might need to adjust the root path

In [None]:
root = Path('/path/to/data/root')
data_dir = root / "hetero_scene_graph"
artifacts_dir = Path("artifacts")

In [None]:
datamodule = HeteroSceneGraphModule(root, 
                                    shuffle=True, 
                                    batch_size=64,
                                    num_workers=16, 
                                    corruption_rate=0.0, 
                                    embedding_type="clip",
                                    edges_to_remove=[("new", "to", "old")],  # hgt fails otherwise
                                    hierarchical=True,
                                    pin_memory=True,
                                    )
datamodule.setup("fit")

Multiple (n=batchsize) heterogeneous graphs (frames) are fused into one big disconnected graph as batch

In [None]:
batch = next(iter(datamodule.train_dataloader()))
batch

Metadata is necessary to build the heterogeneous GNN

In [None]:
metadata = batch.metadata()
metadata

In [None]:
model = IncrementalKSGN(metadata=metadata,
                        gnn_type="sage",
                        embedding_type="clip",
                        pointnet_output_dim=32,
                        edge_emb_dim=64,
                        gnn_hidden_dim=256,
                        gnn_out_dim=128,
                        gnn_num_layers=2,
                        dropout=0.5,
                        lr=0.00014981950627824397,
                        norm='layer',
                        num_node_classes=27,
                        num_edge_classes=16,
                        nodes_to_predict=["new"],
                        edges_to_predict=[("new", "to", "new")],
                        weights_node="nodes_log_scaled",
                        weights_edge="edges_log_scaled",
                        weights_path=artifacts_dir / "weights.json",
                        gamma_edge=40.0,
                        )

Optional: Callbacks for saving the model.

In [None]:
checkpoint_topk = ModelCheckpoint(monitor='epoch', mode='max', save_top_k=1, filename='last')

# from lightning.pytorch.loggers import MLFlowLogger
# mlf_logger = MLFlowLogger(experiment_name="Default", log_model=False, save_dir="logs")

In [None]:
trainer = L.Trainer(fast_dev_run=True,
                    # logger=mlf_logger,
                    max_epochs=1,
                    accelerator='auto',
                    callbacks=[checkpoint_topk],
                    limit_train_batches=0.05,
                    limit_val_batches=0.05,
                    limit_test_batches=0.01,
                    )
trainer.fit(model, datamodule)
# trainer.validate(model, datamodule)
res = trainer.test(model, datamodule)
