# Sort 2D projections by common lines

In [None]:
import mrcfile
import itertools
import numpy as np
from igraph import Graph
from scipy import ndimage as ndi
from skimage import transform, measure
from scipy import signal, spatial, stats

### inputs

In [None]:
#Path to mrcs file of 2D class averages
mrc_input = 'path/to/input/mixture_2D.mrcs'

#Path to star file for particles in 2D class averages
star_input = 'path/to/matching/mixture_particles.star'

#Path for output files
outpath = 'path/to/output/'

#Name for output file
description = 'mixture'

#Pixel size of 2D class averages in A/pixel
pixel_size = 4

#Choose metric for comparing 1D projections (NCC, cross-correlation, Euclidean, Norm-Euc, cosine, difference)
metric = 'Euclidean'

#Number of edges for each node in the graph
neighbors = 5

#Community detection algorithm to use (betweenness, walktrap)
community_detection = 'betweenness'

In [None]:
#Set up angles for 2D->1D projections
interval = 5
angle = 360

#metrics to add 'sliding' feature too, analogous to cross-correlation
slide = ['Euclidean', 'Norm-Euc', 'cosine', 'difference']

### read in 2D projections

In [None]:
projection2d = {}

with mrcfile.open(mrc_input) as mrc:
    for i, data in enumerate(mrc.data):
        projection2d[i] = data.astype('float64')
        
#2D projections are named numerically to match mrcs file
file_names = list(range(len(projection2d)))

### extract class averages from background

In [None]:
#If using synthetic reprojections from EMAN, skip this block

for name, image in projection2d.items():
    #keep pixel values greater than the zero-mean, set everything else to zero
    #TODO: test different region extraction algorithms
    img_thresh = np.empty(image.shape)
    for i, row in enumerate(image):
        for j, pixel in enumerate(row):
            if image[(i,j)] < 0:
                img_thresh[(i,j)] = 0
            else:
                img_thresh[(i, j)] = image[(i,j)]
    
    #dilate by pixel size to connect neighbor regions
    extra = 3 #Angrstroms to dilate (set minimum of 3A)
    extend = int(np.ceil((pixel_size/extra)**-1))

    struct = np.ones((extend, extend), dtype=bool)
    dilate = ndi.binary_dilation(input=img_thresh, structure=struct)

    labeled = measure.label(dilate, connectivity=2, background=False)

    #select a single region from each 2D class average
    rprops = measure.regionprops(labeled, cache=False)
    bbox = [r.bbox for r in rprops]

    if len(bbox) == 1:
        #use only region in the image
        selected = 1

    elif len(bbox) > 1:
        img_x_center = len(image)/2
        img_y_center = len(image[:,0])/2
        img_center = (img_x_center, img_y_center)
        #for use in distance calculation
        x1 = np.array(img_center)

        box_range = {}
        for i, box in enumerate(bbox):
            width_coord = list(range(box[1], box[3]+1))
            length_coord = list(range(box[0], box[2]+1))
            box_range[i+1] = [width_coord, length_coord]

        box_centers = {}
        for i, box in enumerate(bbox):
            y_max, y_min = box[0], box[2]
            x_max, x_min = box[1], box[3]
            center_xy = (((x_max+x_min)/2, (y_max+y_min)/2))
            #i+1 because 0 is the background region
            box_centers[i+1] = center_xy

        selected = 'none'

        for region, bound in box_range.items():
            #first check if there is a region in the center
            if img_x_center in bound[0] and img_y_center in bound[1]:
                #use center region
                selected = region

        if selected == 'none':
            #find box closest to the center    
            distance = {}
            for region, center in box_centers.items():
                x2 = np.array(center)
                distance[region] = spatial.distance.euclidean(x1, x2)  
            region = min(distance, key=distance.get) 
            #use region closest to center
            selected = region

    selected_region = (labeled == selected)

    properties = measure.regionprops(selected_region.astype('int'))
    bbox = properties[0].bbox

    y_min, y_max = bbox[0], bbox[2]
    x_min, x_max = bbox[1], bbox[3]
    
    #keep only true pixels in bounding box
    true_region = np.empty(image.shape)
    
    for i, row in enumerate(image):
        for j, pixel in enumerate(row):
            if selected_region[(i,j)] == True:
                true_region[(i,j)] = image[(i, j)]
            else:
                true_region[(i,j)] = 0
                
    new_region = true_region[y_min:y_max, x_min:x_max]
    
    projection2d[name] = new_region

### create 1d projection vectors

In [None]:
n_projections = len(projection2d)
print('\n'+'calculting line scores for {0} pairs of projections'\
      .format(int((n_projections**2)/2 + (n_projections/2))))

#Keep projections as a list to slice later
projection_1D = []

rot_angles = np.arange(0, angle+1, interval)

#Rotate and project to 1D on x-axis, effectively creates a sinogram
for key, projection in projection2d.items():
    for angle in rot_angles:
        proj_1D = transform.rotate(projection, angle).sum(axis=0)
        if metric == 'NCC' or metric == 'norm-Euc':
            norm_proj_1D = stats.zscore(proj_1D) 
            projection_1D.append([(key, angle), norm_proj_1D])
        else:
            projection_1D.append([(key, angle), proj_1D])

