# Adjusted Mutual Information Tutorial:
In this tutorial, we will go over how to run the adjusted_mutual.py script. The purpose of the script is to allow for the calculation of the Adjusted Mutual Information between sets of atlases. This script outputs a png heatmap and csv file containing the AMI values.

First the relevant libraries must be imported. adjusted_mutual utilizes the nibabel, numpy, glob, sklearn, argpars, and matplotlib libraries.

In [None]:
import nibabel as nb
import numpy as np
import glob
from matplotlib import pyplot as plt
from sklearn import metrics as skm
from argparse import ArgumentParser

Next we have to specify the inputs. For the purpose of this jupyter notebook sys.argv is used to store the inputs, but when running the script in the terminal:

python $<$path_to_script$>$/adjusted_mutual.py $<$input_dir$>$ --output_dir $<$output_dir$>$ --fig_name $<$fig_name$>$ --voxel_size $<$vox_size$>$ --atlases $<$atlas1$>$ $<$atlas2$>$ $<$atlas3$>$ ...
    
input_dir = The path to the dirctory containing the atlases you intend to analyze

output_dir = The path for the directory where you intend to have the output files saved

fig_name = Name of the output png and csv files containing the adjusted mutual information heatmap. If not specified, then the files will be called "AMI_Matrix"

voxel_size = Voxel_size for atlases to be analyzed, this will filter the files located in the input_dir folder for anything with <atlas_name>\_res-<VOX>x<VOX>x<VOX>.nii.gz. Default is 1, meaning any file ending in \_res-1x1x1.nii.gz will be selected if atlases are specified.

atlases = List of atlas names to analyze, located in the directory specified by input_dir. If not specified, then all atlases in the input directory that meet the standards set by voxel_size will be analyzed. If specified, these override the value specified in '--voxel_size'.

In [None]:
input_dir = '/Users/ross/Documents/neuroparc/atlases/label/Human'
output_dir = '/Users/ross/Documents/neuroparc/atlases'
fig_name = 'AMI_analysis'
voxel_size = '1'
atlases = ['AAL_space-MNI152NLin6_res-1x1x1.nii.gz','AICHAJoliot2015_space-MNI152NLin6_res-1x1x1.nii.gz']

#Necessary for running this function in a jupyter notebook
sys.argv = ['',input_dir, '--output_dir',output_dir,'--fig_name',fig_name,'--voxel_size',voxel_size, '--atlases',atlases[0],atlases[1]]

Now we have to define the adjusted_mutual_info function, which is where the actual AMI calculation occurs. This uses the sklearn.metrics.adjusted_mutual_info_score function to calculate the value and returns it.

In [None]:
def adjusted_mutual_info(atlas1, atlas2):
    """Calculate adjusted mutual information between two atlases

    Parameters
    ----------
    atlas1 : str
        path to the first atlas being analyzed
    atlas2: str
        path to the second atlas being analyzed
    """
    
    #Load in the atlas raw data
    at1 = nb.load(atlas1)
    at2 = nb.load(atlas2)

    atlas1 = at1.get_data()
    atlas2 = at2.get_data()
    
    #Flatten both matricies into a vector and feed into the function
    AMI = skm.adjusted_mutual_info_score(atlas1.flatten(),atlas2.flatten())

    #Return resulting value
    return AMI


Now the main function needs to be defined and called, which is where the png and csv files are created from the results of adjusted_mutual_info. The inputs specified by the user are taken in and the atlases are iterated through to generate an AMI value for every possible combination of atlases.

If you wish to change the format of the png file or csv file, you are able to change any of the formating code after the comment "Save AMI matrix to csv file, comma delimited" without compromising any of the calculations.

