### Subtomogram Alignment and Averaging

In [None]:
import numpy as np
import torch
import scipy
import mrcfile
import matplotlib.pyplot as plt

In [None]:


#volumes = '/Users/HenryJones/Desktop/volume.mrc'
#coordinates = '/Users/HenryJones/Desktop/SULI/ribo_coordinates.txt'

volumes = '/Users/HenryJones/Desktop/SULI/gt_volume.mrc'
coordinates = '/Users/HenryJones/Desktop/SULI/new_ribo_coordinates.txt'
volume = torch.tensor(mrcfile.read(volumes))
#coordinates are output centered in nm before magnification
coords = torch.Tensor(np.genfromtxt(coordinates, skip_header = 2))


In [None]:
plt.imshow(volume[25,:,:])

In [4]:
#make our rotation matrix
def angles_to_matrix(angles: torch.Tensor)-> torch.Tensor:
    assert angles.shape[0] == 3
    phi   = angles[0] * torch.pi / 180
    theta = angles[1] * torch.pi / 180
    psi   = angles[2] * torch.pi / 180
    
    sinphi = torch.sin(phi)
    cosphi = torch.cos(phi)
    
    sintheta = torch.sin(theta)
    costheta = torch.cos(theta)
    
    sinpsi = torch.sin(psi)
    cospsi = torch.cos(psi)
    
    r_mat = torch.Tensor([[cospsi * cosphi - (sinpsi * costheta * sinphi), 
                            - (cospsi * sinphi) - (sinpsi * costheta * cosphi),
                            sinpsi * sintheta],
                           [(sinpsi * cosphi) + (cospsi * costheta * sinphi),
                            -(sinpsi * sinphi) + (cospsi * costheta * cosphi),
                            - (cospsi * sintheta)],
                           [sintheta * sinphi, sintheta * cosphi, costheta]])
    return r_mat

In [None]:
def check_transpose(in_tensor: torch.Tensor, correct_shape: torch.Size):
    if in_tensor.shape != correct_shape:
        return torch.transpose(in_tensor, 0, 1)
    return in_tensor

def pick_particles(coords: torch.Tensor, correct_size):
    #fits in in the detector
    #solved an issue by checking z conditions in pixel units, could be better for x and y conditions too
    first_mask = coords[((torch.abs(coords[:,0]) + ribosome_diameter) * factor <= correct_size[2]/2) \
                        & ((torch.abs(coords[:,1]) + ribosome_diameter) * factor <= correct_size[1]/2)\
                       & ((torch.abs(coords[:,2]) * z_factor) + ribo_pixel_radius <= correct_size[0]/2)] #radius for the z axis

    second = first_mask[torch.sqrt(first_mask[:,0] **2 + first_mask[:,1] **2) < 600 -  ribosome_diameter]

    upper_distance_matrix = torch.triu(torch.cdist(second[:,:3], second[:,:3], p=2) > ribosome_diameter)
    final_mask = ((upper_distance_matrix + torch.transpose(upper_distance_matrix, 0, 1)).sum(dim = 0) == upper_distance_matrix.shape[0]-1)
    coords = second[final_mask]
    return coords

