In [None]:
def mrid_gaussian_registration_main(root, animals, mrids, sessions, mrid_dict, chMap=np.array([]), weighted_loss_f="none", bundle_start=0):
    gaussian_centers_coronal=np.array([])
    gaussian_centers_sagittal=np.array([])
    gaussian_centers_axial=np.array([])

    contrast_intensities_coronal=np.array([])
    contrast_intensities_sagittal=np.array([])
    contrast_intensities_axial=np.array([])
    
    for i, mrid in enumerate(mrids):
        for animal in animals:
            path = os.path.join(root, animal, "mri")
            for session in sessions:
                print("Animal: "+animal+" session: "+session+" MRID: "+mrid)
                sessionpath = os.path.join(path, session)
                analysedpath = os.path.join(sessionpath, "analysed")
                if os.path.exists(analysedpath):
                    mridpath = os.path.join(analysedpath, mrid)
                    print("The MRID directory: "+mridpath)
                    # Coronal gaussian centers
                    orient = "coronal"
                    orientpath = os.path.join(mridpath, orient)
                    if os.path.exists(orientpath):
                        gaussian_centers_coronal = np.load(os.path.join(orientpath, "gaussian_centers.npy"))
                        print("Coronal sliced gaussian centers exist: ")
                        print(gaussian_centers_coronal)
                        try:
                            contrast_intensities_coronal = np.load(os.path.join(orientpath, "contrast_intensities_fixedROI.npy"))
                            print("Coronal contrast intensities: ")
                            print(contrast_intensities_coronal)
                        except:
                            print("No fixed ROI contrast intensity available for Coronal slice: ")
                            contrast_intensities_coronal=np.array([])
                        
                    # Coronal gaussian centers
                    orient = "sagittal"
                    orientpath = os.path.join(mridpath, orient)
                    if os.path.exists(orientpath):
                        gaussian_centers_sagittal = np.load(os.path.join(orientpath, "gaussian_centers.npy"))
                        print("Sagittal sliced gaussian centers exist: ")
                        print(gaussian_centers_sagittal)
                        try:
                            contrast_intensities_sagittal = np.load(os.path.join(orientpath, "contrast_intensities_fixedROI.npy"))
                            print("Sagittal contrast intensities: ")
                            print(contrast_intensities_sagittal)
                        except:
                            print("No fixed ROI contrast intensity available for Coronal slice: ")
                            contrast_intensities_sagittal=np.array([])

                        
                    # Coronal gaussian centers
                    orient = "axial"
                    orientpath = os.path.join(mridpath, orient)
                    if os.path.exists(orientpath):
                        gaussian_centers_axial= np.load(os.path.join(orientpath, "gaussian_centers.npy"))
                        print("Axially sliced gaussian centers exist: ")
                        print(gaussian_centers_axial)
                        try:
                            contrast_intensities_axial = np.load(os.path.join(orientpath, "contrast_intensities_fixedROI.npy"))
                        except:
                            contrast_intensities_axial=np.array([])

                    
                    gaussian_centers_3d=combined_gaussian_centers(gaussian_centers_coronal,
                                                                  contrast_intensities_coronal,
                                                                  gaussian_centers_sagittal,
                                                                  contrast_intensities_sagittal,
                                                                  gaussian_centers_axial,
                                                                  contrast_intensities_axial,
                                                                  savepath=mridpath)

                    print("Measured 3D coordinates of gaussian centers: ")
                    print(gaussian_centers_3d)

                    if mrid=="electrode":
                        fitted_points = gaussian_centers_3d
                    else:
                        fitted_points = pointset_register_main(gaussian_centers_3d, mrid_dict[mrid], bundle_start, weighted_loss_f, visualization=True)
                        
                        print("Registered 3D coordinates of MRID CoMs: ")
                        print(fitted_points)
                        np.save(os.path.join(mridpath, "mrid_registered_coordinates.npy"), fitted_points)
                    if chMap.any():    
                        # Mapping the channels to physical coordinate indeces (integers) in MRI space
                        ch_coords=map_electrodes_main(fitted_points, mrid_dict[mrid])
                        np.save(os.path.join(mridpath, "channel_mri_coordinates.npy"), ch_coords[0])
    
                        moving_idx_filename = "moving_img_resampled25um-indeces.npy"
                        fixed_idx_filename = "fixed_img-indeces.npy"
                        moving_idx_path = os.path.join(sessionpath, "registration", moving_idx_filename)
                        fixed_idx_path = os.path.join(sessionpath, "registration", fixed_idx_filename)
                        print("Loading the moving coordinates: " + moving_idx_path)
                        print("Loading the fixed coordinates: "+fixed_idx_path)
                        moving_coordinates = np.load(moving_idx_path)
                        fixed_coordinates = np.load(fixed_idx_path)
                        
                        # Mapping the channel coordinates to the Atlas space
                        dwi1Dsignal = map_channels_to_atlas(chMap, ch_coords, moving_coordinates, fixed_coordinates, savepath=mridpath)
                        np.save(os.path.join(mridpath, "dwi_1D_cross_section_pixel_values.npy"),dwi1Dsignal)
                    
    return gaussian_centers_3d, fitted_points

