[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mahdimplus/DeepRetroMoco/blob/main/functions.ipynb)

In [18]:
pip install voxelmorph

Collecting voxelmorph
  Downloading voxelmorph-0.1-py3-none-any.whl (75 kB)
[?25l[K     |████▍                           | 10 kB 23.2 MB/s eta 0:00:01[K     |████████▊                       | 20 kB 10.5 MB/s eta 0:00:01[K     |█████████████                   | 30 kB 9.3 MB/s eta 0:00:01[K     |█████████████████▌              | 40 kB 8.2 MB/s eta 0:00:01[K     |█████████████████████▉          | 51 kB 7.1 MB/s eta 0:00:01[K     |██████████████████████████▏     | 61 kB 7.4 MB/s eta 0:00:01[K     |██████████████████████████████▌ | 71 kB 7.2 MB/s eta 0:00:01[K     |████████████████████████████████| 75 kB 2.2 MB/s 
Collecting neurite
  Downloading neurite-0.1-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 4.5 MB/s 
Collecting pystrum
  Downloading pystrum-0.1-py3-none-any.whl (18 kB)
Installing collected packages: pystrum, neurite, voxelmorph
Successfully installed neurite-0.1 pystrum-0.1 voxelmorph-0.1


In [19]:
import nibabel as nib
import os
import numpy as np
import random
from nibabel.affines import apply_affine
import time
import voxelmorph as vxm
import pandas as pd
import matplotlib.pyplot as plt


### load data in specific shape (64*64)

In [None]:
def load_m (file_path):
    
    img = nib.load(file_path)
    img_data = img.get_fdata()
    
    if img.shape[0:2]!=(64,64):
    
        img_data = img_data[23:87,23:87,:,:]
        
    if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
        raise ValueError(
              f"Nifti file path must end with .nii or .nii.gz, got {file_path}."
                        )
    return img_data 

### load data in specific shape (64*64) with header data

In [None]:
def load_with_head (file_path: str):
    
    img = nib.load(file_path)
    img_data = img.get_fdata()
    if img.shape[0:2]!=(64,64):
    
        img_data = img_data[23:87,23:87,:,:]
        header=img.header  
        ## edit the header for shape
        header['dim'][1:5]=img_data.shape
        
    if not (file_path.endswith(".nii") or file_path.endswith(".nii.gz")):
        raise ValueError(
              f"Nifti file path must end with .nii or .nii.gz, got {file_path}."
                        )
    return img_data ,img

### list of name and number of data in a direction 

In [None]:
def count (data_dir):
    train_dir = os.path.join(data_dir)

    train_data_num = []
    for file in os.listdir(train_dir):
            train_data_num.append([file])
    train_data_num=np.array(train_data_num) 
    n=train_data_num.shape[0] 

    return n,train_data_num

### calculate maximum intensity in a direction between all data for normalization 

In [None]:
def maxx (data_dir):
    n,train_data_num=count(data_dir)
    start=0
    
    for i in range(n):
        
        d=load_m(data_dir+'/'+str(train_data_num[i][0]))
        maxx=d.max()
        
        if maxx>=start:
            start=maxx
            
    return start

###  prepare input (moved , fix) and ground truth (ref , deformation map) for training network

In [None]:
def data_generator(data_dir, batch_size,m,split):
    """4
    Generator that takes in data of size [N, H, W], and yields data for
    our custom vxm model. Note that we need to provide numpy data for each
    input, and each output.

    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    
    m= maximum between all subject 
    split= percent of validation data
    
    
    """
    
    n,train_data_num=count(data_dir)
    n_train=n-int(split*n)
 
    
    subject_ID=random.randint(0,n_train-1)
    d=load_m(data_dir+'/'+str(train_data_num[subject_ID][0]))

    
    s=d.shape[2]
    slice_ID =random.randint(0,s-1)
    v=d.shape[3]
    
 # preliminary sizing
    vol_shape = d.shape[:2] # extract data shape
    ndims = len(vol_shape)
    
    
    d=d[:,:,slice_ID,:]
    d = np.einsum('jki->ijk', d)

    
    
   
    
    # prepare a zero array the size of the deformation
    # we'll explain this below
    zero_phi = np.zeros([batch_size, *vol_shape, ndims])
    
    while True:
        # prepare inputs:
        # images need to be of the size [batch_size, H, W, 1]
        idx1 = np.random.randint(0, v, size=batch_size)
        moving_images = d[idx1, ..., np.newaxis]
        moving_images=moving_images/m
        
        idx2 = np.random.randint(0, v, size=batch_size)
        fixed_images = d[idx2, ..., np.newaxis]
        fixed_images=fixed_images/m
        
        inputs = [moving_images, fixed_images]
        
        # prepare outputs (the 'true' moved image):
        # of course, we don't have this, but we know we want to compare 
        # the resulting moved image with the fixed image. 
        # we also wish to penalize the deformation field. 
        outputs = [fixed_images, zero_phi]
        
        yield (inputs, outputs)

### prepare data for validation

In [None]:
def val_generator(data_dir, batch_size,m,split):
    """4
    Generator that takes in data of size [N, H, W], and yields data for
    our custom vxm model. Note that we need to provide numpy data for each
    input, and each output.

    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    
    m= maximum between all subject 
    split= percent of validation data
    
    
    """
    
    n,train_data_num=count(data_dir)
    n_train=n-int(split*n)
    a=n_train
    
    
    subject_ID=random.randint(a,n-1)
    d=load_m(data_dir+'/'+str(train_data_num[subject_ID][0]))

    
    s=d.shape[2]
    slice_ID =random.randint(0,s-1)
    v=d.shape[3]
    
 # preliminary sizing
    vol_shape = d.shape[:2] # extract data shape
    ndims = len(vol_shape)
    
    
    d=d[:,:,slice_ID,:]
    d = np.einsum('jki->ijk', d)

    
    
   
    
    # prepare a zero array the size of the deformation
    # we'll explain this below
    zero_phi = np.zeros([batch_size, *vol_shape, ndims])
    
    # prepare inputs:
    # images need to be of the size [batch_size, H, W, 1]
    idx1 = np.random.randint(0, v, size=batch_size)
    moving_images = d[idx1, ..., np.newaxis]
    moving_images=moving_images/m

    idx2 = np.random.randint(0, v, size=batch_size)
    fixed_images = d[idx2, ..., np.newaxis]
    fixed_images=fixed_images/m

    inputs = [moving_images, fixed_images]

    # prepare outputs (the 'true' moved image):
    # of course, we don't have this, but we know we want to compare 
    # the resulting moved image with the fixed image. 
    # we also wish to penalize the deformation field. 
    outputs = [fixed_images,zero_phi]

    return (inputs, outputs)

### change angle format  and nearest neighborhod and apply affine matrix   for augmentation 

In [None]:
def a(teta):
    return   (teta*np.pi)/180

def nearest_neighbors(i, j, M, T_inv):
    x_max, y_max = M.shape[0] - 1, M.shape[1] - 1
    
    x, y, k = apply_affine(T_inv, np.array([i, j, 1]))
   
    if x<0 or y<0:
            x=0
            y=0
    if x>=x_max+1 or y>=y_max+1:
            x=0
            y=0 
            
            
    if np.floor(x) == x and np.floor(y) == y:
        x, y = int(x), int(y)
        return M[x, y]
    
    if np.abs(np.floor(x) - x) < np.abs(np.ceil(x) - x):
        x = int(np.floor(x))
    else:
        x = int(np.ceil(x))
    if np.abs(np.floor(y) - y) < np.abs(np.ceil(y) - y):
        y = int(np.floor(y))
    else:
        y = int(np.ceil(y))
        
    if x > x_max:
        x = x_max
    if y > y_max:
        y = y_max
    return M[x, y]

   

def affine_matrix():
    t=random.randint(-5, 5)
    cos_gamma = np.cos(a(t))
    sin_gamma = np.sin(a(t))
    x=random.randint(-3, 3)
    y=random.randint(-6, 6)
    T=np.array([[cos_gamma,-sin_gamma,0,x],
                 [sin_gamma,cos_gamma,0,y],
                 [0,0,1,0],
                 [0,0,0,1]])

    return T

### Augmentation

In [None]:
def augsb(ref,volume,affine_matrix):
    tdim,xdim,ydim,tdim  = ref.shape
    img_transformed = np.zeros((xdim, ydim), dtype=np.float64)

    for i, row in enumerate(ref[volume,:,:,0]):
        for j, col in enumerate(row):
            pixel_data = ref[volume,i, j, 0]

            input_coords = np.array([i, j, 1])
            i_out, j_out,k= apply_affine(affine_matrix, input_coords)
        
            if i_out<0 or j_out<0:
                i_out=0
                j_out=0
            if i_out>=xdim or j_out>=ydim:
                i_out=0
                j_out=0   
                
            img_transformed[int(i_out),int(j_out)] = pixel_data
    
    
    T_inv = np.linalg.inv(affine_matrix)
    img_nn = np.ones((xdim, ydim), dtype=np.float64)
    for i, row in enumerate(img_transformed):
        for j, col in enumerate(row):

            img_nn[i, j] = nearest_neighbors(i, j, ref[volume,:,:,0], T_inv)
            
    return   img_nn  

### prepare data for augmentation

In [None]:
def affine_generator(data_dir,batch_size,m,split):

    n,train_data_num=count(data_dir)
    n_train=n-int(split*n)
    


    subject_ID=random.randint(0,n_train-1)
    d=load_m(data_dir+'/'+str(train_data_num[subject_ID][0]))




    s=d.shape[2]
    slice_ID =random.randint(0,s-1)

    v=d.shape[3]

    # preliminary sizing
    vol_shape = d.shape[:2] # extract data shape
    ndims = len(vol_shape)


    d=d[:,:,slice_ID,:]
    d = np.einsum('jki->ijk', d)


    y=[]
    for i in range(batch_size):
        y.append(affine_matrix())
    y=np.array(y)


    # prepare a zero array the size of the deformation
    # we'll explain this below
    zero_phi = np.zeros([batch_size, *vol_shape, ndims])

    # prepare inputs:
    # images need to be of the size [batch_size, H, W, 1]
    while True:
        idx2 = np.random.randint(0, v, size=batch_size)
        fixed_images = d[idx2, ..., np.newaxis]
        fixed_images=fixed_images/m


        moving_images=[]
        for i in range(batch_size):

            moving_images.append(augsb(fixed_images,i,y[i]))

        moving_images=np.array(moving_images)    
        moving_images=moving_images[... , np.newaxis]



        #moving_images=augsb(fixed_images,y)

        #idx1 = np.random.randint(0, v, size=batch_size)
        #moving_images = d[idx1, ..., np.newaxis]
        #moving_images=moving_images/m
        inputs = [moving_images, fixed_images]

        # prepare outputs (the 'true' moved image):
        # of course, we don't have this, but we know we want to compare 
        # the resulting moved image with the fixed image. 
        # we also wish to penalize the deformation field. 
        outputs = [fixed_images,zero_phi]

        yield(inputs,outputs)
    #y)

In [None]:
def label_generator(data_dir,batch_size,m,split):

    n,train_data_num=count(data_dir)
    n_train=n-int(split*n)
    a=n_train


    subject_ID=random.randint(a,n-1)
    d=load_m(data_dir+'/'+str(train_data_num[subject_ID][0]))


    s=d.shape[2]
    slice_ID =random.randint(0,s-1)

    v=d.shape[3]

    # preliminary sizing
    vol_shape = d.shape[:2] # extract data shape
    ndims = len(vol_shape)


    d=d[:,:,slice_ID,:]
    d = np.einsum('jki->ijk', d)


    y=[]
    for i in range(batch_size):
        y.append(affine_matrix())
    y=np.array(y)


    # prepare a zero array the size of the deformation
    # we'll explain this below

    # prepare inputs:
    # images need to be of the size [batch_size, H, W, 1]
    while True:

        idx2 = np.random.randint(0, v, size=batch_size)
        fixed_images = d[idx2, ...]
        fixed_images=fixed_images/m


        moving_images=[]
        for i in range(batch_size):

            moving_images.append(augsb(fixed_images,i,y[i]))
        moving_images=np.array(moving_images)    
        #moving_images=moving_images[... ]
        
        
        c=np.stack([moving_images,fixed_images], axis=2) 
        inputs = [c]
        #inputs=[[moving_images,fixed_images]]
        
        # prepare outputs (the 'true' moved image):
        # of course, we don't have this, but we know we want to compare 
        # the resulting moved image with the fixed image. 
        # we also wish to penalize the deformation field.

        outputs = [y]
        yield (inputs, outputs)
       

In [None]:
def ref(data_dir,m,slice_ID,reference):
    """4
    Generator that takes in data of size [N, H, W], and yields data for
    our custom vxm model. Note that we need to provide numpy data for each
    input, and each output.

    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    
    m= maximum between all subject 
    split= percent of validation data
    
    
    """
    
   
    d=load_m(data_dir)
    
    #s=d.shape[2]
    #slice_ID =random.randint(0,s-1)
    v=d.shape[3]
    
 # preliminary sizing
    vol_shape = d.shape[:2] # extract data shape
    ndims = len(vol_shape)
    
    
    d=d[:,:,slice_ID,:]
    d = np.einsum('jki->ijk', d)

    
    
   
    
    # prepare a zero array the size of the deformation
    # we'll explain this below
    zero_phi = np.zeros([v, *vol_shape, ndims])
    
    # prepare inputs:
    # images need to be of the size [batch_size, H, W, 1]
    idx1=[]
    for i in range(v):

        idx1.append(i)
        
    idx1=np.array(idx1)    
    moving_images = d[idx1, ..., np.newaxis]
    moving_images=moving_images/m
    
    

    
    if reference.strip().isdigit():
        # print("User input is Number")
        reference=int(reference)
        idx2=np.ones(v)*reference
        idx2=idx2.astype(int)

        fixed_images = d[idx2, ..., np.newaxis]
        fixed_images=fixed_images/m


    else:
        # print("User input is string")


        img = nib.load(reference)
        img_data = img.get_fdata()
        if img.shape[0:2]!=(64,64):
            img_data = img_data[23:87,23:87,:]
            
        img_data=img_data[np.newaxis,:,:,slice_ID]
        idx2=np.zeros(v)
        idx2=idx2.astype(int)

        fixed_images = img_data[idx2, ..., np.newaxis]
        fixed_images=fixed_images/m

        
    inputs = [moving_images, fixed_images]

    # prepare outputs (the 'true' moved image):
    # of course, we don't have this, but we know we want to compare 
    # the resulting moved image with the fixed image. 
    # we also wish to penalize the deformation field. 
    outputs = [fixed_images,zero_phi]

    return (inputs, outputs)

In [None]:
def main (input_direction,reference,output_direction,maximum_intensity,loadable_model):

    start_time = time.time()
    img_data,img=load_with_head(input_direction)
    slice_number = img_data.shape[2]
    header=img.header
    img_mask_affine = img.affine
        # configure unet input shape (concatenation of moving and fixed images)
    ndim = 2
    unet_input_features = 2
    # data shape 64*64
    s=(64,64)
    inshape = (*s, unet_input_features)
    # configure unet features 
    nb_features =[
        [64, 64, 64, 64],         # encoder features
        [64, 64, 64, 64, 64, 32,16]  # decoder features
                 ]
    # build model using VxmDense
    inshape =s
    vxm_model = vxm.networks.VxmDense(inshape, nb_features, int_steps=0)
    # voxelmorph has a variety of custom loss classes
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    # usually, we have to balance the two losses by a hyper-parameter
    lambda_param = 0.05
    loss_weights = [1, lambda_param]
    vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights, metrics=['accuracy'])
    vxm_model.load_weights(loadable_model)
    o=np.zeros((img_data.shape[0],img_data.shape[1],img_data.shape[2],img_data.shape[3]))

    for i in range(slice_number):

        prepare_data=ref(input_direction,maximum_intensity,i,reference)
        val_input, _ = prepare_data
        val_pred = vxm_model.predict(val_input)
        change_order= np.einsum('jki->kij',val_pred[0][:,:,:,0])

        o[:, :, i,:] = change_order


    img_reg = nib.Nifti1Image(o*maximum_intensity, affine=img_mask_affine, header=header)
    nib.save(img_reg,output_direction)   
    print("--- %s second ---" % (time.time() - start_time))



In [None]:
def snr (direction):
    
    img = nib.load(direction)
    img = img.get_fdata()
    mean=[]
    
    for i in range(img.shape[2]):
        mean.append(np.mean(img[:,:,i]))
    mean=np.array(mean)
    
    deviation=[]
    for i in range(img.shape[2]):
        deviation.append(np.std(img[:,:,i]))
    deviation=np.array(deviation)
    
    
    return (mean/deviation),mean,deviation

In [None]:
def mean(direction): 
    img = nib.load(direction)
    img = img.get_fdata()
    mean=[]
    where_are_NaNs = isnan(img)
    img[where_are_NaNs] = 0
    
    for i in range(img.shape[2]):
        mean.append(np.mean(img[:,:,i]))
    mean.append(np.mean(mean))    
    mean=np.array(mean)
    return mean

In [None]:
def seg_mean(img):
    p=0
    for m in range(img.shape[0]):
        for n in range(img.shape[1]):
            if img[m,n]==0:
                p=p+1
            s=np.sum(img[:,:])
    mean=s/((64*64)-p)   
    return mean


In [None]:
def mean_all(direction): 
    img = nib.load(direction)
    img = img.get_fdata()
    mean=[]
    where_are_NaNs = np.isnan(img)
    img[where_are_NaNs] = 0
    
    for i in range(img.shape[2]):
        mean.append(seg_mean(img[:,:,i]))
     
    mean=np.mean(mean) 
    return mean

In [None]:
def shift_image(X, dx, dy):
    X = np.roll(X, dy, axis=0)
    X = np.roll(X, dx, axis=1)
    if dy>0:
        X[:dy, :] = 0
    elif dy<0:
        X[dy:, :] = 0
    if dx>0:
        X[:, :dx] = 0
    elif dx<0:
        X[:, dx:] = 0
    return X

In [None]:
def cplus(source_centerline_directory,centerlines_directory,main_data_directory,
          center_fix_directory,final_cplus_directory,
           maximum_intensity,model,reference,mean_directory
         ):

    #############################################
    # if reference=0 means reference=first volume
    # if reference=-1 means reference=mid volume
    # if reference=-2 means reference=mean volume
    # if reference>0 means reference=any volume


        Xs=[]
        Ys=[]
        source = pd.read_csv(source_centerline_directory, header=None)
        source.columns=['x','y','delete']
        source = source[['x','y']]

        for s in range(source.shape[0]):
            c=source.loc[s]
            #xs=int(c['x'])
            ys=int(c['y'])
            #Xs.append(xs)
            Ys.append(ys)

        n2,name2=count_endwith(centerlines_directory,'.csv')

        dx=[]
        dy=[]
        for s in range(0,source.shape[0]):
             for j in range(n2):
                    df = pd.read_csv(centerlines_directory+name2[j][0], header=None)
                    df.columns=['x','y','delete']
                    df=df[['x','y']]
                    c=df.loc[s]
                    #x=int(c['x'])
                    y=int(c['y'])
                    #dx.append(Xs[s]-x)
                    dy.append(Ys[s]-y)  

        input_direction=main_data_directory
        img  = nib.load(input_direction)
        img_data=img.get_fdata()
        img_mask_affine = img.affine
        header = img.header
        nb_img = header.get_data_shape()
        o=np.zeros((nb_img[0],nb_img[1],nb_img[2],nb_img[3]))


        DX=np.zeros(len(dy))


        start=0            
        for s in range(0,source.shape[0]):
            for v in range(n2):
                        a= shift_image(img_data[:,:,s,v],dy[v+start],DX[v+start])
                        o[:,:,s, v] = a            
            start=start + n2

       
        input_direction=center_fix_directory
        img_reg = nib.Nifti1Image(o, affine=img_mask_affine, header=header)
        nib.save(img_reg,input_direction)
       
        if reference>0:
            reference=str(reference)
        if reference==0:
            reference='0'
        if reference==-1:
            y=int(n2/2)
            reference=str(y)
        if reference==-2:
            reference=mean_directory
        

        main(input_direction,reference,final_cplus_directory,maximum_intensity,model)

In [None]:
def count_startwith (data_dir,prefix):
    train_dir = os.path.join(data_dir)

    train_data_num = []
    for file in os.listdir(train_dir):
        if file.startswith(prefix):
            train_data_num.append([file])
    train_data_num=np.array(train_data_num) 
    n=train_data_num.shape[0] 

    return n,sorted(train_data_num)

In [None]:
def count_endwith (data_dir,prefix):
    train_dir = os.path.join(data_dir)

    train_data_num = []
    for file in os.listdir(train_dir):
        if file.endswith(prefix):
            train_data_num.append([file])
    train_data_num=np.array(train_data_num) 
    n=train_data_num.shape[0] 

    return n,sorted(train_data_num)

# movement plots for one slice

In [None]:
def flow_one_slice(input_direction,reference,maximum_intensity,loadable_model,slice_num,mean_directory,title):
    img_data,img=load_with_head(input_direction)
    slice_number = img_data.shape[2]
    header=img.header
    img_mask_affine = img.affine
        # configure unet input shape (concatenation of moving and fixed images)
    ndim = 2
    unet_input_features = 2
    # data shape 64*64
    s=(64,64)
    inshape = (*s, unet_input_features)
    # configure unet features 
    nb_features =[
        [64, 64, 64, 64],         # encoder features
        [64, 64, 64, 64, 64, 32,16]  # decoder features
                 ]
    # build model using VxmDense
    inshape =s
    vxm_model = vxm.networks.VxmDense(inshape, nb_features, int_steps=0)
    # voxelmorph has a variety of custom loss classes
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    # usually, we have to balance the two losses by a hyper-parameter
    lambda_param = 0.05
    loss_weights = [1, lambda_param]
    vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights, metrics=['accuracy'])
    vxm_model.load_weights(loadable_model)

    if reference>0:
        reference=str(reference)
    if reference==0:
        reference='0'
    if reference==-1:
        y=int(img_data.shape[3]/2)
        reference=str(y)
    if reference==-2:
        reference=mean_directory

    
    
    
    
    #for i in range(slice_number):
    #slice_number=5
    prepare_data=ref(input_direction,maximum_intensity,slice_num,reference)
    val_input, _ = prepare_data
    val_pred = vxm_model.predict(val_input)
    #val_pred=flow(input_direction,reference,maximum_intensity,loadable_model,slice_num)
  
    x=[]
    y=[]
    for i in range(val_pred[1][:,0,0,0].shape[0]):
        x.append(np.mean(val_pred[1][i,...,0]))
        y.append(np.mean(val_pred[1][i,...,1]))
    x=np.array(x)
    y=np.array(y)

    volume=range(val_pred[1][:,0,0,0].shape[0])
    plt.figure(figsize=(20,5))
    plt.plot(volume,x,label = "x")
    plt.plot(volume,y,label = "y")
    # naming the x axis
    plt.xlabel('volumes')
    # naming the y axis
    plt.ylabel('movement')
    # giving a title to my graph
    plt.title(title)

    # show a legend on the plot
    plt.legend()


# movement plot for all slice in one plot

In [None]:
def flow_all_slice(input_direction,reference,maximum_intensity,loadable_model,mean_directory,title):
    img_data,img=load_with_head(input_direction)
    slice_number = img_data.shape[2]
    header=img.header
    img_mask_affine = img.affine
        # configure unet input shape (concatenation of moving and fixed images)
    ndim = 2
    unet_input_features = 2
    # data shape 64*64
    s=(64,64)
    inshape = (*s, unet_input_features)
    # configure unet features 
    nb_features =[
        [64, 64, 64, 64],         # encoder features
        [64, 64, 64, 64, 64, 32,16]  # decoder features
                 ]
    # build model using VxmDense
    inshape =s
    vxm_model = vxm.networks.VxmDense(inshape, nb_features, int_steps=0)
    # voxelmorph has a variety of custom loss classes
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    # usually, we have to balance the two losses by a hyper-parameter
    lambda_param = 0.05
    loss_weights = [1, lambda_param]
    vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights, metrics=['accuracy'])
    vxm_model.load_weights(loadable_model)

    if reference>0:
        reference=str(reference)
    if reference==0:
        reference='0'
    if reference==-1:
        y=int(img_data.shape[3]/2)
        reference=str(y)
    if reference==-2:
        reference=mean_directory

    
   
    x_all_slice=[]
    y_all_slice=[]
    
    for i in range(slice_number):
        prepare_data=ref(input_direction,maximum_intensity,i,reference)
        val_input, _ = prepare_data
        val_pred = vxm_model.predict(val_input)
        #val_pred=flow(input_direction,reference,maximum_intensity,loadable_model,slice_num)
        x=[]
        y=[]
        
        for i in range(val_pred[1][:,0,0,0].shape[0]):
            x.append(np.mean(val_pred[1][i,...,0]))
            y.append(np.mean(val_pred[1][i,...,1]))

        x_all_slice.append(x)
        y_all_slice.append(y)


    
    x_all_slice=np.array(x_all_slice)
    y_all_slice=np.array(y_all_slice)
    
    mean_x=x_all_slice.mean(axis=0)
    mean_y=y_all_slice.mean(axis=0)
    ### for delete the eror for reference to reference
    mean_x[int(reference)]=0
    mean_y[int(reference)]=0
    
    
    overal=(mean_x+mean_y)/2
    
    volume=range(val_pred[1][:,0,0,0].shape[0])
    plt.figure(figsize=(20,5))
    plt.plot(volume,overal,label = "x")
    #plt.plot(volume,mean_y,label = "y")
    
    
    
    
    
    # naming the x axis
    plt.xlabel('volumes')
    # naming the y axis
    plt.ylabel('movement')
    # giving a title to my graph
    plt.title(title)

    # show a legend on the plot
    plt.legend()


In [None]:
def flow_between_two(input_direction0,input_direction1,reference,
                     maximum_intensity,loadable_model,mean_directory,
                     title,label1,label2):
  
    img_data,img=load_with_head(input_direction0)
    slice_number = img_data.shape[2]
    header=img.header
    img_mask_affine = img.affine
        # configure unet input shape (concatenation of moving and fixed images)
    ndim = 2
    unet_input_features = 2
    # data shape 64*64
    s=(64,64)
    inshape = (*s, unet_input_features)
    # configure unet features 
    nb_features =[
        [64, 64, 64, 64],         # encoder features
        [64, 64, 64, 64, 64, 32,16]  # decoder features
                 ]
    # build model using VxmDense
    inshape =s
    vxm_model = vxm.networks.VxmDense(inshape, nb_features, int_steps=0)
    # voxelmorph has a variety of custom loss classes
    losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
    # usually, we have to balance the two losses by a hyper-parameter
    lambda_param = 0.05
    loss_weights = [1, lambda_param]
    vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights, metrics=['accuracy'])
    vxm_model.load_weights(loadable_model)

    if reference>0:
        reference=str(reference)
    if reference==0:
        reference='0'
    if reference==-1:
        y=int(img_data.shape[3]/2)
        reference=str(y)
    if reference==-2:
        reference=mean_directory

    
   
    x_all_slice=[]
    y_all_slice=[]
    
    for i in range(slice_number):
        prepare_data=ref(input_direction0,maximum_intensity,i,reference)
        val_input, _ = prepare_data
        val_pred = vxm_model.predict(val_input)
        #val_pred=flow(input_direction,reference,maximum_intensity,loadable_model,slice_num)
        x=[]
        y=[]
        
        for i in range(val_pred[1][:,0,0,0].shape[0]):
            x.append(np.mean(val_pred[1][i,...,0]))
            y.append(np.mean(val_pred[1][i,...,1]))

        x_all_slice.append(x)
        y_all_slice.append(y)


    
    x_all_slice=np.array(x_all_slice)
    y_all_slice=np.array(y_all_slice)
    
    mean_x=x_all_slice.mean(axis=0)
    mean_y=y_all_slice.mean(axis=0)
    ### for delete the eror for reference to reference
    mean_x[int(reference)]=0
    mean_y[int(reference)]=0
    
    
    overal=(mean_x+mean_y)/2
   

   
   
    x_all_slice=[]
    y_all_slice=[]
    
    for i in range(slice_number):
        prepare_data=ref(input_direction1,maximum_intensity,i,reference)
        val_input, _ = prepare_data
        val_pred = vxm_model.predict(val_input)
        #val_pred=flow(input_direction,reference,maximum_intensity,loadable_model,slice_num)
        x=[]
        y=[]
        
        for i in range(val_pred[1][:,0,0,0].shape[0]):
            x.append(np.mean(val_pred[1][i,...,0]))
            y.append(np.mean(val_pred[1][i,...,1]))

        x_all_slice.append(x)
        y_all_slice.append(y)


    
    x_all_slice=np.array(x_all_slice)
    y_all_slice=np.array(y_all_slice)
    
    mean_x=x_all_slice.mean(axis=0)
    mean_y=y_all_slice.mean(axis=0)
    ### for delete the eror for reference to reference
    mean_x[int(reference)]=0
    mean_y[int(reference)]=0
    
    
    overal1=(mean_x+mean_y)/2    


 
    volume=range(val_pred[1][:,0,0,0].shape[0])
    
    plt.figure(figsize=(25,10))
    plt.plot(volume,overal,label = label1)
    plt.plot(volume,overal1,label = label2)
    
    
    
    
    
    # naming the x axis
    plt.xlabel('volumes',fontsize=18)
    # naming the y axis
    plt.ylabel('movement',fontsize=18)
    # giving a title to my graph
    plt.title(title,fontsize=20)

    # show a legend on the plot
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    # show a legend on the plot
    plt.legend()
    plt.grid()
    plt.legend(fontsize=15)