def new_extract_subtomos(volume: torch.Tensor, picked_coordinates: torch.Tensor, factor, z_factor):
    tensor_of_subtomos = np.empty(shape = int(picked_coordinates.shape[0]), dtype = object)
    tensor_of_coords = torch.empty(size = (int(picked_coordinates.shape[0]),3))
    #tensor_of_pixel_coords
    for particle in range(int(picked_coordinates.shape[0])):
        pixel_coord_float = torch.Tensor([picked_coordinates[particle,0] * factor + volume.shape[-1]/2,
                                          picked_coordinates[particle,1] * factor + volume.shape[-2]/2,
                                          picked_coordinates[particle,2] * z_factor + volume.shape[-3]/2])
                                            #Don't need to reverse order of z axis like we do for the tomograms
                                          #picked_coordinates[particle,2] * z_factor + volume.shape[-3]/2])
                                         #volume.shape[-3] - (picked_coordinates[particle,2] * z_factor + volume.shape[-3]/2)]) #reverse the order of the z axis
        
        pixel_coord = torch.round(pixel_coord_float).int()
        tensor_of_coords[particle] = pixel_coord_float

        bounds = torch.Tensor([ [pixel_coord[2] - ribo_pixel_radius,   pixel_coord[2] + ribo_pixel_radius],
                                [pixel_coord[1] - ribo_pixel_diameter, pixel_coord[1] + ribo_pixel_diameter],
                                [pixel_coord[0] - ribo_pixel_diameter, pixel_coord[0] + ribo_pixel_diameter]]).int()
        
        below = torch.tensor(torch.where(bounds < 0, 1, 0))

        above = torch.tensor(torch.where(bounds > torch.stack([torch.tensor(volume.shape),
                                                  torch.tensor(volume.shape)]).T, 1, 0))
        old_bounds = torch.abs(bounds * (below + above)) # element wise multiplication of EITHER mask, but will always be one or the other
        pad_bounds = old_bounds[0,:]
        pad_bounds[pad_bounds > volume.shape[0]-1] = pad_bounds[1] + 1 - volume.shape[0]
        pad_bounds[pad_bounds < 0] = np.abs(pad_bounds[0])
        #for slicing we need to adjust
        bounds[(bounds < 0)] = 0
        bounds[0,1][bounds[0,1] > volume.shape[0]] = volume.shape[0]
        
        #extract subtomo

        tensor_of_subtomos[particle] = volume[int(bounds[0,0]) : int(bounds[0,1]) + 1,
                                              int(bounds[1,0]) : int(bounds[1,1]) + 1,
                                              int(bounds[2,0]) : int(bounds[2,1]) + 1]
        
        #pad z dimension if particle is within particle radius of volume face.
        tensor_of_subtomos[particle] = torch.nn.functional.pad(tensor_of_subtomos[particle],
                                                               (0,0,0,0, # by our selection we don't need x and y padding
                                                                pad_bounds[0], #low z padding
                                                                pad_bounds[1]), #high z padding
                                                               mode = 'constant')
        
        if particle <= 2:
            padded = torch.nn.functional.pad(tensor_of_subtomos[particle], (0,0,0,0, ribo_pixel_radius, ribo_pixel_radius), mode = 'constant', value = 0)

            x = torch.linspace(start = - ribo_pixel_diameter, end = ribo_pixel_diameter, steps = padded.shape[0])
            print(x)
            meshgrid = torch.meshgrid(x,x,x, indexing = "ij")
            mask = torch.where(torch.sqrt(meshgrid[0]**2 + meshgrid[1]**2 + meshgrid[2]**2) < ribo_pixel_radius, 1, 0)
            mrc = mrcfile.new(f"prerotate{particle}.mrc")
            mrc.set_data(padded.numpy().astype(np.float32) * mask.numpy().astype(np.float32))
            mrc.close()

    return tensor_of_subtomos, tensor_of_coords

def new_meshgrid(subtomo_shape: torch.Size, particle_center: torch.Tensor) -> torch.meshgrid:
    """
        should only have to worry about not being centered with z axis
        Example of particle center is 
        
        torch.Tensor([1809.6265, 1761.0503,   54.2631])
    """

    center_diff = particle_center - torch.round(particle_center)
    # best results with center dif 0 ,1 , 2
    x = torch.linspace(start = - ribo_pixel_diameter, end = ribo_pixel_diameter, steps = subtomo_shape[2]) - center_diff[0]
    y = torch.linspace(start = - ribo_pixel_diameter, end = ribo_pixel_diameter, steps = subtomo_shape[1]) - center_diff[1]
    z = torch.linspace(start = - ribo_pixel_diameter, end = ribo_pixel_diameter, steps = subtomo_shape[0]) - center_diff[2]
    
    return torch.meshgrid(x, y, z, indexing = 'xy'), x, y, z