In [None]:
def mrid_main(sessionpath, roi_name, basestructs, filename_data, slice_orientation, mrid_dict, bundle_start=0, slice_thickness=0.8, transform_filename="", inverseTransform=False, relaxation_verbose=False, fixedROIanalysis=True, te=[4, 4.09], labelsname="labels.txt"):
    savepath=os.path.join(sessionpath, "analysed", roi_name, slice_orientation)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    num_echos = 0
    nii_data, data, segmentation, anat, labelsdf = get_data(sessionpath, filename_data)
    voxel_size = nii_data.header['pixdim'][1:3]
    print("Voxel size of the data: "+str(voxel_size))
    
    if len(np.shape(data))>3:
        num_echos=np.shape(data)[-1]
        print("Detected number of echo images:" + str(num_echos))
    else:
        print("No echo images detected, this is not a T2*MAP image")

    # detectedPixels, zmap = find_roi(data, segmentation, anat, basestructs, labelsdf, [roi_name], num_echos)
    # # Save dictionary
    # with open(os.path.join(savepath, "induced_contrast_dict.pkl"), 'wb') as f:
    #     pickle.dump(detectedPixels, f)
    
    detectedPixels, zmap, zmap_nii = run_roi(filename_data, roi_name, data, nii_data, segmentation, anat, basestructs, 
            labelsdf, num_echos, savepath)
    
    # plot_all_roi(data, detectedPixels[roi_name], roi_name, savepath, color='r',slice_orientation=slice_orientation, num_echos=num_echos, savefigs=True, dpi=1000)

    if num_echos>0:
        heatmaps, img_slice = run_relaxation(savepath, filename_data, data, nii_data, roi_name, segmentation, anat, 
                                             basestructs, labelsdf, te, r=1, unsupervised=False, relaxation_verbose=relaxation_verbose)
        print("The image slice where the bundle is:" + str(img_slice))

    if transform_filename:
        print("Transformation exists, warping the contrast heatmaps")
    
        fixed_path = heatmap_warp(filename_data, roi_name, savepath, sessionpath, transform_filename, inverseTransform, num_echos)
        _, data_volume =read_data(fixed_path)
        
        gaussian_centers, heatmaps, ind = run_gaussian_analysis(filename_data, savepath, roi_name, slice_orientation,
                                                                data_volume, labelsdf, px_size=25)
        
        del data_volume

    if fixedROIanalysis:
        pattern_echocenters = get_centers(detectedPixels[roi_name])
        pattern_centers = np.nanmean(pattern_echocenters, axis=1)
        contrasts, densities = mrid_contrast_fixedROI(data, segmentation, anat, labelsdf, voxel_size, slice_thickness,
                                                      img_slice, basestructs, 
                                                      pattern_centers, mrid_dict[roi_name]["dimensions"][bundle_start:, :], 
                                                      mrid_dict[roi_name]["ionp_amount"][bundle_start:], num_echos, savepath, roi_name)
        plt.figure()
        plt.plot(densities, contrasts)
        plt.show()
        
    return detectedPixels, heatmaps

In [None]:
def run_roi(filename_data, roi_name, data, nii_data, segmentation, anat, basestructs, labelsdf, num_echos, savepath):
    detectedPixels, zmap, zmap_nii = find_roi(data, segmentation, anat, basestructs, labelsdf, [roi_name], num_echos)
    # Save dictionary
    with open(os.path.join(savepath, "induced_contrast_dict.pkl"), 'wb') as f:
        pickle.dump(detectedPixels, f)

    filename=roi_name+"-"+slice_orientation+"-zmap.npy"
    np.save(os.path.join(savepath, filename), zmap)

    new_zmapnii_filename=filename_data+"-"+roi_name+"-zmap.nii.gz"
    save_nii(zmap_nii, nii_data.affine, os.path.join(savepath, new_zmapnii_filename))
    return detectedPixels, zmap, zmap_nii

