# Tutorial 6: Saving and Exporting Models (LibTorch-friendly)

Train a small tabular classifier, save/load its weights, and export a TorchScript artifact that can be loaded from C++ via LibTorch. ZenML is not required for these steps.

In [1]:
import torch
from pathlib import Path

from pioneerml.pipelines.tutorial_examples.tabular_datamodule_pipeline import (
    TabularConfig,
    TabularClassifier,
    TabularDataModule,
)
from pioneerml.common.zenml.utils import detect_available_accelerator


## 1) Train a small model (Lightning for convenience)

We reuse the tabular DataModule/LightningModule to get a trained model quickly.

In [2]:
config = TabularConfig(num_samples=200, num_features=8, num_classes=3, batch_size=32)
datamodule = TabularDataModule(config)
datamodule.setup(stage="fit")

accelerator, devices = detect_available_accelerator()
model = TabularClassifier(config)

try:
    import pytorch_lightning as pl

    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=3,
        limit_train_batches=5,
        limit_val_batches=2,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )
    trainer.fit(model, datamodule=datamodule)
    model.eval()
    print("Training complete.")
except Exception as exc:
    print(f"Skipping training (dependency/runtime issue): {exc}")
    model.eval()


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 1.4 K  | train
---------------------------------------------
1.4 K     Trainable params
0         Non-trainable params
1.4 K     Total params
0.006     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


[33m/home/jack/virtual_environments/miniconda3/envs/pioneerml/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the [0m[38;5;105mnum_workers[33m argument[0m[38;5;105m to [33mnum_workers=15[0m[38;5;105m in the [33mDataLoader` to improve performance.
[0m
[33m/home/jack/virtual_environments/miniconda3/envs/pioneerml/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the [0m[38;5;105mnum_workers[33m argument[0m[38;5;105m to [33mnum_workers=15[0m[38;5;105m in the [33mDataLoader` to improve performance.
[0m


`Trainer.fit` stopped: `max_epochs=3` reached.


Training complete.


## 2) Save and reload via `state_dict`

Standard PyTorch approach; portable anywhere PyTorch runs.

In [3]:
save_dir = Path("outputs/tutorials/06_model_exports")
save_dir.mkdir(parents=True, exist_ok=True)

state_path = save_dir / "tabular_classifier.pt"
torch.save(model.state_dict(), state_path)
print("Saved state_dict ->", state_path)

reloaded = TabularClassifier(config)
reloaded.load_state_dict(torch.load(state_path))
reloaded.eval()
print("Reloaded model; param checksum:", sum(p.sum().item() for p in reloaded.parameters()))


Saved state_dict -> outputs/tutorials/06_model_exports/tabular_classifier.pt outputs/tutorials/06_model_exports/tabular_classifier.pt
Reloaded model; param checksum: 2.5657803267240524 2.5657803267240524


## 3) Build a pure PyTorch inference module

To avoid Lightning-specific attributes (e.g., `trainer`) during export, wrap the underlying MLP in a plain `nn.Module`.

In [4]:
class TabularInference(torch.nn.Module):
    def __init__(self, num_features: int, num_classes: int):
        super().__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_features, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

# Initialize inference module and load weights from the Lightning model
inference_model = TabularInference(config.num_features, config.num_classes)
# Filter state_dict keys that belong to the Sequential under `model`
inference_state = {k.replace("model.", "", 1): v for k, v in reloaded.state_dict().items() if k.startswith("model.")}
inference_model.load_state_dict(inference_state, strict=False)
inference_model.eval()


TabularInference(
  (model): Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=3, bias=True)
  )
)

## 4) Export to TorchScript for LibTorch/C++

Export **both** TorchScript variants so the difference is explicit:

- `torch.jit.script(model)`: compiles from Python code and preserves control flow (`if`, loops).
- `torch.jit.trace(model, example_input)`: records ops for one example path; great for static feed-forward graphs, but it can miss data-dependent branches.

In production, prefer `script` when possible; use `trace` when scripting is not supported or for simple static models.


In [5]:
example_input = torch.randn(1, config.num_features)

# 1) Scripted TorchScript (captures model code/control flow)
scripted_mod = torch.jit.script(inference_model)
scripted_path = save_dir / "tabular_classifier_scripted.pt"
scripted_mod.save(scripted_path)
print("Saved scripted TorchScript ->", scripted_path)

# 2) Traced TorchScript (records ops from an example input)
traced_mod = torch.jit.trace(inference_model, example_input)
traced_path = save_dir / "tabular_classifier_traced.pt"
traced_mod.save(traced_path)
print("Saved traced TorchScript ->", traced_path)

# Verify round-trip load for both artifacts
loaded_scripted = torch.jit.load(scripted_path)
loaded_traced = torch.jit.load(traced_path)

with torch.no_grad():
    out_scripted = loaded_scripted(example_input)
    out_traced = loaded_traced(example_input)

print("Scripted output shape:", tuple(out_scripted.shape))
print("Traced output shape:", tuple(out_traced.shape))
print("Max |scripted - traced| on sample input:", float((out_scripted - out_traced).abs().max().item()))


Saved TorchScript -> outputs/tutorials/06_model_exports/tabular_classifier_scripted.pt outputs/tutorials/06_model_exports/tabular_classifier_scripted.pt
TorchScript output shape: (1, 3) (1, 3)


## 5) Loading from C++ (LibTorch)

Both exported files can be loaded from C++ via `torch::jit::load`:

- `tabular_classifier_scripted.pt`
- `tabular_classifier_traced.pt`

A minimal C++ snippet:

```cpp
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>

int main() {
    torch::jit::script::Module module = torch::jit::load("tabular_classifier_scripted.pt");
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::randn({1, 8})); // match num_features
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.sizes() << std::endl;
    return 0;
}
```