### sliding vector comparisons

In [None]:
score_matrix = {}

#Calculate sliding vector scores to account for translations in 1D projections
if metric in slide:
    for i, proj_1 in enumerate(projection_1D):
        lp1 = len(proj_1[1])
        #slice to calculate upper triangular matrix
        for proj_2 in projection_1D[i:]:
            lp2 = len(proj_2[1])
            dol = abs(lp1 - lp2)
            proj_shift = []

            if dol == 0:
                if metric == 'difference':
                    score = sum(abs(proj_1[1] - proj_2[1]))
                elif metric == 'Euclidean' or metric == 'Norm-Euc':
                    score = spatial.distance.euclidean(proj_1[1], proj_2[1])
                elif metric == 'cosine':
                    score = spatial.distance.cosine(proj_1[1], proj_2[1])

            #store position and score of minimum 
            elif lp1 < lp2:
                static = proj_2[1]
                for i in range(0, dol+1):
                    proj_shift.append(np.pad(proj_1[1], pad_width=(i, dol-i), mode='constant'))
            elif lp1 > lp2:
                static = proj_1[1]
                for i in range(0, dol+1):
                    proj_shift.append(np.pad(proj_2[1], pad_width=(i, dol-i), mode='constant'))
            else:
                print('something is wrong with the 1D projections')

            if dol != 0:
                scores = []
                for shift in proj_shift:
                    if metric == 'difference':
                        val = sum(abs(static - shift))
                    elif metric == 'Euclidean' or metric == 'Norm-Euc':
                        val = spatial.distance.euclidean(static, shift)
                    elif metric == 'cosine':
                        val = spatial.distance.cosine(static, shift)
                    scores.append(val)
                #metrics in 'slide' are dissimilarity, smaller values are better
                score = min(scores)
            #format for score_matrix [proj_1, angle_1, proj_2, angle_2] = score
            score_matrix[proj_1[0][0], proj_1[0][1], proj_2[0][0], proj_2[0][1]] = score
            
elif metric == 'cross-correlation' or metric == 'NCC':
    for i, proj_1 in enumerate(projection_1D):
        for proj_2 in projection_1D[i:]:
            score = signal.correlate(proj_1[1], proj_2[1], mode='valid')
            score_matrix[proj_1[0][0], proj_1[0][1], proj_2[0][0], proj_2[0][1]] = score          
    for key, array in score_matrix.items():
        score_matrix[key] = np.amax(array)

In [None]:
#optional output of all scores (this is a large text file! 100 projections --> 2.6*10^7 lines)

#with open(outpath+'/{0}_raw_scores.txt'.format(description), 'w') as f:
#    f.write('projection_1' + '\t' + 'angle_1' + '\t' +'projection_2' + '\t' + 'angle_2' + '\t' + 'score' + '\n')
#    for key, value in score_matrix.items():
#        f.write(str(key[0]) +'\t'+ str(key[1]) +'\t'+ str(key[2]) +'\t'+ str(key[3]) +'\t'+ '%f'%(value) + '\n')

###  final common line scores

In [None]:
#Get the optimum score for each pair of projection images 

final_scores = {}
complete_score_matrix = {}

#Generate pairs for the number of 2D projections
pairs = list(itertools.combinations_with_replacement(file_names, 2))

#Initialize final score (proj1, proj2) = [score, angle_proj1, angle_proj2]
final_scores = {pair: [score_matrix[pair[0], 0, pair[1], 0], 0, 0] for pair in pairs}

if metric in slide: 
    for key, value in score_matrix.items():
        if value < final_scores[(key[0], key[2])][0]:
            final_scores[(key[0], key[2])][0] = value
            final_scores[(key[0], key[2])][1] = key[1]
            final_scores[(key[0], key[2])][2] = key[3]

elif metric == 'cross-correlation' or metric == 'NCC':
    for key, value in score_matrix.items():
        if value > final_scores[(key[0], key[2])][0]:
            final_scores[(key[0], key[2])][0] = value
            final_scores[(key[0], key[2])][1] = key[1]
            final_scores[(key[0], key[2])][2] = key[3]
            
for key, value in final_scores.items():
    complete_score_matrix[key] = value
    score = value[0]
    angle_1 = value[1]
    angle_2 = value[2]
    complete_score_matrix[(key[1], key[0])] = [score, angle_2, angle_1]

### nearest neighbors

In [None]:
proj_scores = {}

for i in range(len(projection2d)):
    proj_scores[i] = []
    for j in range(len(projection2d)):
        proj_scores[i].append(complete_score_matrix[(i, j)][0])
        
    #TODO: apply different edgeweighting strategies
    zscores = stats.zscore(proj_scores[i])
    proj_scores[i] = zscores
    
proj_knn = {key: [] for key in range(len(projection2d))}

