## Demonstration of Conditional Inference with Trained Model

In [1]:
import os
import torch
import pyvista as pv

from geogen.model import GeoModel
import geogen.plot as geovis

In [2]:
# Set device for inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Set the configuration to match that from the trained model

In [3]:
config = {
    "project": {
        "name": "generative-conditional-3d",
        "root_dir": os.getcwd(),
    },
    # Data loader configurations
    "data": {
        "shape": (64, 64, 64),  # [C, X, Y, Z]
        "bounds": (
            (-1920, 1920),
            (-1920, 1920),
            (-1920, 1920),
        ),
        "batch_size": 8,
        "epoch_size": 10_000,
    },
    # Categorical embedding parameters
    "embedding": {
        "num_categories": 15,
        "dim": 15,
    },
    # Inference parameters
    "inference": {
        "seed": None,
        "n_samples": 1,
        "batch_size": 4,
        "save_imgs": True,
    },
}

### Generating Conditional Samples
The model is trained on surface and borehole data. To generate samples, a random sample of conditinal data should be sampled from the StructuralGeo synthetic geology generator.

In [4]:
from model_inference_experiments import (
    create_cond_data,
    save_model_and_boreholes,
    load_model_and_boreholes,
    show_model_and_boreholes,
)


def show_model_and_boreholes(model, boreholes):
    """
    Plot the model and boreholes side by side. Two 3D tensor inputs
    """
    # Make two pane pyvista plot
    p = pv.Plotter(shape=(1, 2))

    # Plot the synthetic model
    p.subplot(0, 0)
    m = GeoModel.from_tensor(model.squeeze().detach().cpu())
    geovis.volview(m, plotter=p, show_bounds=True)

    # Select 2nd pane
    p.subplot(0, 1)
    bh = GeoModel.from_tensor(boreholes.squeeze().detach().cpu())
    geovis.volview(bh, plotter=p, show_bounds=True)

    p.show()


# Generate conditional samples and save to folder
save_dir = os.path.join(config["project"]["root_dir"], "samples/jupyter-demo")
cond_data_folder_title = "cond_generation"
num_samples = 4
create_cond_data(save_dir, cond_data_folder_title, device, num_samples)

In [5]:
# Load and check saved model
sample_number = 0
samples_dir = os.path.join(save_dir, f"{cond_data_folder_title}_{sample_number}")
model, boreholes = load_model_and_boreholes(samples_dir)
show_model_and_boreholes(model, boreholes)

