In [None]:
dwi_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_DWI_v1.01.nii.gz")
nii_dwi=nib.load(dwi_path)
dwi=np.asanyarray(nii_dwi.dataobj)
dwi=dwi[:,:,:,0]

# mask_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_v2_brainmask_bin.nii.gz")
# nii_mask=nib.load(mask_path)
# mask=np.asanyarray(nii_mask.dataobj)

# t2s_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_T2star_v1.01.nii.gz")
# nii_t2s=nib.load(t2s_path)
# t2s=np.asanyarray(nii_t2s.dataobj)

atlas_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_atlas_v4.nii.gz")
nii_atlas=nib.load(atlas_path)
atlas=np.asanyarray(nii_atlas.dataobj)

labels_path='./atlas_labels.rtf'
atlaslabelsdf=read_whs_labels(labels_path)

# Util functions

In [1]:
def map_electrodes_main(fitted_points, mrid_dict):
    last_ch_dist = get_dist_to_deepest_ch(mrid_dict)
    print("Distance from last (deepest) CoM to the deepest channel (um):")
    print(last_ch_dist)
    px_size = 25
    inter_ch_dist = 50
    chs_mapped = 0
    ch_coords = []
    total_ch = 64
    
    for i in range(len(fitted_points)-1):
        bottom = fitted_points[-1*i-1, :]
        top = fitted_points[-1*i-2, :]
        
        dist = np.linalg.norm(top-bottom) * px_size
        if i==0 and i!=len(fitted_points)-2:
            print("First (deepest) segment mapping")
            dist = dist + last_ch_dist
            chs_in_segment = np.floor(dist/inter_ch_dist).astype(int)
            if chs_in_segment + chs_mapped < total_ch:
                chs_in_segment = chs_in_segment
            else:
                chs_in_segment = total_ch - chs_mapped
            
            
            ch_coords_segment,_ = ch_coordinates(bottom, top, chs_in_segment, offset=last_ch_dist)
            
        elif i==0 and i==len(fitted_points)-2:
            print("Only 2 patterns available, simply linear mapping")
            dist = dist + last_ch_dist
            chs_in_segment = np.floor(dist/inter_ch_dist).astype(int)
            print("Number channels between patterns: ")
            print(chs_in_segment)
            chs_in_segment = total_ch
                
            ch_coords_segment,_ = ch_coordinates(bottom, top, chs_in_segment, offset=last_ch_dist)
            
        elif i==len(fitted_points)-2 and i!=0:
            print("Last segment mapping")
            chs_in_segment = total_ch - chs_mapped
            print("Remaining channels to be mapped: ")
            print(chs_in_segment)
            ch_coords_segment,_ = ch_coordinates(bottom, top, chs_in_segment)
            
        else:
            chs_in_segment = np.floor(dist/inter_ch_dist).astype(int)
            if chs_in_segment + chs_mapped < total_ch:
                chs_in_segment = chs_in_segment
            else:
                chs_in_segment = total_ch - chs_mapped
                
            ch_coords_segment,_ = ch_coordinates(bottom, top, chs_in_segment)
            
        ch_coords.append(ch_coords_segment)
        print("Number of channels mapped in this segment: ")
        print(chs_in_segment)
        chs_mapped = chs_mapped + chs_in_segment
        print("Total channels mapped: ")
        print(chs_mapped)

    ch_coords = np.vstack(ch_coords)
    print("Mapped channel coordinates: ")
    print(ch_coords)
    visualize_channelfit(ch_coords, fitted_points)
    return ch_coords

In [None]:
def visualize_channelfit(ch_coords, fit3d):
    fig = go.Figure()

    # Scatter for gauss3d (markers only)
    fig.add_trace(go.Scatter3d(
        x=ch_coords[:, 0], 
        y=ch_coords[:, 1], 
        z=ch_coords[:, 2],
        mode='markers',
        marker=dict(size=5, color='blue'),
        name='gauss3d'
    ))
    
    # Scatter for fit3d (markers + connecting lines)
    fig.add_trace(go.Scatter3d(
        x=fit3d[:, 0], 
        y=fit3d[:, 1], 
        z=fit3d[:, 2],
        mode='markers+lines',
        marker=dict(size=5, color='red'),
        line=dict(width=2, color='red'),
        name='fit3d'
    ))
    
    fig.update_layout(
        scene=dict(
            # ax.set_xlabel("Medial --> Lateral")

            # ax.set_ylabel("Ventral --> Dorsal")
            # ax.set_zlabel("Posterior --> Anterior")
            xaxis_title="Medial --> Lateral",
            yaxis_title="Ventral --> Dorsal",
            zaxis_title="Posterior --> Anterior"
        ),
        legend=dict(x=0, y=1)
    )
    
    fig.show()
    