In [None]:
def run_relaxation(savepath, filename_data, data, nii_data, roi_name, segmentation, anat, basestructs, labelsdf, te, r=1, unsupervised=False, relaxation_verbose=False):
    #running the relaxation for given mrid
    if relaxation_verbose:
        heatmaps, heatmap_data, img_slice = get_relaxation(data, roi_name, segmentation, anat, 
                                                       basestructs, labelsdf, te, r=1, unsupervised=False, savepath=savepath)
    else:
        heatmaps, heatmap_data, img_slice = get_relaxation(data, roi_name, segmentation, anat, 
                                                       basestructs, labelsdf, te, r=1, unsupervised=False)
    #saving the relaxation results
    filename=roi_name+"-"+slice_orientation+"-heatmap_r="+str(1)+"sqr.npy"
    np.save(os.path.join(savepath, filename), heatmaps)
    
    # print(np.shape(heatmap_data))
    # for s in range(np.shape(heatmap_data)[-1]):
    filename=roi_name+"-"+slice_orientation+"-heatmap-img_slice="+str(img_slice)+".pdf"
    plt.figure()
    plt.imshow(data[:,:,img_slice,0], cmap='gray')
    # plt.imshow(np.sum(heatmaps, axis=0), **cmap_args, alpha=0.75)
    plt.imshow(heatmap_data[:,:,img_slice], **cmap_args, alpha=0.75)
    plt.colorbar()
    plt.savefig(os.path.join(savepath, filename), dpi=3000)
    plt.show()

    
    new_heatmap_filename=filename_data+"-"+roi_name+"-heatmap.nii.gz"
    # heatmap_data[:,:,img_slice]=np.sum(heatmaps,axis=0)
    save_nii(heatmap_data, nii_data.affine, os.path.join(savepath, new_heatmap_filename))

    return heatmaps, img_slice

In [2]:
## MOVED TO handlers.ipynb

# def get_data(sessionpath, filename_data):
#     """
#     Gets the raw data together with MRID segmentation and anatomical segmentation
#     filename_data: filename of the raw T2*Map MGE data
#     img_slice: image slice of interest
#     """
#     anatpath = os.path.join(sessionpath, "anat")
#     filename_data_full = ".".join((filename_data,"nii", "gz"))
#     filename_segmentation=".".join((filename_data+"-segmentation", "nii", "gz"))
#     filename_anat=".".join((filename_data+"-anat", "nii", "gz"))

#     nii_data, data=read_data(os.path.join(anatpath, filename_data_full))
#     _, segmentation=read_data(os.path.join(anatpath, filename_segmentation))
#     _, anat=read_data(os.path.join(anatpath, filename_anat))

#     labelsdf=read_labels(os.path.join(sessionpath, "anat", "labels.txt"))
#     print(labelsdf)
    
#     print("Data shape of anatomy segmentation" + str(np.shape(anat)))
#     print("Data shape of MRI data" + str(np.shape(data)))
#     print("Data shape of MRID segmentation" + str(np.shape(segmentation)))

#     # print("Voxel dimensions: " + str(nii_data.header['pixdim']))
#     return nii_data, data, segmentation, anat, labelsdf

In [None]:
def run_gaussian_analysis(filename, savepath, roi_name, orientation, data_volume, labelsdf, px_size=25):
    coronalFlag = True
    heatmap_warped_filename = ".".join((filename+"-"+roi_name+"-heatmap-warped", "nii", "gz"))
    resampled_path = os.path.join(savepath, heatmap_warped_filename)

    segmentation_filename = ".".join((filename+"-"+roi_name+"-heatmap-segmentation-warped", "nii", "gz"))
    segmentation_path = os.path.join(savepath, segmentation_filename)

    _, heatmap_warped=read_data(resampled_path)
    _, segmentation_warped = read_data(segmentation_path)
    
    heatmaps, ind = get_maxproj(heatmap_warped, segmentation_warped, roi_name, labelsdf, orientation=orientation)
    if orientation == "coronal":
        fixed_img = data_volume[:,:,ind[2]]
    elif orientation == "sagittal":
        coronalFlag=False
        fixed_img = data_volume[ind[0],:,:]
    elif orientation == "axial":
        fixed_img = data_volume[:,ind[1],:]
        
    gaussian_centers, gaussAmp, gaussSig, popt = find_gaussian_centers(heatmaps, fixed_img, px_size, coronal=coronalFlag)

    gausscent_filename ="gaussian_centers.npy"
    np.save(os.path.join(savepath, gausscent_filename), gaussian_centers)

    gausscent_filename ="gaussian_amplitudes.npy"
    np.save(os.path.join(savepath, gausscent_filename), gaussAmp)

    gausscent_filename ="gaussian_sigmas.npy"
    np.save(os.path.join(savepath, gausscent_filename), gaussSig)

    return gaussian_centers, heatmaps, ind