# Imports and Helper Functions

In [None]:
import os
import re
import pandas as pd
import itk
import tifffile
import skimage
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import colorcet as cc
from sklearn.linear_model import LinearRegression as LR

### Set the colormap
obj = 'cet_diverging_bwr_20_95_c54'

# Calibrations

## Setting Paths to Calibration Experiments

Will need to update these manually as more calibration experiments are conducted

In [None]:
### You should change this path to reflect your owk working directory
ddir = f'/path/to/Boaz/I2/20240930_iDISCO_round2/'
calims_552 = ['552101','552102','550750']
calims_552 = [os.path.join(ddir,'ANM'+st+'_JF552') for st in calims_552]
calims_673 = ['549895','552100','555643']
calims_673 = [os.path.join(ddir,'ANM'+st+'_JFX673') for st in calims_673]

### You should change this path to reflect your owk working directory
ddir_r1 = f'/path/to/Boaz/I2/2024-09-19_iDISCO_CalibrationBrains/'
calims_552_r1 = ['549057_left','550749_left']
calims_552_r1 = [os.path.join(ddir_r1,'ANM'+st+'_JF552') for st in calims_552_r1]
calims_673_r1 = ['550751_left','551089_left']
calims_673_r1 = [os.path.join(ddir_r1,'ANM'+st+'_JF673') for st in calims_673_r1]

### Additional directories and animal monikers can be added below, following the template above

###

### You should add the directories, path lists, and animal lsits to the following, as necessary
dirs = [ddir,ddir_r1]
calims_552_tot = calims_552 + calims_552_r1
calims_673_tot = calims_673 + calims_673_r1

ANMS_552 = [anm.split('/')[-1] for anm in calims_552_tot]
ANMS_673 = [anm.split('/')[-1] for anm in calims_673_tot]


## Region Stats

Detemrine number of regions

In [None]:
### Checking the size of the region_csv to initialize the array
fname = os.path.join(calims_552_tot[0],'itk/region_stats.csv')
data = np.loadtxt(fname,delimiter=',',skiprows=1)
regions = data[:,0]

Populate the region stat arrays for both channels

In [None]:
region_stats_552 = np.zeros((len(calims_552_tot),regions.shape[0],5))
for j, path in zip(range(len(calims_552_tot)),calims_552_tot):
    fname = os.path.join(path,'itk/region_stats.csv')
    data = np.loadtxt(fname,delimiter=',',skiprows=1)
    regions = data[:,0]
    area = data[:,1]
    chAF = data[:,2]
    ch552 = data[:,5]
    ch673 = data[:,8]
    region_stats_552[j,:,0] = regions
    region_stats_552[j,:,1] = chAF
    region_stats_552[j,:,2] = ch552
    region_stats_552[j,:,3] = ch673
    region_stats_552[j,:,4] = area
    
region_stats_673 = np.zeros((len(calims_673_tot),regions.shape[0],5))
for j, path in zip(range(len(calims_673_tot)),calims_673_tot):
    fname = os.path.join(path,'itk/region_stats.csv')
    data = np.loadtxt(fname,delimiter=',',skiprows=1)
    regions = data[:,0]
    area = data[:,1]
    chAF = data[:,2]
    ch552 = data[:,5]
    ch673 = data[:,8]
    region_stats_673[j,:,0] = regions
    region_stats_673[j,:,1] = chAF
    region_stats_673[j,:,2] = ch552
    region_stats_673[j,:,3] = ch673
    region_stats_673[j,:,4] = area

## Allen-to-Experiment (Forward) Global Volume Corrections

Calculate the growth/shrink factor for each experiment, as determined by the affine portion of the registration transformation

In [None]:
vol_552 = np.zeros((len(calims_552_tot),1))

for j, path in zip(range(len(calims_552_tot)),calims_552_tot):
    fname = os.path.join(path,'itk')
    param_files = [os.path.join(fname,"TransformParameters.{0}.txt".format(i)) for i in range(4)]
    parameter_object_init = itk.ParameterObject.New()
    # print(temp)
    parameter_object_init.ReadParameterFile(param_files)
    transform_parameters = np.array(parameter_object_init.GetParameter(0,'TransformParameters'),dtype=float)
    rotation = transform_parameters[:9].reshape(3, 3)
    translation = transform_parameters[-3:][..., np.newaxis]
    reg_affine: np.ndarray = np.append(rotation, translation, axis=1)
    reg_affine = np.append(reg_affine, [[0,0,0,1]], axis=0)

    vol_552[j] = np.linalg.det(rotation)
    