In [None]:
def ch_coordinates(bottom, top, nchannels, offset=0, vox_res=25):
    # vox_res=39.06
#     ch_coord=np.linspace(pstart, pstop, nchannels).astype('int')
    unitvec=top-bottom
    unitvec=unitvec/np.linalg.norm(unitvec)
    
    ch_coord=np.zeros((nchannels,3))

    if offset>0:
        bottom = bottom - (offset/vox_res)*unitvec
    for i in range(nchannels):
        ch_coord[i,:]=np.round(bottom+(i*unitvec*(50/vox_res)))
    
    return (ch_coord).astype('int'), unitvec

def get_unitvec(p1, p2):
    unitvec=p2-p1
    unitvec=unitvec/np.linalg.norm(unitvec)
    
    return unitvec

def get_unitvec_linefit(popt):
    xunit=np.sqrt(1/(1+popt[0]**2))
    yunit=xunit*popt[0]
    unitvec = np.array([xunit, yunit])
    
    return unitvec

In [None]:
def map_channels_to_atlas(chMap, ch_coord, moving_coordinates, fixed_coordinates, savepath):
    # Coronal slice to be plotted selector
    temp_y=0
    minPixVal = 2e16
    num_channels = len(ch_coord)
    dwi1Dsignal = np.zeros((num_channels,))
    pyrCh = 9999
    regionNames = []
    pyrChIdx = 0
    pyrLyExists = False
    
    with open(os.path.join(savepath, "channel_atlas_coordinates.txt"), 'w') as f:
        
        for idx, coord in enumerate(ch_coord):
            print(idx)
            atlasIdx=((moving_coordinates[:,0] == coord[0]) & (moving_coordinates[:,1] == coord[1]) & (moving_coordinates[:,2] == coord[2]))            
            if fixed_coordinates[atlasIdx].any():
                print("Exact coordinate exists")
            else:
                print("no exact atlas coord")
                atlasIdx=((moving_coordinates[:,0]>=coord[0]-1) & (moving_coordinates[:,0]<=coord[0]+1) &
                  (moving_coordinates[:,1] >=coord[1]-1) & (moving_coordinates[:,1] <=coord[1]+1) &
                  (moving_coordinates[:,2]>= coord[2]-1) & (moving_coordinates[:,2] <= coord[2]+1)
                     )   

            atlasCoord=fixed_coordinates[atlasIdx][0]
            x,y,z =atlasCoord.astype(int)
            label=atlas[x,y,z]   
            
            anat_region=atlaslabelsdf["Anatomical Regions"][atlaslabelsdf["Labels"]==label].values[0]
            regionNames.append(anat_region)
            
            currPixVal = dwi[x,y,z]
            dwi1Dsignal[idx] = currPixVal
            if anat_region == "Cornu ammonis 1":
                pyrLyExists = True
                if currPixVal < minPixVal:
                    minPixVal = currPixVal
                    pyrCh = chMap[idx]
                    pyrChIdx = idx
                     
            line="CH:"+str(chMap[idx])+" in "+anat_region+' Segment: '+str(label) + " atlas coord: " + str(atlasCoord)
            print(line)
            f.write(line)
            f.write('\n')

        # Writing the pyramidal channel
        line="CH:"+str(pyrCh)+" in pyramidal layer CA1"
        print(line)
        f.write(line)
        f.write('\n')

    if pyrLyExists:
        pixelValues = dwi1Dsignal
        pixelValues = (pixelValues - np.min(pixelValues)) / (np.max(pixelValues) - np.min(pixelValues))
        plt.figure(figsize=(25,10))
        # Get unique categories and assign each a color
        unique_regions = list(set(regionNames))
        colors = plt.cm.get_cmap("tab10", len(unique_regions))  # Color map
        
        region_to_color = {region: colors(i) for i, region in enumerate(unique_regions)}
        
        # Plot line segments with color depending on region
        for i in range(len(pixelValues) - 1):
            region = regionNames[i]
            plt.plot([i, i+1], [pixelValues[i], pixelValues[i+1]],
                     color=region_to_color[region], linewidth=2)
        
        # Optional: Add legend
        for region in unique_regions:
            plt.plot([], [], color=region_to_color[region], label=region)
            
        plt.axvline(x=pyrChIdx, color='red', linestyle='--', linewidth=2, label='Pyramidal Layer')
        plt.xticks(ticks=np.linspace(0,num_channels-1,num_channels), labels=chMap)
        plt.legend(title="Anatomical Region")
        plt.xlabel("Channel Index")
        plt.ylabel("Pixel Value")
        plt.title("Pixel Values by Region")
        plt.grid(True)
        plt.savefig(os.path.join(savepath, "dwi_1D_cross_section.pdf"), dpi=2000)
        plt.show()
        
    return dwi1Dsignal

