## Install latest Octave compile

In [1]:
%%capture
! apt-get update
! apt-get install gcc g++ gfortran make libblas-dev liblapack-dev libpcre3-dev \
libarpack2-dev libcurl4-gnutls-dev epstool libfftw3-dev transfig libfltk1.3-dev \
libfontconfig1-dev libfreetype6-dev libgl2ps-dev libglpk-dev libreadline-dev \
gnuplot-x11 libgraphicsmagick++1-dev libhdf5-serial-dev openjdk-8-jdk \
libsndfile1-dev llvm-dev lpr texinfo libgl1-mesa-dev libosmesa6-dev pstoedit \
portaudio19-dev libqhull-dev libqrupdate-dev libqscintilla2-dev \
libsuitesparse-dev texlive texlive-generic-recommended libxft-dev zlib1g-dev \
autoconf automake bison flex gperf gzip icoutils librsvg2-bin libtool perl \
rsync tar qtbase5-dev qttools5-dev qttools5-dev-tools libqscintilla2-qt5-dev \
wget git libsundials-dev gnuplot x11-apps

In [2]:
import os, urllib.request, json

In [3]:
# Download latest compiled octave package 
def get_octave(root_path):
  os.chdir(root_path)
  with urllib.request.urlopen("https://api.github.com/repos/cerr/octave-colab/releases/latest") as url:
      data = json.loads(url.read().decode())
  fname = data['assets'][0]['name']
  requrl = data['assets'][0]['browser_download_url']
  urllib.request.urlretrieve(requrl, fname)
  # Unzip, identify octave folder name
  !tar xf {fname}
  top_folder = !tar tf {fname} | head -1
  octave_folder = top_folder[0][:-1]
  octave_path = os.path.join(root_path,octave_folder)
  return octave_path

# Set path to Octave exectuable
octave_path = get_octave('/content')
os.environ['OCTAVE_EXECUTABLE'] = octave_path + '/bin/octave-cli'
os.environ['PATH'] = octave_path + '/bin:' + os.environ['PATH']

## Install Octave-Python bridge

In [4]:
%%capture
! pip3 install octave_kernel
! pip3 install oct2py==5.3.0

%load_ext oct2py.ipython

## Download CERRx

In [5]:
%%capture
!git clone --single-branch --branch octave_dev https://www.github.com/cerr/CERR.git

## Register scans using structure priors

**Add CERR to path**

In [6]:
%%octave

#Load required packages
pkg load statistics

#Add CERR to Octave path
addpath(genpath('/content/CERR'));


**Specify I/O paths**


In [7]:
%%octave
#Plastimatch cmd file 
plmSettingsFile = 'Path/to/PlastimatchCommandFile.txt'; #Replace with plm cmd file path
#Temp dir
tmpDirPath = '/tmp';
#Registration output dir
registeredDir = '/content/Out';
mkdir(registeredDir)

**Define registration settings**

In [None]:
registration_tool = 'PLASTIMATCH';
movMask3M = [];
threshold_bone = -800;
inBspFile = '';
outBspFile = '';
algorithm = 'BSPLINE PLASTIMATCH';

**Register scans using Plastimatch**

In [None]:
%%octave

#Load fixed scan
baseScanFile = 'Path/to/basePlanC.mat';
planC = loadPlanC(baseScanFile,tempdir);
planC = updatePlanFields(planC);
planC = quality_assure_planC(baseScanFile,planC);
indexS = planC{end};

#Load moving scan
movScanFile = 'Path/to/movingPlanC.mat';
planD = loadPlanC(movScanFile,tempdir);
planD = updatePlanFields(planD);
planD = quality_assure_planC(movScanFile,planD);
indexSD = planD{end};
movMask3M = [];
movScanNum  = 1;

#Define initial translation based on centroids of landmark structures
# On base scan:
baseMask3M = [];
baseScanNum = 1;
baseLandmarkStrC = {'PERICARDIUM'}; #Replace with landmark structure name
baseLandmarkListM = [];
xyzBaseM = zeros(length(baseLandmarkStrC),3);
strBaseC = {planC{indexS.structures}.structureName};
for iStr = 1:length(baseLandmarkStrC)
    baseLandmarkStr = baseLandmarkStrC{iStr};
    strIndV = getMatchingIndex(baseLandmarkStr,strBaseC,'exact');
    assocScanNumV = getStructureAssociatedScan(strIndV,planC);
    baseLandmarkInd = strIndV(assocScanNumV == baseScanNum);
    [baseX,baseY,baseZ] = calcIsocenter(baseLandmarkInd,'COM',planC);
    xyzBaseV = [baseX,baseY,baseZ];
    xyzBaseM(iStr,:) = xyzBaseV;
    [xBaseV,yBaseV,zBaseV] = getScanXYZVals(planC{indexS.scan}(baseScanNum));
    baseY =  -yBaseV(1)+(yBaseV(1)-baseY);
    baseZ = -zBaseV(end) + zBaseV(end) - baseZ;
