# SatVision Inference Using Triton Server Inside NASA GSFC

This Jupyter Notebook demonstrates how to perform inference using the SatVision foundation model deployed on a Triton Inference Server inside NASA GSFC. It walks through the end-to-end process of formatting inputs, sending requests to the server, and retrieving model predictions. The notebook is designed to support high-throughput inference for multi-channel MODIS TOA data, leveraging GPU acceleration and Triton’s efficient serving capabilities for scalable downstream applications.

The idea of this server is to deploy the model so users can extract the features needed for model training, and then be able to train their new models without the resources needed to run SatVision-TOA.

## 1. Download Configuration Dependencies

### 1.1. Download SatVision Repository

In [None]:
import os
import sys
import torch
import urllib
import subprocess
import gevent.ssl
import numpy as np
import matplotlib.pyplot as plt
import tritonclient.http as httpclient

from glob import glob
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download

In [None]:
repo_url = "https://github.com/nasa-nccs-hpda/satvision-toa"
if not os.path.exists('satvision-toa'):
    subprocess.run(["git", "clone", repo_url, 'satvision-toa'], check=True)
    print(f"Cloned {repo_url} into satvision-toa")
else:
    print("Repository already exists.")

In [None]:
config_url = "https://huggingface.co/nasa-cisto-data-science-group/" + \
    "satvision-toa-giant-patch8-window8-128/resolve/main/" + \
    "mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml"
config_output_path = os.path.join("mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml")

In [None]:
if not os.path.exists(config_output_path):
    urllib.request.urlretrieve(config_url, config_output_path)
    print(f"Downloaded to {config_output_path}")

### 1.2. Setup SatVision Dependencies

In [None]:
# setting up the path and dependencies
sys.path.append('satvision-toa')
from satvision_toa.models.mim import build_mim_model
from satvision_toa.transforms.mim_modis_toa import MimTransform
from satvision_toa.configs.config import _C, _update_config_from_file
from satvision_toa.plotting.modis_toa import plot_export_pdf

## 1.3 Load SatVision Configuration

In [None]:
# load model config
config = _C.clone()
_update_config_from_file(config, config_output_path)

# 2. Setup Triton Server Client

In [None]:
triton_server_url = "gs6n-dgx02.sci.gsfc.nasa.gov"

In [None]:
# Initialize the Triton client
ssl_context_factory = gevent.ssl._create_unverified_context
client = httpclient.InferenceServerClient(
    url=triton_server_url,
    ssl=True,
    insecure=True,
    ssl_context_factory=ssl_context_factory
)

# 3. Download Data from HuggingFace

In [None]:
hf_dataset_repo_id: str = 'nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation'

In [None]:
validation_tiles_dir = snapshot_download(repo_id=hf_dataset_repo_id, allow_patterns="*.npy", repo_type='dataset')
validation_tiles_regex = os.path.join(validation_tiles_dir, '*.npy')
validation_tiles_filename = next(iter(glob(validation_tiles_regex)))
validation_tiles = np.load(validation_tiles_filename)
validation_tiles.shape

# 3. Perfom Inference

## 3.1 Apply Transform

This section will be performed at the triton server at a later time.

In [None]:
# Use the Masked-Image-Modeling transform specific to MODIS TOA data
transform = MimTransform(config)

In [None]:
# Apply transform to each image in the batch
# A mask is auto-generated in the transform
images_and_masks = [transform(validation_tiles[idx]) for idx \
    in range(validation_tiles.shape[0])]

# Seperate img and masks, cast masks to torch tensor
images = np.stack([image_mask_list[0] for image_mask_list in images_and_masks])
masks = np.stack([image_mask_list[1] for image_mask_list in images_and_masks])

## 3.2 Run Inference

In [None]:
input_images = []
input_masks = []
output_images = []

for i in tqdm(range(validation_tiles.shape[0])):

    single_image, single_mask = np.expand_dims(images[i], 0), np.expand_dims(masks[i], 0).astype(bool)

    # Prepare input tensors
    image_tensor = httpclient.InferInput("image", single_image.shape, "FP32")
    image_tensor.set_data_from_numpy(single_image)

    mask_tensor = httpclient.InferInput("mask", single_mask.shape, "BOOL")
    mask_tensor.set_data_from_numpy(single_mask)

    # Specify output tensor
    output_tensor = httpclient.InferRequestedOutput("output")

    # Perform inference
    response = client.infer(
        model_name="satvision_toa_model",
        inputs=[image_tensor, mask_tensor],
        outputs=[output_tensor]
    )

    # Retrieve and print output
    input_images.append(torch.from_numpy(np.squeeze(single_image)))
    input_masks.append(torch.from_numpy(np.squeeze(single_mask)))
    output_images.append(torch.from_numpy(np.squeeze(response.as_numpy("output"))))

# output reconstructions
print(f"Reconstructed {len(output_images)} images")

In [None]:
input_images[0].shape

In [None]:
rgb_index = [0, 2, 1]
plot_export_pdf('reconstructions.pdf', input_images, output_images, input_masks, rgb_index)