In [None]:
def get_mapped_ch(chmap_path, num_channels=64):
    ch_coord = np.zeros((len(num_channels), 3))

    detected=0
    with open(chmap_path) as f:
        lines = f.readlines()
        for i,line in enumerate(lines):
            if i>1:
                ch=int(line.split()[0].split(":")[1])
                if ch in ch_selected:
                    x,y,z = line.split("[")[1].split("]")[0].split()
                    ch_coord[detected,:] = np.array([x,y,z])
                    detected=detected+1

    ch_coord=ch_coord.astype(int)
    
    return ch_coord

def map_coord(coord, moving_coordinates, fixed_coordinates):
    fixedIdx=((moving_coordinates[:,0] == coord[0]) & (moving_coordinates[:,1] == coord[1]) & (moving_coordinates[:,2] == coord[2]))            
    if fixed_coordinates[fixedIdx].any():
        fixedCoord=fixed_coordinates[fixedIdx][0]
    else:
        fixedIdx=((moving_coordinates[:,0]>=coord[0]-1) & (moving_coordinates[:,0]<=coord[0]+1) &
                      (moving_coordinates[:,1] >=coord[1]-1) & (moving_coordinates[:,1] <=coord[1]+1) &
                      (moving_coordinates[:,2]>= coord[2]-1) & (moving_coordinates[:,2] <= coord[2]+1)
                         ) 
        
        fixedCoord=np.median(fixed_coordinates[fixedIdx], axis=0)    
    
    return fixedCoord.astype(int)

def map_channels_legacy(bottom, top, px_size, offset_pattern_elec, working_dir, chMap, moving_coordinates, fixed_coordinates, atlas, labelsdf, filename="Elastic_chMap", num_ch=64):
    unitvec = get_unitvec(bottom, top)
    
    bottom=bottom-(offset_pattern_elec/px_size)*unitvec
    bottom=np.round(bottom).astype(int)

    ch_coord,_=ch_coordinates(bottom, top, num_ch, vox_res=px_size)
    
    # Coronal slice to be plotted selector
    temp_y=0

    with open(os.path.join(working_dir, filename+".txt"), 'w') as f:
        line="Top Coord:"+str(top)
        f.write(line)
        f.write('\n')

        line="Bottom Coord:"+str(bottom)
        f.write(line)
        f.write('\n')

        for idx, coord in enumerate(ch_coord):
            print(idx)
            atlasIdx=((moving_coordinates[:,0] == coord[0]) & (moving_coordinates[:,1] == coord[1]) & (moving_coordinates[:,2] == coord[2]))            
            if fixed_coordinates[atlasIdx].any():
                print("Exact coordinate exists")
            else:
                print("no exact atlas coord")
                atlasIdx=((moving_coordinates[:,0]>=coord[0]-1) & (moving_coordinates[:,0]<=coord[0]+1) &
                  (moving_coordinates[:,1] >=coord[1]-1) & (moving_coordinates[:,1] <=coord[1]+1) &
                  (moving_coordinates[:,2]>= coord[2]-1) & (moving_coordinates[:,2] <= coord[2]+1)
                     )   

            atlasCoord=fixed_coordinates[atlasIdx][0]
            x,y,z =atlasCoord.astype(int)
            label=atlas[x,y,z]   
            anat_region=labelsdf["Anatomical Regions"][labelsdf["Labels"]==label].values[0]
            line="CH:"+str(chMap[idx])+" in "+anat_region+' Segment: '+str(label) + " atlas coord: " + str(atlasCoord)
            print(line)
            f.write(line)
            f.write('\n')

## Reading atlas labels

In [None]:
# labels_path='./atlas_labels.rtf'
# labelsdf=read_whs_labels(labels_path)

In [None]:
# # Probe map for right-hemi implantation
# siteMap = np.flip(np.reshape(scipy.io.loadmat('right_hemi_sitemap.mat')['siteMap'], -1))

In [None]:
# root="/Users/eminhanozil/Dropbox (Yanik Lab)/Localization Manuscript 2024/RAT DATA"

In [None]:
# dwi_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_DWI_v1.01.nii.gz")
# nii_dwi=nib.load(dwi_path)
# dwi=np.asanyarray(nii_dwi.dataobj)
# dwi=dwi[:,:,:,0]

# mask_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_v2_brainmask_bin.nii.gz")
# nii_mask=nib.load(mask_path)
# mask=np.asanyarray(nii_mask.dataobj)

# t2s_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_T2star_v1.01.nii.gz")
# nii_t2s=nib.load(t2s_path)
# t2s=np.asanyarray(nii_t2s.dataobj)

# atlas_path=os.path.join(root, "WHS_SD_rat_atlas_v4_pack","WHS_SD_rat_atlas_v4.nii.gz")
# nii_atlas=nib.load(atlas_path)
# atlas=np.asanyarray(nii_atlas.dataobj)