In [1]:
import cuvis_ai
import cuvis
import yaml
from UNet_2D import FreshTwin2DUNet
import torch
import numpy as np
from torchvision.transforms import v2
import cv2 as cv
from pathlib import Path

In [2]:
def remove_prefix(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix):]
    return s

In [3]:
CONF = 'UNet_train_config.yaml'
WEIGHTS = 'with_pca.ckpt'
proc = None

## Load Model

First we create a cuvis.ai Graph using some the config parameters from our training.

In [4]:
with open(CONF, 'r') as f:
    config = yaml.safe_load(f)

graph = cuvis_ai.pipeline.Graph('Strawberry_graph')

Pca = cuvis_ai.preprocessor.PCA(config["pca_channels"])
sess = cuvis.SessionFile(str('D:/FreshTwin/swir/Strawberry_133_5_22_04_25.cu3s'))
input_mesu = sess.get_measurement(0)

if "cube" not in input_mesu.data:
    if proc is None:
        # create processing context only once if there are session files without cubes
        proc = cuvis.ProcessingContext(sess)
        if Path(config["white_path"]).exists() and Path(config["dark_path"]).exists():
            proc.set_reference(cuvis.SessionFile(config["white_path"]).get_measurement(0), cuvis.ReferenceType.White)
            proc.set_reference(cuvis.SessionFile(config["dark_path"]).get_measurement(0), cuvis.ReferenceType.Dark)
            proc.processing_mode = cuvis.ProcessingMode.Reflectance
    mesu = proc.apply(input_mesu)
input_data = input_mesu.data["cube"].array
Pca.fit(input_data.astype(np.float16))
modelNode = cuvis_ai.node.wrap.make_node(FreshTwin2DUNet)(
    criterion=torch.nn.NLLLoss, num_classes=config["num_classes"], in_channels=config["pca_channels"])

## Create Pipeline

Using some transformation nodes, a pipeline for data pre- and postprocessing is created

In [5]:
graph = graph >> modelNode

## Load checkpoint

We can easily load the model we previously trained using `train.py` into our modelNode

In [6]:
checkpoint = torch.load(WEIGHTS, weights_only=False)

state_dict = {remove_prefix(
    k, 'model.'): v for k, v in checkpoint['state_dict'].items()}

modelNode.model.load_state_dict(state_dict)
modelNode.initialized = True
modelNode.model.pca = checkpoint['pca']

## Save cuvis.ai graph

When saving the graph to disk, the whole pipeline and model code will be saved as well.

In [7]:
graph.save_to_file('Strawberry_Unet.ai')

Cant find class 'pca_image'
Cant find class 'super'
Cant find class 'x'
Cant find class 'self'
Cant find class 'c11'
Project saved to Strawberry_Unet.ai


## Load cuvis.ai graph from file

We can now load the zip file again into any project without having to worry about the model class being available there.

In [8]:
loaded = cuvis_ai.pipeline.Graph.load_from_file('Strawberry_Unet.ai.zip')

Re-initializing module because the following parameters were re-set: module__in_channels, module__num_classes.
Re-initializing criterion.
Re-initializing optimizer.


## Infer one cube

Using the forward function of our graph we can now easily infer any cubert SessionFile that has the same dimensions as the ones we trained our model with. 

In [9]:
MESU = 'D:/FreshTwin/swir/Strawberry_133_5_22_04_25.cu3s' # TODO: change mesu

In [17]:
sess = cuvis.SessionFile(str(MESU))
input_mesu = sess.get_measurement(0)

if "cube" not in input_mesu.data:
    if proc is None:
        # create processing context only once if there are session files without cubes
        proc = cuvis.ProcessingContext(sess)
        if Path(config["white_path"]).exists() and Path(config["dark_path"]).exists():
            proc.set_reference(cuvis.SessionFile(config["white_path"]).get_measurement(0), cuvis.ReferenceType.White)
            proc.set_reference(cuvis.SessionFile(config["dark_path"]).get_measurement(0), cuvis.ReferenceType.Dark)
            proc.processing_mode = cuvis.ProcessingMode.Reflectance
    mesu = proc.apply(input_mesu)
input_data = input_mesu.data["cube"].array
output = graph.forward(np.expand_dims(input_data.astype(np.float32),0)) 

--- 0.03195905685424805 seconds ---


## Display the result

In [19]:
cv.imshow("input RGB", cv.cvtColor(cv.normalize(input_data[:,:,config["cube_rgb_channels"]], None, 0, 255, cv.NORM_MINMAX, 0),cv.COLOR_BGR2RGB))
output_img = np.zeros_like(output)
prediction = np.argmax(output, axis=2)
output_img[:, :, 0] = (prediction == 1).astype(np.float32) # Strawberry will be labeled red
output_img[:, :, 2] = (prediction == 2).astype(np.float32) # Bruise will be labled blue
 
cv.imshow("output", cv.cvtColor(cv.normalize(output_img, None, 255, 0, cv.NORM_MINMAX, 0),cv.COLOR_BGR2RGB))

cv.waitKey(-1)
cv.destroyAllWindows()

In [16]:
cv.destroyAllWindows()
