# Running predictions with snceg
In this notebook we will download and apply our pretrained Attention U-Net to an example neuromelanin MRI.

[Model repository](https://huggingface.co/lillepeder/SNceg-0.1).

## Import libraries

In [None]:
import sys
# add snceg.py to the system path
sys.path.insert(0, '..')

In [None]:
from pathlib import Path
from huggingface_hub import snapshot_download
from snceg import *

In [None]:
target_shape = (288,384, 48)
target_res = (.677, .677, 1.34)

In [None]:
# Resample
resampled_nimg = conform(nimg, target_shape, target_res)

# Save output
out_name = f"{fn.parent / fn.stem}_RESAMPLED.nii.gz"
resampled_nimg.to_filename(out_name)

print(out_name)

# Download the model from HuggingFace
The model and its parameters are saved to `../models`

In [None]:
snapshot_download(repo_id="lillepeder/SNceg-0.1", local_dir='../models')

## A quick note on resampling

The model was trained primarily on anisotropic images of resolution $(0.677 \times 0.677 \times 1.340) mm³$, so we resample the image prior to prediction and resample back to the original image.

This and more is fetched with `load_variables`. 

In [None]:
size, reorder, resample = load_variables(pkl_fn='../models/vars_SNceg-0.1.pkl')

In [None]:
print(size, reorder, resample)

In the rare case that your data is already close to the given resolution, you can override the resampling by uncommenting the line below.

In [None]:
# resample = None

# Load the model into memory
This step is made extremely easy by fastAI's `load_learner`:

In [None]:
model_fn = '../models/SNceg-0.1.pkl'
learner = load_learner(model_fn)

# Predict
And just like that we can apply it to our input image.

In [None]:
# the input image
input_fn = '../data/mean_NM.nii.gz'

# path of the output
output_fn = '../data/SN_prediction.nii.gz'

In [None]:
pred = run_one_sample(fn=input_fn, 
               learner=learner, 
               reorder=reorder, 
               resample=resample,
               pred_fn=output_fn
              )

# Visualize result

In [None]:
from torchio import Subject, ScalarImage, LabelMap

subject = Subject(image=ScalarImage(input_fn), mask=LabelMap(output_fn))

subject.plot(figsize=(20,8),
             percentiles=(0.5, 99.9),
             indices=(95,110,118),
            )

## Or you can view it in your viewer of choice, e.g. Freeview 

In [None]:
!freeview {input_fn} {output_fn}:colormap=lut