In [1]:
import sys
from os import path

import numpy as np
sqrt = np.sqrt
eps = np.finfo(float).eps

# FISTA
Solve the regularised least squares problem
\begin{equation}\nonumber
\arg\min_x \frac12 \|Ax-y\|_2^2 + \Omega(x)
\end{equation}
with the FISTA algorithm described in [1].
The penalty term and its proximal operator must be defined in such a way that they already contain the regularisation parameter.

-------------
### Reference:
* Beck & Teboulle - A Fast Iterative Shrinkage Thresholding Algorithm for Linear Inverse Problems, 2009

In [2]:
def fista( y, A, At, tol_fun, tol_x, max_iter, verbose, x0, omega, proximal) :
    # Initialization
    res = -y.copy()
    xhat = x0.copy()
    x = np.zeros_like(xhat)
    res += A.dot(xhat)
    xhat = proximal( xhat )
    reg_term = omega( xhat )
    prev_obj = 0.5 * np.linalg.norm(res)**2 + reg_term

    told = 1
    beta = 0.9
    prev_x = xhat.copy()
    grad = np.asarray(At.dot(res))
    qfval = prev_obj

    # Step size computation
    L = ( np.linalg.norm( A.dot(grad) ) / np.linalg.norm(grad) )**2
    mu = 1.9 / L

    # Main loop
    if verbose >= 1 :
        print
        print "      |     ||Ax-y||     |  Cost function    Abs error      Rel error    |     Abs x          Rel x"
        print "------|------------------|-----------------------------------------------|------------------------------"
    iter = 1
    while True :
        if verbose >= 1 :
            print "%4d  |" % iter,
            sys.stdout.flush()

        # Smooth step
        x = xhat - mu*grad

        # Non-smooth step
        x = proximal( x )
        reg_term_x = omega( x )

        # Check stepsize
        tmp = x-xhat
        q = qfval + np.real( np.dot(tmp,grad) ) + 0.5/mu * np.linalg.norm(tmp)**2 + reg_term_x
        res = A.dot(x) - y
        res_norm = np.linalg.norm(res)
        curr_obj = 0.5 * res_norm**2 + reg_term_x

        # Backtracking
        while curr_obj > q :
            # Smooth step
            mu = beta*mu
            x = xhat - mu*grad

            # Non-smooth step
            x = proximal( x )
            reg_term_x = omega( x )

            # Check stepsize
            tmp = x-xhat
            q = qfval + np.real( np.dot(tmp,grad) ) + 0.5/mu * np.linalg.norm(tmp)**2 + reg_term_x
            res = A.dot(x) - y
            res_norm = np.linalg.norm(res)
            curr_obj = 0.5 * res_norm**2 + reg_term_x

        # Global stopping criterion
        abs_obj = abs(curr_obj - prev_obj)
        rel_obj = abs_obj / curr_obj
        abs_x   = np.linalg.norm(x - prev_x)
        rel_x   = abs_x / ( np.linalg.norm(x) + eps )
        if verbose >= 1 :
            print "  %13.7e  |  %13.7e  %13.7e  %13.7e  |  %13.7e  %13.7e" % ( res_norm, curr_obj, abs_obj, rel_obj, abs_x, rel_x )

        if abs_obj < eps :
            criterion = "Absolute tolerance on the objective"
            break
        elif rel_obj < tol_fun :
            criterion = "Relative tolerance on the objective"
            break
        elif abs_x < eps :
            criterion = "Absolute tolerance on the unknown"
            break
        elif rel_x < tol_x :
            criterion = "Relative tolerance on the unknown"
            break
        elif iter >= max_iter :
            criterion = "Maximum number of iterations"
            break

        # FISTA update
        t = 0.5 * ( 1 + sqrt(1+4*told**2) )
        xhat = x + (told-1)/t * (x - prev_x)

        # Gradient computation
        res = A.dot(xhat) - y
        xarr = np.asarray(x)

        grad = np.asarray(At.dot(res))

        # Update variables
        iter += 1
        prev_obj = curr_obj
        prev_x = x.copy()
        told = t
        qfval = 0.5 * np.linalg.norm(res)**2


    if verbose >= 1 :
        print "< Stopping criterion: %s >" % criterion

    opt_details = {}
    opt_details['residual'] = res_norm
    opt_details['cost_function'] = curr_obj
    opt_details['abs_cost'] = abs_obj
    opt_details['rel_cost'] = rel_obj
    opt_details['abs_x'] = abs_x
    opt_details['rel _x'] = rel_x
    opt_details['iterations'] = iter
    opt_details['stopping_criterion'] = criterion

    return x, opt_details