for projection, scores in proj_scores.items():
    sort_scores = sorted((score, i) for i, score in enumerate(proj_scores[projection]))
    if metric not in slide:
        sort_scores = sorted(sort_scores, reverse=True)
    count = 0
    for zscore in sort_scores:
        if count < neighbors and zscore[1] != projection:
            #proj_knn format is (proj_1) = [[score, (proj_1, angle_1, proj_2, angle_2)]]
            proj_knn[projection].append([abs(zscore[0]),\
                                        (projection,\
                                         complete_score_matrix[(projection, zscore[1])][1],\
                                         zscore[1],\
                                         complete_score_matrix[(projection, zscore[1])][2])])
            count = count + 1

### cluster 2d projections

In [None]:
flat = []

for proj, neighbors in proj_knn.items():
    for n in neighbors:
        #(proj_1, proj_2, score)
        flat.append((str(proj), str(n[1][2]), n[0]))

g = Graph.TupleList(flat, weights=True)

if community_detection == 'walktrap':
    wt = Graph.community_walktrap(g, weights='weight', steps=4)
    cluster_dendrogram = wt.as_clustering()
elif community_detection == 'betweenness':
    ebs = Graph.community_edge_betweenness(g, weights='weight', directed=True)
    cluster_dendrogram = ebs.as_clustering()

clusters = {}
for comm, proj in enumerate(cluster_dendrogram.subgraphs()):
    clusters[comm] = proj.vs['name']

#Convert vertex IDs back to int
for key, comm in clusters.items():
    clusters[key] = [int(proj) for proj in comm]

### remove outliers from clusters

In [None]:
#Use median absolute deviation of summed 2D projections to remove outliers
#Evaluate outliers for further processing

pixel_sums = {}

for cluster, nodes in clusters.items():
    pixel_sums[cluster] = []
    for node in nodes:
        pixel_sums[cluster].append(sum(sum(projection2d[node])))

for cluster, psums in pixel_sums.items():
    med = np.median(psums)
    m_psums = [abs(x - m1) for x in psums]
    mad = np.median(m_psums)
    
    for i, proj in enumerate(psums):
        #Boris Iglewicz and David Hoaglin (1993)
        z = 0.6745*(proj - med)/mad
        if abs(z) > 3.5:
            print('projection node {0} was removed from cluster {1} as an outlier'\
                  .format(clusters[cluster][i], cluster))
            clusters[cluster].pop(i)

### score output files

In [None]:
#Complete score matrix output
#Output is tab separated: projection_1, angle_1, projection_2, angle_2, score
with open(outpath+'/complete_scores_{0}.txt'.format(description), 'w') as f:
    f.write('projection_1' + '\t' + 'angle_1' + '\t' +'projection_2' +\
            '\t' + 'angle_2' + '\t' + 'score' + '\n')
    for key, value in complete_score_matrix.items():
        f.write(str(key[0]) +'\t'+ str(value[1]) +'\t'+ str(key[1]) +'\t'+\
                str(value[2]) +'\t'+ '%f'%(value[0]) + '\n')
        
#Nearest neighbors output
#Output is tab separated: projection_1, angle_1, projection_2, angle_2, score    
with open(outpath+'/neighbors_{0}.txt'.format(description), 'w') as f:
    f.write('projection_1' + '\t' +'angle_1' + '\t'\
            + 'projection_2' + '\t' + 'angle_2' + '\t' + 'edge_score' + '\n')
    for key, value in proj_knn.items():
        for item in value:
            f.write(str(key) + '\t' + str(item[1][1]) + '\t'\
                    + str(item[1][2]) + '\t' + str(item[1][3]) + '\t' + str(item[0]) + '\n')        

### output cluster star files

In [None]:
h_count = 0
header = []
class_average = {}

with open(star_input) as f:
    for raw_line in f:
        line = raw_line.rstrip('\n').split()
        h_count = h_count + 1
        if h_count < 33:
            header.append(line)
        else:
            #skip blank lines
            if len(line) > 10:
                #line[23]-1 because relion index starts at 1, mrc and clusters start at 0
                if int(line[23])-1 in class_average:
                    class_average[int(line[23])-1].append(line)
                else:
                    class_average[int(line[23])-1] = [line]

#Format header for output
for i, entry in enumerate(header):
    if i == 1:
        header[i] = ['\n'+str(header[i][0])+'\n']
    if len(entry) > 1:
        header[i] = [' '.join(entry)]
        
flat_header = [entry for sub in header for entry in sub]

output = {cluster: [] for cluster, proj in clusters.items()}

for cluster, averages in clusters.items():
    for average in averages:
        output[cluster].append(class_average[average])

for cluster, data in output.items():
    with open(outpath+'/{0}_cluster_{1}.star'.format(description, cluster), 'w') as f:
        f.write('\n'.join(flat_header)+'\n')
        for particle_list in data:
            for particles in particle_list:
                f.write('\t'.join(particles)+'\n')
                
with open(outpath+'/{0}_clusters.txt'.format(description), 'w') as f:
    for cluster, averages in clusters.items():
        f.write(str(cluster) + '\t' + str(averages) + '\n')