def align_and_average(new_out, picked_coords, new_pixel_coords, ribo_pixel_radius):
    i = 0
    #loop over subtomograms, picked particle rotation angles, and updated interger pixel coordinates
    average_size = ribo_pixel_radius * 2 + 1
    average = torch.zeros((average_size, average_size, average_size))
    for subtomo, angles, pixel_coord in zip(new_out, picked_coords[:, -3:], new_pixel_coords[:, :3]):

        rmat = angles_to_matrix(angles)
        #we have made extracted with shape (2 * ribo pixel diameter + 1, 2 * ribo pixel diameter +1 , 2 * ribo pixel radius +1)
        #so now need to pad z axis to make it a cube
        padded = torch.nn.functional.pad(subtomo, (0,0,0,0, ribo_pixel_radius, ribo_pixel_radius), mode = 'constant', value = 0)
        meshgrid, x_linspace, y_linspace, z_linspace = new_meshgrid(padded.shape, pixel_coord) #picked_part[:3])
        x = torch.flatten(meshgrid[0])
        y = torch.flatten(meshgrid[1])
        z = torch.flatten(meshgrid[2])
        #before_rotation = torch.stack([x, y, z])
        rotation = torch.matmul(rmat, torch.stack([z, y, x])) #z y x or x y z doesn't matter
        center_diff = pixel_coord - torch.round(pixel_coord)
        #rotation -= center_diff.unsqueeze(1)
        #center_diff = pixel_coord - torch.round(pixel_coord)
        assert padded.shape[0] == 2 * ribo_pixel_diameter + 1
        assert padded.shape[1] == 2 * ribo_pixel_diameter + 1
        assert padded.shape[2] == 2 * ribo_pixel_diameter + 1

        in_grid_shape = torch.unflatten(rotation, dim = 1, sizes = (padded.shape[0],
                                                        padded.shape[1],
                                                        padded.shape[2]))[:,
                                                                            int((padded.shape[0] - 1)/2 - ribo_pixel_radius):  int((padded.shape[0] - 1)/2 + ribo_pixel_radius) + 1, 
                                                                            int((padded.shape[1] - 1)/2 - ribo_pixel_radius):  int((padded.shape[1] - 1)/2 + ribo_pixel_radius) + 1,
                                                                            int((padded.shape[2] - 1)/2 - ribo_pixel_radius):  int((padded.shape[2] - 1)/2 + ribo_pixel_radius) + 1]
        data_w_coords = scipy.interpolate.RegularGridInterpolator((z_linspace.numpy(), y_linspace.numpy(), x_linspace.numpy()), padded.numpy())

        #apply spherical mask, does it matter before or after interpolation and rotation?
        mask = torch.where(torch.sqrt(in_grid_shape[0]**2 + in_grid_shape[1]**2 + in_grid_shape[2]**2) < ribo_pixel_radius, 1, 0)

        # should be 2, 1, 0
        interpolation = data_w_coords(((in_grid_shape[2]).numpy(),
                                    (in_grid_shape[1]).numpy(),
                                        (in_grid_shape[0]).numpy()))  * mask.numpy()
        #plt.imshow(interpolation[interpolation.shape[0]//2,:,:])
        #plt.scatter(pixel_coord[1], pixel_coord[2])
        #plt.show()
        mrc = mrcfile.new(f"post_rotate{i}.mrc")
        mrc.set_data(interpolation.astype(np.float32))
        mrc.close()

        if i > 2:
            break
        #    raise ValueError
        #mrc = mrcfile.new(f"rotate_sample{i}.mrc")
        #mrc.set_data(interpolation.astype(np.float32))
        #mrc.close()
        i += 1
        plt.imshow(interpolation[:,interpolation.shape[1]//2,:])
        #plt.scatter(pixel_coord[1], pixel_coord[2]),:])
        plt.show()
        average = average + interpolation
    average /= new_out.shape[0]
    return average


"""
ARGS and inputs

"""
#80S ribosomes have diameters up to 300-320 A which is 30-32 nm
#We know our tem simulator uses magnification 
#750000 with detector size 16000 nm
magnification = 75000
detector_size_nm = 16000
#can get out of bounds with 8/1000
ribosome_radius = int(np.around(volume.shape[1] * 15/1000)) * 2 #old dataset was for   # nm #15 worked for 1000 pixel width volume, so 9 for 600
#factor = magnification / detector_size_nm /2
#for now, with volume
factor = 1
ribosome_diameter = 2 * ribosome_radius
ribo_pixel_radius = int(torch.round(torch.Tensor([factor * ribosome_radius])))
ribo_pixel_diameter = 2 * ribo_pixel_radius
print(ribo_pixel_radius, 'pixel_radius')
print(ribosome_diameter + ribosome_radius, 'neighbor distance')
#correct_size = torch.Size([200, 1000, 1000]) #z, y, x
correct_size = volume.shape
#z_factor = volume.shape[0] /(torch.abs(torch.min(coords[:,2])) + torch.max(coords[:,2]))
z_factor = volume.shape[0] /72
print(z_factor)
#we only need to crop in the x and y planes
def main():
    picked_coords = pick_particles(coords, correct_size)
    print(picked_coords.shape, "picked coords")
    plt.imshow(volume[:,300,:])
    plt.scatter(picked_coords[:,1] * factor + volume.shape[1]//2,
                picked_coords[:,2] * z_factor + volume.shape[0]//2, s = 1, c = "r")
    plt.show()
    subtomos, new_pixel_coords = new_extract_subtomos(volume, picked_coords, factor, z_factor)

    #print("Number of subtomograms:", subtomos.shape[0])
    #for i in range(3):
    #    mrc = mrcfile.new(f"checking{i}.mrc")
    #    mrc.set_data(subtomos[i].numpy().astype(np.float32))
    #    mrc.close()
    return align_and_average(subtomos, picked_coords = picked_coords, new_pixel_coords=new_pixel_coords, ribo_pixel_radius= ribo_pixel_radius)

average = main()

In [None]:
quick = torch.arange(3)
quick1 = torch.arange(3) + 3
quick2 = torch.arange(3) + 6
print(torch.stack([quick, quick1, quick2]))

## Image space alignment and averaging

In [None]:
print(torch.max(coords[:,0]))
print(torch.max(coords[:,1]))
print(torch.max(coords[:,2]))
print(torch.min(coords[:,0]))
print(torch.min(coords[:,1]))
print(torch.min(coords[:,2]))
print( 35.9860 + 36.1435)
print(average.shape)

In [None]:
plt.imshow(volume[25,:,:])

In [None]:
plt.imshow(volume[:,:,500])

In [None]:
mrc = mrcfile.new("dotted_volume.mrc")
mrc.set_data(volume.numpy().astype(np.float32))
mrc.close()

In [None]:
print(average.shape)

In [None]:
plt.imshow(average[17,:,:])

In [15]:
mrc = mrcfile.new("now_avg.mrc")
mrc.set_data(average.numpy().astype(np.float32))
mrc.close()

## rotation check

In [17]:
sample0 = np.array(mrcfile.read("rotate_sample0.mrc"))
sample1 = np.array(mrcfile.read("rotate_sample1.mrc"))
sample2 = np.array(mrcfile.read("rotate_sample2.mrc"))

In [None]:
plt.imshow(sample0[:,18,:])
plt.show()
plt.imshow(sample1[:,18,:])
plt.show()
plt.imshow(sample2[:,18,:])
plt.show()

In [None]:
print(sample0.shape)

In [20]:
picked_coords = pick_particles(coords, correct_size)

In [None]:
print(picked_coords.shape)

In [22]:
padded0 = torch.nn.functional.pad(torch.tensor(sample0),
                                    (18,18,18,18, # by our selection we don't need x and y padding
                                    18, #low z padding
                                    18), #high z padding
                                    mode = 'constant')
padded1 = torch.nn.functional.pad(torch.tensor(sample1),
                                    (18,18,18,18, # by our selection we don't need x and y padding
                                    18, #low z padding
                                    18), #high z padding
                                    mode = 'constant')

padded2 = torch.nn.functional.pad(torch.tensor(sample2),
                                    (18,18,18,18, # by our selection we don't need x and y padding
                                    18, #low z padding
                                    18), #high z padding
                                    mode = 'constant')

In [None]:
plt.imshow(padded0[:,36,:])
plt.show()
plt.imshow(padded1[:,36,:])
plt.show()
plt.imshow(padded2[:,36,:])
plt.show()

In [None]:
axes = np.linspace(-36, 36, padded0.shape[0])
print(axes)

In [None]:
print(axes[(axes.shape[0]-1)//2 - (axes.shape[0]-1)//4 : (axes.shape[0]-1)//2 + (axes.shape[0]-1)//4 + 1])

In [None]:
cropped_axes = axes[(axes.shape[0]-1)//2 - (axes.shape[0]-1)//4 : (axes.shape[0]-1)//2 + (axes.shape[0]-1)//4 + 1]
print(cropped_axes)

In [28]:
meshgrid = np.meshgrid(cropped_axes, cropped_axes, cropped_axes, indexing = "ij")

In [None]:
print(meshgrid[0].shape)

In [30]:
x = np.ravel(meshgrid[0])
y = np.ravel(meshgrid[1])
z = np.ravel(meshgrid[2])


In [38]:
interp0 = scipy.interpolate.RegularGridInterpolator((axes, axes, axes), padded0.numpy())
angles0 = picked_coords[0,-3:]
rmat0 = angles_to_matrix(angles0)
inverse_rmat0 = torch.linalg.inv(rmat0)#apply spherical mask, does it matter before or after interpolation and rotation?
rotation0 = torch.matmul(inverse_rmat0, torch.stack([torch.tensor(z,dtype= torch.float32),
                                                    torch.tensor(x, dtype = torch.float32),
                                                    torch.tensor(y, dtype = torch.float32)])) # was z,y,x before nov 8
grid_shape0 = torch.unflatten(rotation0, dim = 1, sizes= (sample0.shape[0], sample0.shape[1], sample0.shape[2]))
interpolation0 = interp0((grid_shape0[0].numpy(),
                            grid_shape0[1].numpy(),
                               grid_shape0[2].numpy()))

interp1 = scipy.interpolate.RegularGridInterpolator((axes, axes, axes), padded1.numpy())
angles1 = picked_coords[1,-3:]
rmat1 = angles_to_matrix(angles1)
inverse_rmat1 = torch.linalg.inv(rmat1)#apply spherical mask, does it matter before or after interpolation and rotation?
rotation1 = torch.matmul(inverse_rmat1, torch.stack([torch.tensor(z,dtype= torch.float32),
                                                    torch.tensor(x, dtype = torch.float32),
                                                    torch.tensor(y, dtype = torch.float32)])) # was z,y,x before nov 8
grid_shape1 = torch.unflatten(rotation1, dim = 1, sizes= (sample0.shape[0], sample0.shape[1], sample0.shape[2]))
interpolation1 = interp1((grid_shape1[0].numpy(),
                            grid_shape1[1].numpy(),
                               grid_shape1[2].numpy())) 


interp2 = scipy.interpolate.RegularGridInterpolator((axes, axes, axes), padded2.numpy())
angles2 = picked_coords[2,-3:]
rmat2 = angles_to_matrix(angles2)
inverse_rmat2 = torch.linalg.inv(rmat2)#apply spherical mask, does it matter before or after interpolation and rotation?
rotation2 = torch.matmul(inverse_rmat2, torch.stack([torch.tensor(z, dtype= torch.float32),
                                                    torch.tensor(x, dtype = torch.float32),
                                                    torch.tensor(y, dtype = torch.float32)])) # was z,y,x before nov 8
grid_shape2 = torch.unflatten(rotation2, dim = 1, sizes= (sample0.shape[0], sample0.shape[1], sample0.shape[2]))
interpolation2 = interp0((grid_shape2[0].numpy(),
                            grid_shape2[1].numpy(),
                               grid_shape2[2].numpy())) 



In [None]:
toybox = np.zeros_like(sample0)
toybox[5:10, 5:10, 5:10] = np.arange(125).reshape(5,5,5)
paddedtoy = torch.nn.functional.pad(torch.tensor(toybox),
                                    (18,18,18,18, # by our selection we don't need x and y padding
                                    18, #low z padding
                                    18), #high z padding
                                    mode = 'constant')
interptoy = scipy.interpolate.RegularGridInterpolator((axes, axes, axes), paddedtoy.numpy())
rotationtoy = torch.matmul(rmat0, torch.stack([torch.tensor(x, dtype= torch.float32),
                                                    torch.tensor(y, dtype = torch.float32),
                                                    torch.tensor(z, dtype = torch.float32)])) # was z,y,x before nov 8
gridtoy = torch.unflatten(rotationtoy, dim = 1, sizes= (sample0.shape[0], sample0.shape[1], sample0.shape[2]))
print(gridtoy.shape)
interpolationtoy = interptoy((gridtoy[0].numpy(),
                            gridtoy[1].numpy(),
                               gridtoy[2].numpy())) 


In [None]:
paddedtoyi = torch.nn.functional.pad(torch.tensor(interpolationtoy),
                                    (18,18,18,18, # by our selection we don't need x and y padding
                                    18, #low z padding
                                    18), #high z padding
                                    mode = 'constant')
interptoyi = scipy.interpolate.RegularGridInterpolator((axes, axes, axes), paddedtoyi.numpy())
rotationtoyi = torch.matmul(inverse_rmat0, torch.stack([torch.tensor(x, dtype= torch.float32),
                                                    torch.tensor(y, dtype = torch.float32),
                                                    torch.tensor(z, dtype = torch.float32)])) # was z,y,x before nov 8
gridtoyi = torch.unflatten(rotationtoyi, dim = 1, sizes= (sample0.shape[0], sample0.shape[1], sample0.shape[2]))
print(gridtoy.shape)
interpolationtoyi = interptoyi((gridtoyi[0].numpy(),
                            gridtoyi[1].numpy(),
                               gridtoyi[2].numpy())) 0

In [None]:
print(picked_coords[0,-3:])

In [None]:
plt.imshow(toybox[:,7,:])
plt.show()
plt.imshow(interpolationtoy[:,10,:])
plt.show()
plt.imshow(interpolationtoyi[:,7,:])
plt.show()

In [None]:
plt.imshow(interpolation0[:, interpolation0.shape[1]//2,:])
plt.show()
plt.imshow(interpolation1[:, interpolation1.shape[1]//2,:])
plt.show()
plt.imshow(interpolation2[:, interpolation2.shape[1]//2,:])
plt.show()