# Proximal operators
Define the proximal operators for some classical regularisation terms.

In [3]:
def non_negativity(x, lam, first_index, how_many) :
    """
    POCS for the first orthant (non-negativity)
    """
    v = x.copy()
    for i in range(first_index, how_many):
        if v[i] < 0.0:
            v[i] = 0.0
    return v

def soft_thresholding(x, lam, first_index, how_many) :
    """
    Proximal of L1 norm
    """
    # NB: this preserves non-negativity
    v = x.copy()
    for i in range(first_index, how_many):
        if v[i] <= lam:
            v[i] = 0.0
        else:
            v[i] -= lam
    return v

def projection_onto_l2_ball(x, lam, first_index, how_many) :
    """
    Proximal of L2 norm
    """
    # NB: this preserves non-negativity
    v = x.copy()
    xn = np.sqrt(sum(v[first_index:first_index+how_many]**2))
    # xn = np.linalg.norm(v[first_index:first_index+how_many])
    if xn > lam:
        for i in range(first_index, how_many):
            v[i] = v[i]/xn*lam
    return v

# Tractography optimisation
Find the relevant elements that describe the white matter starting from a tractogram.

* Forward problem:
\begin{equation}\nonumber
x,A,\epsilon \to y = Ax + \epsilon
\end{equation}
where $x$ encodes the weights associated to each element of the tractogram, $A$ is the linear operator that transforms the weights to the ground thruth signal and $\epsilon$ is the noise. Vector $y$ is the result of the forward model.

* Inverse problem:
\begin{equation}\nonumber
x^* = \arg\min_x \frac12 \|Ax-y\|_2^2 + \Omega(x)
\end{equation}
where $\Omega$ penalises solutions that don't exibit the a priori knowledge that we have on the forward model (e.g. sparsity, hierarchy, ...).

In [4]:
#####################################
######### ACCEPT THIS MAGIC #########
#####################################

from commit import trk2dictionary

datafolder = 'data'
path_trk = path.join(datafolder, 'fibers_connecting.trk')
# path_peaks = path.join(datafolder, 'peaks.nii.gz')
path_out = path.join(datafolder, 'example_output')

trk2dictionary.run(
    filename_trk = path_trk,
    path_out = path_out
#     filename_peaks = path_peaks
    )



import commit
commit.core.setup()

mit = commit.Evaluation( '.', datafolder )
name_dwi = 'dwi.nii.gz'
name_scheme = 'dwi.scheme'
mit.load_data( name_dwi, name_scheme )

mit.set_model( 'VolumeFractions' )

mit.model.set( hasISO = False)
mit.generate_kernels( regenerate=True )
mit.load_kernels()

mit.load_dictionary( path.basename(path_out) )

mit.set_threads()
mit.build_operator()

  from ._conv import register_converters as _register_converters



-> Creating the dictionary from tractogram:
	* Segment position = COMPUTE INTERSECTIONS
	* Fiber shift X    = 0.000 (voxel-size units)
	* Fiber shift Y    = 0.000 (voxel-size units)
	* Fiber shift Z    = 0.000 (voxel-size units)
	* Points to skip   = 0
	* Min segment len  = 1.00e-03
	* Do not blur fibers
	* Loading data:
		* tractogram
			- 50 x 50 x 50
			- 1.0000 x 1.0000 x 1.0000
			- 4589 fibers
		* no mask specified to filter IC compartments
		* no dataset specified for EC compartments
		* output written to "data/example_output"
	* Generate tractogram matching the dictionary: 
	  [ 4589 fibers kept ]
   [ 0.4 seconds ]

-> Loading data:
	* DWI signal...
		- dim    = 50 x 50 x 50 x 1
		- pixdim = 1.000 x 1.000 x 1.000
	* Acquisition scheme...
		- 1 samples, 1 shells
		- 0 @ b=0 , 1 @ b=1000.0
   [ 0.0 seconds ]

-> Preprocessing:
	* There are no b0 volume(s) for normalization... [ min=0.00,  mean=0.03, max=1.00 ]
   [ 0.0 seconds ]

-> Simulating with "Volume fractions" model:
   

# Define the regularisation term

In [5]:
tmp = np.loadtxt( path.join(datafolder,'fibers_connecting.txt')).astype(np.int32)
group_size = tmp[:,0][:]
group_isVB = tmp[:,1]

weights = np.array([1/float(g) for g in group_size])

