In [1]:
import cuvis_ai
import cuvis
import yaml
from EfficientAD_torch_model import EfficientAdModel
import torch
import numpy as np
from torchvision.transforms import v2
import cv2 as cv

In [2]:
def remove_prefix(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix):]
    return s

In [3]:
CONF = 'example_train_config.yaml'
WEIGHTS = 'EAD_model_0.93_new.ckpt'

## Load Model

First we create a cuvis.ai Graph using some the config parameters from our training.

In [4]:
with open(CONF, 'r') as f:
    config = yaml.safe_load(f)

graph = cuvis_ai.pipeline.Graph('EfficientAD_graph')

modelNode = cuvis_ai.node.wrap.make_node(EfficientAdModel)(
    criterion=torch.nn.NLLLoss, teacher_out_channels=384, model_size='medium', in_channels=6)

## Create Pipeline

Using some transformation nodes, a pipeline for data pre- and postprocessing is created

In [5]:
CenterCrop = cuvis_ai.node.wrap.make_node(v2.CenterCrop)

Resize = cuvis_ai.node.wrap.make_node(v2.Resize)

Pad = cuvis_ai.node.wrap.make_node(v2.Pad)

graph = graph >> CenterCrop((1800, 4300)) >> Resize(
    (1050, 2300)) >> modelNode >> Pad(padding=150)

In [6]:
torch.load(WEIGHTS)

{'epoch': 196,
 'global_step': 37828,
 'pytorch-lightning_version': '2.4.0',
 'state_dict': OrderedDict([('model.teacher.conv1.weight',
               tensor([[[[ 3.8301e-02,  2.7897e-02, -4.1028e-02,  1.8517e-02],
                         [-7.3324e-02,  1.0422e-02,  4.2204e-05,  6.8185e-03],
                         [ 1.7180e-02, -1.2565e-03,  5.1365e-02,  5.4532e-02],
                         [-6.9995e-02, -5.3737e-02,  3.3014e-02,  4.2294e-02]],
               
                        [[-4.5705e-02, -7.5281e-02,  4.9149e-02,  7.0959e-02],
                         [-6.7040e-02, -3.6553e-02, -1.7474e-02,  6.6298e-02],
                         [ 2.5458e-02,  9.9867e-03, -6.7165e-02,  3.8104e-02],
                         [ 8.0557e-03, -2.9160e-02,  5.6200e-03, -3.5618e-03]],
               
                        [[-8.3991e-02,  3.1498e-02, -5.0874e-02,  6.2174e-02],
                         [-6.9594e-02, -7.4470e-02,  7.0632e-03,  3.7253e-02],
                         [ 3.0271e-02,  

## Load checkpoint

We can easily load the model we previously trained using `train.py` into our modelNode

In [8]:
checkpoint = torch.load(WEIGHTS, weights_only=True)

state_dict = {remove_prefix(
    k, 'model.'): v for k, v in checkpoint['state_dict'].items()}

modelNode.model.load_state_dict(state_dict)
modelNode.initialized = True

## Save cuvis.ai graph

When saving the graph to disk, the whole pipeline and model code will be saved as well.

In [9]:
graph.save_to_file('effAD_cuvis.ai')

Cant find class 'ValueError'
Cant find class 'self'
Cant find class 'super'
Cant find class 'self'
Cant find class 'self'
Cant find class 'super'
Cant find class 'super'
Cant find class 'x'
Cant find class 'self'
Cant find class 'any'
Cant find class 'p_dic'
Cant find class 'self'
Cant find class 'super'
Cant find class 'x'
Cant find class 'super'
Cant find class 'self'
Cant find class 'super'
Cant find class 'x'
Cant find class 'batch'
Cant find class 'value'
Project saved to effAD_cuvis.ai


## Load cuvis.ai graph from file

We can now load the zip file again into any project without having to worry about the model class being available there.

In [10]:
loaded = cuvis_ai.pipeline.Graph.load_from_file(
    'effAD_cuvis.ai.zip')

Re-initializing module because the following parameters were re-set: module__in_channels, module__model_size, module__teacher_out_channels.
Re-initializing criterion.
Re-initializing optimizer.


## Infer one cube

Using the forward function of our graph we can now easily infer any cubert SessionFile that has the same dimensions as the ones we trained our model with. 

In [12]:
MESU = '/home/nathaniel/Downloads/val/20250310_151530_frame_102_ok_nok_rdx_rwx.cu3s'

In [13]:
num_samples = 1
input_mesu = cuvis.SessionFile(str(MESU)).get_measurement(0)
input_data = input_mesu.data["cube"].array
normalized_data = (input_data - config["mean"]) / config["std"]
normalized_data = np.expand_dims(normalized_data, 0)
normalized_data = normalized_data.astype(np.float32)
output = loaded.forward(normalized_data)

--- 23.731627941131592 seconds ---


## Display the result


In [None]:
cv.imshow("input SWIR", cv.resize(cv.normalize(input_data[:, :, 3:], None, 0, 255, cv.NORM_MINMAX, 0), (1000, 500)))
cv.imshow("input RGB", cv.resize(cv.normalize(input_data[:, :, :3], None, 0, 255, cv.NORM_MINMAX, 0), (1000, 500)))
cv.imshow("output", cv.resize(cv.normalize(output[0], None, 255, 0, cv.NORM_MINMAX, 0), (1000, 500)))

cv.waitKey(-1)
cv.destroyAllWindows()