# Real-Time CNN Susceptibility Artifact Detection
Using nnU-Net trained model.
Siemens Access-i interface used to connect with the MRI machine.

## Initialize CNN

In [17]:
"""
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/inference/examples.py
"""
# %matplotlib inline
from skimage import measure
import torch
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import cv2
import numpy as np
from scipy.interpolate import splprep, splev

"""
CNN model defined as global for easy access
"""
global cnn_model

def prepare_cnn(path_to_model_directory="MODEL", checkpoint_name="checkpoint_final.pth", folds=(4,)):
    """
    :param path_to_model_directory: must have dataset.json & folder containing the fold e.g. fold_4/checkpoint_final.pth
    :param checkpoint_name: e.g. checkpoint_final.pth
    :param folds: tuple of folds e.g. (4,)
    :return: 
    The model which is prepared to predict
    """
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    predictor = nnUNetPredictor(
        tile_step_size=1,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_gpu=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=False
    )
    predictor.initialize_from_trained_model_folder(path_to_model_directory, checkpoint_name=checkpoint_name, use_folds=folds)
    return predictor

def image_callback_cnn(image_data):
    global cnn_model
    if cnn_model is None:
        cnn_model = prepare_cnn()
    image = cv2.resize(image_data, (512, 512)).astype(np.float32) / 255.0
    cnn_input = image.reshape(1, 1, image.shape[0], image.shape[1])
    props = {'spacing': (999, 1, 1)}
    output = cnn_model.predict_single_npy_array(cnn_input, props, None, None, True)[0]
    output_display = (output * 255).astype(np.uint8).reshape(512, 512)
    image_display = (image * 255).astype(np.uint8)
     
def draw_spline(image_data):
    spline_points = []
    labels = measure.label(image_data > 128)
    properties = measure.regionprops(labels)
    centers = np.array([prop.centroid for prop in properties])
    if len(centers) > 1:
        tck, _ = splprep(centers.T, s=0)
        spline_points = splev(np.linspace(0, 1, 1000), tck)
    return spline_points

## Connect to the Access-i websocket

In [19]:
import siemens_access_library as access_library
import asyncio
import threading

Access = access_library.Access("127.0.0.1")

active_check = Access.get_is_active()
if active_check is None:
    raise SystemExit("Server not active")
print(f"Active: {active_check.value}")

version = Access.get_version()
print(f"Version: {version.value}")

register = Access.register()
print(f"Register: {register.result.success}, Session: {register.sessionId}")

image_format = Access.set_image_format(register.sessionId, "raw16bit")

"""
Initialize websocket loop for image service
"""
def run_websocket_in_thread(session_id, callback_function):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(Access.connect_websocket(session_id, callback_function))

thread = threading.Thread(target=run_websocket_in_thread, args=(register.sessionId, image_callback_cnn))
thread.start()
# Connect the image service to existing websocket 

image_service = Access.connect_image_service_to_default_web_socket(register.sessionId)
print(f"ImageServiceConnection: {image_service.result.success}")


Active: True
Version: 2.6
Register: True, Session: df8bcfe4-38c3-4459-b96d-422eaae6ca46
ImageServiceConnection: True