bundles = np.insert(group_size,0,0)
structureIC = np.array(
    [np.arange(sum(bundles[:k+1]),sum(bundles[:k+1])+bundles[k+1]) for k in range(len(bundles)-1)]
    )


regnorms = [commit.solvers.group_sparsity, commit.solvers.non_negative, commit.solvers.non_negative]



group_norm = 2 # each group is penalised with its 2-norm

lambdas = [12.5, 0.0, 0.0]

regterm = commit.solvers.init_regularisation(mit,
                                             regnorms    = regnorms,
                                             structureIC = structureIC,
                                             group_is_ordered = False,
                                             weightsIC   = weights,
                                             group_norm  = group_norm,
                                             lambdas     = lambdas)

In [6]:
# RUN THIS BLOCK WITHOUT READING IT, PLEASE
from commit.proximals import (non_negativity,
                             omega_group_sparsity,
                             prox_group_sparsity,
                             soft_thresholding,
                             projection_onto_l2_ball)

group_sparsity = -1
non_negative = 0
norm1 = 1
norm2 = 2
norminf = np.inf
list_regnorms = [group_sparsity, non_negative, norm1, norm2]
list_group_sparsity_norms = [norm2, norminf]
import warnings

def regularisation2omegaprox(regularisation):
    lambdaIC  = float(regularisation.get('lambdaIC'))
    lambdaEC  = float(regularisation.get('lambdaEC'))
    lambdaISO = float(regularisation.get('lambdaISO'))
    if lambdaIC < 0.0 or lambdaEC < 0.0 or lambdaISO < 0.0:
        raise ValueError('Negative regularisation parameters are not allowed')

    normIC  = regularisation.get('normIC')
    normEC  = regularisation.get('normEC')
    normISO = regularisation.get('normISO')
    if not normIC in list_regnorms:
        raise ValueError('normIC must be one of commit.solvers.{group_sparsity,non_negative,norm1,norm2}')
    if not normEC in list_regnorms:
        raise ValueError('normEC must be one of commit.solvers.{group_sparsity,non_negative,norm1,norm2}')
    if not normISO in list_regnorms:
        raise ValueError('normISO must be one of commit.solvers.{group_sparsity,non_negative,norm1,norm2}')

    ## NNLS case
    if (lambdaIC == 0.0 and lambdaEC == 0.0 and lambdaISO == 0.0) or (normIC == non_negative and normEC == non_negative and normISO == non_negative):
        omega = lambda x: 0.0
        prox  = lambda x: non_negativity(x, 0, len(x))
        return omega, prox

    ## All other cases
    # Intracellular Compartment
    startIC = regularisation.get('startIC')
    sizeIC  = regularisation.get('sizeIC')
    if lambdaIC == 0.0:
        omegaIC = lambda x: 0.0
        proxIC  = lambda x: non_negativity(x, startIC, sizeIC)
    elif normIC == norm2:
        omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC])
        proxIC  = lambda x: projection_onto_l2_ball(x, lambdaIC, startIC, sizeIC)
    elif normIC == norm1:
        omegaIC = lambda x: lambdaIC * sum( x[startIC:sizeIC] )
        proxIC  = lambda x: soft_thresholding(x, lambdaIC, startIC, sizeIC)
    elif normIC == non_negative:
        omegaIC = lambda x: 0.0
        proxIC  = lambda x: non_negativity(x, startIC, sizeIC)
    elif normIC == group_sparsity:
        weightsIC   = regularisation.get('weightsIC')
        structureIC = regularisation.get('structureIC')
        if regularisation.get('group_is_ordered'): # This option will be deprecated in future release
            warnings.warn('The ordered group structure will be deprecated. Check the documentation of commit.solvers.init_regularisation.',DeprecationWarning)
            bundles = np.insert(structureIC,0,0)
            structureIC = np.array([np.arange(sum(bundles[:k+1]),sum(bundles[:k+1])+bundles[k+1]) for k in range(len(bundles)-1)]) # check how it works with bundles=[2,5,4]
            regularisation['structureIC'] = structureIC
            regulatisation['group_is_ordered'] = False # the group structure is overwritten, hence the flag has to be changed
            del bundles
        if not len(structureIC) == len(weightsIC):
            raise ValueError('Number of groups and weights do not coincide.')
        group_norm = regularisation.get('group_norm')
        if not group_norm in list_group_sparsity_norms:
            raise ValueError('Wrong norm in the structured sparsity term. Choose between %s.' % str(list_group_sparsity_norms))

        omegaIC = lambda x: omega_group_sparsity( x, structureIC, weightsIC, lambdaIC, group_norm )
        proxIC  = lambda x:  prox_group_sparsity( x, structureIC, weightsIC, lambdaIC, group_norm )
    else:
        raise ValueError('Type of regularisation for IC compartment not recognized.')


    # Extracellular Compartment
    startEC = regularisation.get('startEC')
    sizeEC  = regularisation.get('sizeEC')
    if lambdaEC == 0.0:
        omegaEC = lambda x: 0.0
        proxEC  = lambda x: non_negativity(x, startEC, sizeEC)
    elif normEC == norm2:
        omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:sizeEC])
        proxEC  = lambda x: projection_onto_l2_ball(x, lambdaEC, startEC, sizeEC)
    elif normEC == norm1:
        omegaEC = lambda x: lambdaEC * sum( x[startEC:sizeEC] )
        proxEC  = lambda x: soft_thresholding(x, lambdaEC, startEC, sizeEC)
    elif normEC == non_negative:
        omegaEC = lambda x: 0.0
        proxEC  = lambda x: non_negativity(x, startEC, sizeEC)
    else:
        raise ValueError('Type of regularisation for EC compartment not recognized.')

    # Isotropic Compartment
    startISO = regularisation.get('startISO')
    sizeISO  = regularisation.get('sizeISO')
    if lambdaISO == 0.0:
        omegaISO = lambda x: 0.0
        proxISO  = lambda x: non_negativity(x, startISO, sizeISO)
    elif normISO == norm2:
        omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:sizeISO])
        proxISO  = lambda x: projection_onto_l2_ball(x, lambdaISO, startISO, sizeISO)
    elif normISO == norm1:
        omegaISO = lambda x: lambdaISO * sum( x[startISO:sizeISO] )
        proxISO  = lambda x: soft_thresholding(x, lambdaISO, startISO, sizeISO)
    elif normISO == non_negative:
        omegaISO = lambda x: 0.0
        proxISO  = lambda x: non_negativity(x, startISO, sizeISO)
    else:
        raise ValueError('Type of regularisation for ISO compartment not recognized.')

    omega = lambda x: omegaIC(x) + omegaEC(x) + omegaISO(x)
    prox = lambda x: non_negativity(proxIC(proxEC(proxISO(x))),0,x.size) # non negativity is redunduntly forced

    return omega, prox