In [None]:
def main():
    parser = ArgumentParser(
        description="Script to calculate the adjsted mutual information of atlases"
    )
    parser.add_argument(
        "input_dir",
        help="""The directory where the atlases you wish to analyze are saved.""",
        action="store",
    )
    parser.add_argument(
        "--output_dir",
        help="""Directory to save the output adjacency matrix. If not specified, then
        the input directory will be used. Default is None.""",
        action="store",
        default=None,
    )
    parser.add_argument(
        "--fig_name",
        help = """Name to use for the output png and csv files. If not specified, then
        the name 'AMI_Matrix' will be used. Default is 'AMI_Matrix'""",
        action = "store",
        default = 'AMI_Matrix',
    )
    parser.add_argument(
        "--voxel_size",
        help="""Voxel_size for atlases to be analyzed, this will filter
        the files located in the input_dir file for anything with 
        <atlas_name>_res-<VOX>x<VOX>x<VOX>.nii.gz. Default is 1.""",
        action="store",
        default="1",
    )
    parser.add_argument(
        "--atlases",
        help="""List of atlas names to analyze, located in the directory
        specified by input_dir. If not specified, then all atlases in the input
        directory will be analyzed. These override the value specified in
        '--voxel_size'. Default is None.""",
        nargs="+",
    )

    # ------- Parse CLI arguments ------- #
    result = parser.parse_args()
    input_dir = result.input_dir
    output_dir=result.output_dir
    fig_name = result.fig_name
    voxel_size = result.voxel_size
    atlases = result.atlases
    

    # Save outputs to input directory if output directory not specified
    if not output_dir:
        output_dir=input_dir

    #Search for all applicable files
    if atlases:
        #Append input directory to atlas_names
        atlas_paths = [f"{input_dir}/{i}" for i in atlases]
    else:
        dims = f"{voxel_size}x{voxel_size}x{voxel_size}"
        
        atlas_paths = [
        i for i in glob.glob(input_dir + f"/*{dims}.nii.gz") if dims in i
        ]


    #Create a ndarray of zeros to be filled in
    AMI_array = np.zeros((len(atlas_paths)+1,len(atlas_paths)+1))

    #Loop through all combinations of atlases specified and calculate AMI
    for i in range(len(atlas_paths)):
        for j in range(len(atlas_paths)):
            if i >= j:
                AMI = adjusted_mutual_info(atlas_paths[i],atlas_paths[j])
                AMI_array[int(i)][int(j)]=float(AMI)
                AMI_array[int(j)][int(i)]=float(AMI)

    #Save AMI matrix to csv file, comma delimited
    np.savetxt(f'{output_dir}/{fig_name}.csv', AMI_array, delimiter=",")

    #Generate labels for figure
    for i in range(len(atlases)):
        atlases[i] = atlases[i].split('_space-')[0]

    fig, ax = plt.subplots()
    im = ax.imshow(AMI_array, cmap="gist_heat_r") #Can specify the colorscheme you wish to use
    ax.set_xticks(np.arange(len(atlases)))
    ax.set_yticks(np.arange(len(atlases)))

    ax.set_xticklabels(atlases)
    ax.set_yticklabels(atlases)

    #Label x and y-axis, adjust fontsize as necessary
    plt.setp(ax.get_xticklabels(), fontsize=6, rotation=90, ha="right", va="center", rotation_mode="anchor")
    plt.setp(ax.get_yticklabels(), fontsize=6)

    plt.colorbar(im, aspect=30)
    ax.set_title("Adjusted Mutual Information Between Atlases")
    
    fig.tight_layout()
    plt.show()
    
    #Save the figure
    plt.savefig(f'{output_dir}/{fig_name}.png', dpi=1000)



if __name__ == "__main__":
    main()

## Outputs:
After running this script to completion, for every potential combination of atlases you should have the following in your output directory:
1. A file named {figname}.png or AMI_Matrix.png, containing a heatmap of the AMI values for each combination of atlases.
2. A file named {figname}.csv or AMI_Matrix.csv, containing the AMI matrix information for the acccompanying png file.

## Common Errors:
- Issues may arise if atlases being compared have different voxel sizes, as the overlap measured by the dice coefficient may not be accurate
- If the atlases you are using do not end with '{vox}\_{vox}\_{vox}.nii.gz', issues will arise if you are using the --voxel_size approach to selecting atlases to analyze. Either specify the exact atlases you wish to analyze, rename the atlases to follow the intended structure, or edit the first if/else statement in the main function.