vol_673 = np.zeros((len(calims_673_tot),1))

for j, path in zip(range(len(calims_673_tot)),calims_673_tot):
    fname = os.path.join(path,'itk')
    param_files = [os.path.join(fname,"TransformParameters.{0}.txt".format(i)) for i in range(4)]
    parameter_object_init = itk.ParameterObject.New()
    # print(temp)
    parameter_object_init.ReadParameterFile(param_files)
    transform_parameters = np.array(parameter_object_init.GetParameter(0,'TransformParameters'),dtype=float)
    rotation = transform_parameters[:9].reshape(3, 3)
    translation = transform_parameters[-3:][..., np.newaxis]
    reg_affine: np.ndarray = np.append(rotation, translation, axis=1)
    reg_affine = np.append(reg_affine, [[0,0,0,1]], axis=0)

    vol_673[j] = np.linalg.det(rotation)
    

## Calibration Heatmaps

In [None]:
def calc_and_display_calibrations_cross_channel(stats_ch1,vols_ch1,im_paths_ch1,animals_ch1,
                                                title=['Insert Channel 1 Name Here','Insert Channel 1 Name Here'],axes=None,same_ch_flag=True,
                                                stats_ch2=None,vols_ch2=None,im_paths_ch2=None,animals_ch2=None):
    """
    Description:

        Function to calculate the covariance of regional intensity across brains.
        This is calculated in three contexts:
            Raw intensity
            Global-volume-corrected intensity
            Regional-volume-corrected intensity
        The covraiance coefficients are plotted as heatmaps and returned as numpy arrays
        
        Variables:

            stats_ch1 : numpy array
                        Array containing the regional statistics as measured by aligning
                        the experiment to the Allen atlas and computing the intensity and volume
                        
            vols_ch1 : numpy array
                        Array containing the percent volume change for each experiment calculated
                        from the affine matrix in the registration transformation
            
            im_paths_ch1 : string list
                        File paths to the folder containing each experiment
                            
            animals_ch1 : string list
                        Animal names for each experiment
                        
            title : string list
                        Channel names for each channel under consieration
                        
            axes : pyplot figure
                        Figure axes onto which the heatmaps will be plotted
                        Can be omitted and a new figure will be created
                        
            same_ch_flag : bool
                        Flag to indicate if the covariance corresponds to a cross- or auto-correlation
                        If True, the inputs for "ch1" above will be copied to "ch2".
                        If False, the user function expects similar inputs for 
                        stats_ch2, vols_ch2, im_paths_ch2, and animals_ch2 as described above for ch1.
                        
            stats_ch2, vols_ch2, im_paths_ch2, animals_ch2 : numpy array, numpy array, string list, string list
                        These are equivalent to the inputs for ch2. Must be specified in the function call
                        as variable name / value pairs: e.g. 
                        function(... , stats_ch2 = input_stats_array, ...)

                
    """

    if axes is None:
        fig, axes = plt.subplots(nrows=1,ncols=3,figsize=(15,5))
        
    if same_ch_flag:
        stats_ch2=np.copy(stats_ch1)
        vols_ch2=np.copy(vols_ch1)
        im_paths_ch2=im_paths_ch1.copy()
        animals_ch2=animals_ch1.copy()
        ch1_id = 2
        ch2_id = 2
    else:
        ch1_id = 2
        ch2_id = 3
    
    ### Raw Intensity
    coeffs_raw = np.zeros((stats_ch1.shape[0],stats_ch2.shape[0]))
    for j in range(stats_ch1.shape[0]):
        X = stats_ch1[j,:,ch1_id].reshape(-1,1)
        for jj in range(stats_ch2.shape[0]):
            y = stats_ch2[jj,:,ch2_id].reshape(-1,1)
            reg=LR().fit(X,y)
            reg.score(X,y)
            coeffs_raw[j,jj] = reg.coef_[0,0]
    
    ### Global Volume Correction
    coeffs_global = np.zeros((stats_ch1.shape[0],stats_ch2.shape[0]))
    for j in range(stats_ch1.shape[0]):
        X = stats_ch1[j,:,ch1_id].reshape(-1,1)*vols_ch1[j]
        for jj in range(stats_ch2.shape[0]):
            y = stats_ch2[jj,:,ch2_id].reshape(-1,1)*vols_ch2[jj]
            reg=LR().fit(X,y)
            reg.score(X,y)
            coeffs_global[j,jj] = reg.coef_[0,0]
    
    ### Regional Volume Correction
    ### Ignore the deprecation warning... 
    
    if same_ch_flag:
        ch1_id = 1
        ch2_id = 1
    else:
        ch1_id = 1
        ch2_id = 2
        
    for_ints_ch1 = []
    inv_ints_ch1 = []
    ints_ch1 = []
    ids_ch1 = []
    for path in im_paths_ch1:
        forward_regions = pd.read_csv(os.path.join(path,'itk','region_stats.csv'))
        inverse_regions = pd.read_csv(os.path.join(path,'invert_test','region_stats.csv'))

        intersect_ids = np.intersect1d(forward_regions.Region,inverse_regions.Region,return_indices=True)

        forward_int = np.array(forward_regions.mean_ch1)
        inverse_int = np.array(inverse_regions.intensity_mean)
        
        inv_int = np.zeros((len(inverse_int),3))
        for jj in range(len(inverse_int)):
            temp = inverse_int[jj]
            n = 0
            for nn, j in enumerate(temp.split()):
                if j != r'[' and j != r']':
                    temp1 = re.findall(r"(?<!\d|\.)\d+(?:\.\d+)?",j)
                    inv_int[jj,n] = np.squeeze(temp1[0]).astype(float)
                    n += 1

        X = np.array(forward_regions.N)[intersect_ids[1]]
        Y = np.array(inverse_regions.N)[intersect_ids[2]]
        for_ints_ch1.append(forward_int[intersect_ids[1]])
        inv_ints_ch1.append(inv_int[intersect_ids[2],ch1_id])
        ids_ch1.append(intersect_ids[0])

        ints_ch1.append(np.array(inv_int[intersect_ids[2],ch1_id])*(Y/X))
        
    if same_ch_flag:
        for_ints_ch2 = for_ints_ch1.copy()
        inv_ints_ch2 = inv_ints_ch1.copy()
        ints_ch2 = ints_ch1.copy()
        ids_ch2 = ids_ch1.copy()
    else:
        for_ints_ch2 = []
        inv_ints_ch2 = []
        ints_ch2 = []
        ids_ch2 = []
        for path in im_paths_ch2:
            forward_regions = pd.read_csv(os.path.join(path,'itk','region_stats.csv'))
            inverse_regions = pd.read_csv(os.path.join(path,'invert_test','region_stats.csv'))

            intersect_ids = np.intersect1d(forward_regions.Region,inverse_regions.Region,return_indices=True)

            forward_int = np.array(forward_regions.mean_ch1)
            inverse_int = np.array(inverse_regions.intensity_mean)
            
            inv_int = np.zeros((len(inverse_int),3))
            for jj in range(len(inverse_int)):
                temp = inverse_int[jj]
                n = 0
                for nn, j in enumerate(temp.split()):
                    if j != r'[' and j != r']':
                        temp1 = re.findall(r"(?<!\d|\.)\d+(?:\.\d+)?",j)
                        inv_int[jj,n] = np.squeeze(temp1[0]).astype(float)
                        n += 1

            X = np.array(forward_regions.N)[intersect_ids[1]]
            Y = np.array(inverse_regions.N)[intersect_ids[2]]
            for_ints_ch2.append(forward_int[intersect_ids[1]])
            inv_ints_ch2.append(inv_int[intersect_ids[2],ch2_id])
            ids_ch2.append(intersect_ids[0])

            ints_ch2.append(np.array(inv_int[intersect_ids[2],ch2_id])*(Y/X))

    coeffs_region = np.zeros((len(im_paths_ch1),len(im_paths_ch2)))
    for j in range(len(im_paths_ch1)):
        X = ints_ch1[j].flatten()[...,np.newaxis]
        id_X = ids_ch1[j]
        for jj in range(len(im_paths_ch2)):
            y = ints_ch2[jj].flatten()[...,np.newaxis]
            id_Y = ids_ch2[jj]
            id_inter = np.intersect1d(id_X,id_Y,return_indices=True)
            reg = LR().fit(X[id_inter[1]], y[id_inter[2]])
            coeffs_region[j,jj] = reg.coef_[0]

            
    
    ### Plot the covariances
    axes[0].imshow(coeffs_raw,cmap=obj, vmin=np.amin([np.amin(coeffs_raw),np.amin(coeffs_global),np.amin(coeffs_region)]), 
                 vmax=np.amax([np.amax(coeffs_raw),np.amax(coeffs_global),np.amax(coeffs_region)]))
    for j in range(coeffs_raw.shape[0]):
        for jj in range(coeffs_raw.shape[0]):
            axes[0].text(j-0.25, jj+0.1, f'{coeffs_raw[jj,j]:.3f}',c='k')
    # plt.colorbar()
    axes[0].set_title('Raw Intensity ')
    axes[0].set_xticks(np.arange(stats_ch1.shape[0]),animals_ch1,rotation=45,ha='right');
    axes[0].set_yticks(range(stats_ch2.shape[0]),animals_ch2,rotation=45);
            
    axes[1].imshow(coeffs_global,cmap=obj, vmin=np.amin([np.amin(coeffs_raw),np.amin(coeffs_global),np.amin(coeffs_region)]), 
                 vmax=np.amax([np.amax(coeffs_raw),np.amax(coeffs_global),np.amax(coeffs_region)]))
    for j in range(coeffs_global.shape[0]):
        for jj in range(coeffs_global.shape[0]):
            axes[1].text(j-0.25, jj+0.1, f'{coeffs_global[jj,j]:.3f}',c='k')
    # plt.colorbar()
    # axes[1].set_title(title+' vs ' + title + ' Regression Coefficient')
    axes[1].set_title('Global Volume Correction')
    axes[1].set_xticks(np.arange(stats_ch1.shape[0]),animals_ch1,rotation=45,ha='right');
    axes[1].set_yticks(range(stats_ch2.shape[0]),[],rotation=45);

    pcm=axes[2].imshow(coeffs_region,cmap=obj, vmin=np.amin([np.amin(coeffs_raw),np.amin(coeffs_global),np.amin(coeffs_region)]), 
                 vmax=np.amax([np.amax(coeffs_raw),np.amax(coeffs_global),np.amax(coeffs_region)]))
    fig.colorbar(pcm, ax=axes[:])

    for i in range(coeffs_region.shape[0]):
        for j in range(coeffs_region.shape[1]):
            text = plt.text(j, i, f"{coeffs_region[i, j]:.3f}",
                        ha='center', va='center', color='k')
    axes[2].set_title('Regional Volume Correction')
    axes[2].set_xticks(np.arange(stats_ch1.shape[0]),animals_ch1,rotation=45,ha='right');
    axes[2].set_yticks(range(stats_ch2.shape[0]),[],rotation=45);
    
    fig.suptitle(title[0] + ' vs ' + title[1] + ' Covariance', fontsize=16)
    
    return np.vstack((coeffs_raw, coeffs_global, coeffs_region))
    

### 552 vs 552

In [None]:
ceoffs_552 = calc_and_display_calibrations_cross_channel(region_stats_552,vol_552,calims_552_tot,ANMS_552,title=['552','552'],same_ch_flag=True)

### 673 vs 673

In [None]:
ceoffs_673 = calc_and_display_calibrations_cross_channel(region_stats_673,vol_673,calims_673_tot,ANMS_673,title=['673','673'],same_ch_flag=True)

### 552 vs 673

In [None]:
ceoffs = calc_and_display_calibrations_cross_channel(region_stats_552,vol_552,calims_552_tot,ANMS_552,
                                                     title=['552','673'],same_ch_flag=False,
                                                     stats_ch2=region_stats_673,vols_ch2=vol_673,im_paths_ch2=calims_673_tot,animals_ch2=ANMS_673)