In [7]:
def solve(y, A, At, tol_fun = 1e-4, tol_x = 1e-6, max_iter = 1000, verbose = 1, x0 = None, regularisation = None):
    """
    Solve the regularised least squares problem
        argmin_x 0.5*||Ax-y||_2^2 + Omega(x)
    with the Omega described by 'regularisation'.
    Check the documentation of commit.solvers.init_regularisation to see how to
    solve a specific problem.
    """
    if regularisation is None:
        omega = lambda x: 0.0
        prox  = lambda x: non_negativity(x, 0, x.size)
    else:
        omega, prox = regularisation2omegaprox(regularisation)

    if x0 is None:
        x0 = np.ones(A.shape[1])

    return fista( y, A, At, tol_fun, tol_x, max_iter, verbose, x0, omega, prox)

In [16]:
y = mit.get_y()
A = mit.A
At = mit.A.T
verbose = True

x = solve(y,A,At,regularisation=regterm)


      |     ||Ax-y||     |  Cost function    Abs error      Rel error    |     Abs x          Rel x
------|------------------|-----------------------------------------------|------------------------------
   1  |   2.8184227e+01  |  1.0383389e+03  3.3059971e+04  3.1839287e+01  |  5.1252171e+01  4.6156715e+00
   2  |   1.6311518e+01  |  6.6110372e+02  3.7723522e+02  5.7061428e-01  |  5.9891936e+00  6.3545502e-01
   3  |   1.5496527e+01  |  4.9300316e+02  1.6810056e+02  3.4097257e-01  |  2.8826590e+00  3.8613705e-01
   4  |   1.5099432e+01  |  4.4030020e+02  5.2702960e+01  1.1969779e-01  |  1.5718034e+00  2.2010439e-01
   5  |   1.5219797e+01  |  4.1716826e+02  2.3131938e+01  5.5449899e-02  |  8.8699461e-01  1.2869810e-01
   6  |   1.5129286e+01  |  4.1360092e+02  3.5673456e+00  8.6250911e-03  |  5.4072623e-01  7.8431112e-02
   7  |   1.5100378e+01  |  4.1319707e+02  4.0384113e-01  9.7735718e-04  |  3.8188371e-01  5.5503930e-02
   8  |   1.5044498e+01  |  4.1444862e+02  1.2515471e+00  3

7.04583808734049