end

# On moving scan:
strMovC = {planD{indexSD.structures}.structureName};
landmarkListM = [];
movLandmarkStrC = {'PERICARDIUM'}; #Replace with landmark structure name
xyzMovM = nan(length(movLandmarkStrC),3);
for iStr = 1:length(movLandmarkStrC)
    movLandmarkStr = movLandmarkStrC{iStr};
    strIndV = getMatchingIndex(movLandmarkStr,strMovC,'exact');
    assocScanNumV = getStructureAssociatedScan(strIndV,planD);
    movLandmarkInd = strIndV(assocScanNumV == movScanNum);
    if length(movLandmarkInd) ~= 1
        continue;
    end
    [movX,movY,movZ] = calcIsocenter(movLandmarkInd,'COM',planD);
    xyzMovV = [movX,movY,movZ];
    xyzMovM(iStr,:) = xyzMovV;
    [xMovV,yMovV,zMovV] = getScanXYZVals(planD{indexSD.scan}(movScanNum));
    movY =  -yMovV(1)+(yMovV(1)-movY);
    movZ = -zMovV(end) + zMovV(end) - movZ;
end

initialTranslationXyzM = xyzMovM - xyzBaseM;
initialTranslationXyzV = nanmean(initialTranslationXyzM,1);
initialTranslationXyzV(1) = initialTranslationXyzV(1);
initialTranslationXyzV(2) = -initialTranslationXyzV(2);
initialTranslationXyzV(3) = -initialTranslationXyzV(3);

In [None]:
%%octave
#Call registration wrapper
planC = register_scans(planC, baseScanNum, planD, movScanNum,...
    algorithm, registration_tool, tmpDirPath, baseMask3M,...
    movMask3M, threshold_bone, plmSettingsFile, inBspFile,...
    outBspFile, landmarkListM, initialTranslationXyzV);

## Display result

In [None]:
%%octave
#Get original scan array
indexS = planC{end};
ctOffset = planC{indexS.scan}(baseScanNum).scanInfo(1).CTOffset;
origScanArray = single(getScanArray(baseScanNum,planC)) - ctOffset;

#Get registered scan array
scanNum = length(planC{indexS.scan});
ctOffset = planC{indexS.scan}(scanNum).scanInfo(1).CTOffset;
regScanArray = single(getScanArray(scanNum,planC)) - ctOffset;

In [None]:
%octave_pull origScanArray regScanArray

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from IPython.display import clear_output
import ipywidgets as widgets

dx, dy = 1, 1

x = np.arange(0, 512, dx)
y = np.arange(0, 512, dy)
extent = np.min(x), np.max(x), np.min(y), np.max(y)

clear_output(wait=True)

def window_image(image, window_center, window_width):
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    window_image = image.copy()
    window_image[window_image < img_min] = img_min
    window_image[window_image > img_max] = img_max
    
    return window_image

def show_axial_slice(slcNum):
    clear_output(wait=True)
    print('Slice '+str(slcNum))
    if 'fig' in locals():
        fig.remove()
    fig, (ax1,ax2) = plt.subplots(1,2,figsize=(15,15))
    window_center = 40
    window_width = 400
    windowed_img = window_image(origScanArray[:,:,slcNum-1],
                                window_center,window_width)
    im1 = ax1.imshow(windowed_img, cmap=plt.cm.gray, alpha=1,
                    interpolation='nearest', extent=extent)
    
    windowed_reg_img = window_image(regScanArray[:,:,slcNum-1],
                                window_center,window_width)
    im2 = ax2.imshow(windowed_reg_img, cmap=plt.cm.gray, alpha=1, 
                    interpolation='nearest', extent=extent)        
    plt.show()

slice_slider = widgets.IntSlider(value=120,min=1,max=299,step=1)
outputSlc = widgets.Output()

display(slice_slider, outputSlc)

def update_slice(change):
    with outputSlc:
        show_axial_slice(change['new'])

slice_slider.observe(update_slice, names='value')