Widget(value='<iframe src="http://localhost:36623/index.html?ui=P_0x7d484a41c1a0_0&reconnect=auto" class="pyvi…

## Conditional Inference
From the generated conditional values, the trained inferennce model is used to generate multiple reconstructions.

In [6]:
from utils import download_if_missing
from model_inference_experiments import load_model_with_ema_option

relative_checkpoint_path = os.path.join("demo_model", "conditional-weights.ckpt")
checkpoint_path = os.path.join(config["project"]["root_dir"], relative_checkpoint_path)
weights_url = "https://github.com/chipnbits/flowtrain_stochastic_interpolation/releases/download/v1.0.0/conditional-weights.ckpt"
download_if_missing(checkpoint_path, weights_url)

flowmatching_model = load_model_with_ema_option(
    checkpoint_path, map_location=device, use_ema=True
)

Applying EMA shadow to model...


An auto-populating function is provided that 

1. Iterates through a folder `save_dir` containing subfolders `cond_data_folder_title` with paired data `boreholes.pt` and `true_model.pt` containing the boreholes extracted from the ground truth geological model.
2. Creates the conditional data for the inverse problem that includes surface, air, and boreholes from `boreholes.pt` and `true_model.pt`
3. Runs the inference routine on the data to produce `n_samples_each` for each set of conditional data
4. Saves the solutions in the same subfolder with `sample_title_000.pt` naming convention

The script below will sample 9 conditional reconstructions for each pair of boreholes with true model. (The true model is only used to get surface and air data, subsurface is not used in the inference)

The sample time is long, so precomputed inference results available for demonstration of ensemble analysis below. To run the inference locally, set `USE_PRECOMPUTED_INFERENCE_RESULTS = False` below.

In [7]:
USE_PRECOMPUTED_INFERENCE_RESULTS = True

In [8]:
from model_inference_experiments import populate_solutions

if not USE_PRECOMPUTED_INFERENCE_RESULTS:
    populate_solutions(
        save_dir=save_dir,
        cond_data_folder_title=cond_data_folder_title,
        device=device,
        model=flowmatching_model,
        n_samples_each=9,
        batch_size=1,
        sample_title="sample",
    )

### Loading and Viewing Solutions

In [9]:
from model_inference_experiments import load_solutions, show_solutions

if not USE_PRECOMPUTED_INFERENCE_RESULTS:
    # Same folder as the stored conditional data
    sample_number = 0
    samples_dir = os.path.join(save_dir, f"{cond_data_folder_title}_{sample_number}")
    print("Loading from:", samples_dir)
    # Autoparse the true_model.pt, boreholes.pt, and any solutions
    geomodel, boreholes = load_model_and_boreholes(samples_dir)
    solutions = load_solutions(samples_dir, sample_title="sample")
    show_model_and_boreholes(geomodel, boreholes)
    show_solutions(solutions)

### Ensemble Analysis
Precomputed models and borehole data is provided by decompression of an archive with the model featured in the paper.

In [10]:
import os, tarfile, tempfile, gzip


def unpack_pt_archive(archive_path: str, dest_dir: str) -> None:
    """
    Decompress a .tar.gz produced by pack_pt_folder into dest_dir,
    recreating the original filenames/structure (only .pt files are restored).
    """
    os.makedirs(dest_dir, exist_ok=True)

    # Stream-decompress gzip to a temp tar for safe random access
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        tmp_path = tmp.name

    try:
        with gzip.open(archive_path, "rb") as gz, open(tmp_path, "wb") as out:
            while True:
                chunk = gz.read(1024 * 1024)
                if not chunk:
                    break
                out.write(chunk)

        with tarfile.open(tmp_path, "r") as tar:

            def is_within_directory(directory, target):
                ad = os.path.abspath(directory)
                at = os.path.abspath(target)
                return os.path.commonprefix([ad, at]) == ad

            for member in tar.getmembers():
                if not (member.isfile() and member.name.endswith(".pt")):
                    continue
                target_path = os.path.join(dest_dir, member.name)
                if not is_within_directory(dest_dir, target_path):
                    raise Exception("Unsafe path in archive")
                os.makedirs(os.path.dirname(target_path), exist_ok=True)
                with tar.extractfile(member) as src, open(target_path, "wb") as dst:
                    dst.write(src.read())
    finally:
        try:
            os.unlink(tmp_path)
        except FileNotFoundError:
            pass

Decompress `boreholes.pt`, `true_model.pt`, and solutions into a samples directory

In [11]:
root = os.getcwd()
rel = "samples/jupyter-demo/paper_cond_gen_0"
samples_dir = os.path.join(root, rel)
archive_path = os.path.join(root, "dikes_ptpack.tar.gz")
unpack_pt_archive(archive_path, samples_dir)
print("Restored to:", samples_dir)

Restored to: /home/sghys/projects/flowtrain_stochastic_interpolation/project/geodata-3d-conditional/samples/jupyter-demo/paper_cond_gen_0


Verify the data loaded and displays

In [12]:
sample_number = 0
geomodel, boreholes = load_model_and_boreholes(samples_dir, device="cpu")
solutions = load_solutions(samples_dir, sample_title="sample", device="cpu")
show_model_and_boreholes(geomodel, boreholes)
# Limit to 10 solutions for display
show_solutions(solutions[0:10])

Widget(value='<iframe src="http://localhost:36623/index.html?ui=P_0x7d484de739b0_1&reconnect=auto" class="pyvi…

Widget(value='<iframe src="http://localhost:36623/index.html?ui=P_0x7d484e2fa9f0_2&reconnect=auto" class="pyvi…

In [15]:
def vote_probabilities(solutions: torch.Tensor, num_categories: int = 15) -> torch.Tensor:
    """
    Compute per-voxel class probabilities over a batch.
    Input:  [B, X, Y, Z] integer categories (may include -1)
    Output: [C, X, Y, Z] float probabilities
    """
    assert solutions.dim() == 4
    B, X, Y, Z = solutions.shape
    device = solutions.device

    # Handle negative indices (-1 for "air")
    if solutions.min().item() < 0:
        solutions = solutions + 1  # shift to 0..C-1

    solutions = solutions.to(torch.long)

    # Accumulator for per-class voxel counts
    accumulator = torch.zeros(num_categories, X, Y, Z, dtype=torch.float32, device=device)

    # Accumulate one-hot for each sample
    for b in range(B):
        one_hot = torch.nn.functional.one_hot(solutions[b], num_classes=num_categories)  # [X, Y, Z, C]
        one_hot = one_hot.permute(3, 0, 1, 2).float()                 # [C, X, Y, Z]
        accumulator += one_hot

    # Normalize by total samples
    probabilities = accumulator / B
    return probabilities


solution_probabilistic = vote_probabilities(solutions, num_categories=15)

## Display Probabilistic Results
The probability for one of the dike categories is compared against the true model and the conditional data.

In [16]:
import numpy as np


def get_voxel_grid_from_tensor(
    data, bounds=((-1920, 1920), (-1920, 1920), (-1920, 1920)), threshold=None
):
    """ """
    assert data.ndim == 3, "Data must be 3D"
    dims = data.shape

    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()

    # Create a padded grid with n+1 nodes and node spacing equal to model sample spacing
    dimensions = tuple(x + 1 for x in dims)
    spacing = tuple((x[1] - x[0]) / (r - 1) for x, r in zip(bounds, dims))
    # pad origin with a half cell size to center the grid
    origin = tuple(x[0] - cs / 2 for x, cs in zip(bounds, spacing))

    # Create a structured grid with n+1 nodes in each dimension forming n^3 cells
    grid = pv.ImageData(
        dimensions=dimensions,
        spacing=spacing,
        origin=origin,
    )
    # Necessary to reshape data vector in Fortran order to match the grid
    grid["values"] = data.flatten(order="F")
    grid = grid.threshold(threshold, all_scalars=True)

    return grid


geomodel.squeeze_().squeeze_()
true_grid = get_voxel_grid_from_tensor(
    geomodel, bounds=((-1920, 1920), (-1920, 1920), (-1920, 1920)), threshold=-0.5
)
boreholes.squeeze_().squeeze_()
borehole_grid = get_voxel_grid_from_tensor(
    boreholes, bounds=((-1920, 1920), (-1920, 1920), (-1920, 1920)), threshold=-0.5
)

# Index into the category of choice
DIKE_VALS = [6, 7, 8]
dike_indices = [x + 1 for x in DIKE_VALS]  # Account for earlier shift
dike_probs = solution_probabilistic[dike_indices, :, :, :]  # Slice out only dike layers

# Set to spatial coords
x = np.linspace(-1920, 1920, 64)
y = np.linspace(-1920, 1920, 64)
z = np.linspace(-1920, 1920, 64)
x, y, z = np.meshgrid(x, y, z, indexing="ij")  # Ensure correct shape order
mesh = pv.StructuredGrid(x, y, z)

# Extract the first channel
dike1_data = dike_probs[1].detach().cpu().numpy().ravel(order="F")
mesh["dike1"] = dike1_data  # Assign to mesh

dike2_data = dike_probs[2].detach().cpu().numpy().ravel(order="F")
mesh["dike2"] = dike2_data  # Assign to mesh

dike1_true = true_grid.copy()
dike1_true["values"] = np.where(
    ~np.isin(dike1_true["values"], DIKE_VALS[1]), -1, dike1_true["values"]
)
dike1_true = dike1_true.threshold(-0.5, all_scalars=True)

dike2_true = true_grid.copy()
dike2_true["values"] = np.where(
    ~np.isin(dike2_true["values"], DIKE_VALS[2]), -1, dike2_true["values"]
)
dike2_true = dike2_true.threshold(-0.5, all_scalars=True)

dike1_samples = borehole_grid.copy()
dike1_samples["values"] = np.where(
    ~np.isin(dike1_samples["values"], DIKE_VALS[1]), -1, dike1_samples["values"]
)
dike1_samples = dike1_samples.threshold(-0.5, all_scalars=True)

dike2_samples = borehole_grid.copy()
dike2_samples["values"] = np.where(
    ~np.isin(dike2_samples["values"], DIKE_VALS[2]), -1, dike2_samples["values"]
)
dike2_samples = dike2_samples.threshold(-0.5, all_scalars=True)

p = pv.Plotter()
p.add_mesh(
    dike1_true,
    color="orange",
    show_scalar_bar=False,
    interpolate_before_map=False,
    opacity=0.3,
)
p.add_mesh(
    dike1_samples,
    scalars=None,
    color="red",
    show_scalar_bar=False,
    interpolate_before_map=False,
    opacity=1.0,
)
p.add_title("Dike 1 True with Borehole Samples")
p.show()

p = pv.Plotter()
p.add_mesh(
    dike1_samples,
    scalars=None,
    color="red",
    show_scalar_bar=False,
    interpolate_before_map=False,
    opacity=1.0,
)

contour = mesh.contour([0.05, 0.3, 0.6, 0.9], scalars=f"dike1")
p.add_mesh(
    contour,
    opacity=0.3,
    cmap="Wistia",
    show_scalar_bar=False,
)

p.add_scalar_bar(
    f"Probability Contour",
    vertical=False,
    title_font_size=24,
    label_font_size=24,
    fmt="%.2f",
    n_labels=4,
)

p.add_title("Dike 1 Probabilistic Surfaces with Borehole Samples")
p.show()

Widget(value='<iframe src="http://localhost:36623/index.html?ui=P_0x7d483cb153d0_3&reconnect=auto" class="pyvi…

Widget(value='<iframe src="http://localhost:36623/index.html?ui=P_0x7d486e937620_4&reconnect=auto" class="pyvi…