In [None]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from sklearn.mixture import GaussianMixture
import sys
sys.path.append('../src/')
from forward_modeling import project_volume, slice_volume
from molecular_handling import particle
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import norm

In [None]:
def gmm_image_dataset(particles,size=100,noise=0.1,size_grid=32):
    """
        Arguments:
        ---------
        - particles (list): 
            list of the different particles in the dataset
        - size (int {default:100}): 
            lentgh of the dataset
        - noise (float {default: 0.1}):
            noise added to the images
        - size_grid (integer {default:32}):
            number of voxels along each grid direction
    """
    dataset = np.zeros((size, size_grid, size_grid))
    n_particles = len(particles)
    for t in range(size):
        iat = np.random.randint(n_particles)
        angle = np.random.rand() * 360
        rot = R.from_euler('z', angle, degrees = True)
        
        particles[iat].rotate(quat = rot.as_quat())
        dataset[t,:,:] = particles[iat].image
    
    return dataset + np.random.randn(dataset.shape[0], dataset.shape[1], dataset.shape[2])*noise

In [None]:
particle1 = particle(n_atom=3, rads_atom=[1.,2.,1.], size_grid = 32)
particle2 = particle(n_atom=2, rads_atom=[2.,2.], size_grid = 32)
particle3 = particle(n_atom=1, rads_atom=[3.], size_grid = 32)
particles = [particle1, particle2, particle3]

In [None]:
dataset = gmm_image_dataset(particles, noise=0.1)

# Simple gaussian mixture model applied directly to the images

In [None]:
reshaped_dataset = dataset.reshape(100,1024)
gm = GaussianMixture(3)
gm.fit(reshaped_dataset)

images = gm.means_.reshape(3,32,32)

In [None]:
plt.imshow(particle1.image)

In [None]:
plt.imshow(particle2.image)

In [None]:
plt.imshow(particle3.image)

In [None]:
plt.imshow(images[0])

In [None]:
plt.imshow(images[1])

In [None]:
plt.imshow(images[2])

It worked quite good without having to find the orientations

# Marginalizing over the orientations

In [None]:
Rot = R.from_euler('z', 90, degrees = True)

In [None]:
def rot_mat(theta):
    #theta in radians
    c, s = np.cos(theta), np.sin(theta)
    mat = np.array(((c, -s), (s, c)))
    return mat
    

def rotate_image(image, theta):
    rot = rot_mat(theta)
    N = image.shape[0]
    rot_image = np.zeros((N,N))
    x = np.linspace(-1,1,N)
    y = np.linspace(-1,1,N)
    interpolating_function = RegularGridInterpolator((x,y), image)

    for i in range(N):
        x = -1 + i*2/(N-1)
        for j in range(N):
            y = -1 + j*2/(N-1)
            vect = np.array((x,y))
            vect_ = rot @ vect
            if np.max(np.abs(vect_))<=1:
                rot_image[i,j] += interpolating_function(vect_)
                
    return rot_image

def proba(model, rotated_image):
    prob = 1
    for i in range(rotated_image.shape[0]):
        for j in range(rotated_image.shape[1]):
            prob*=(1 - norm.cdf(np.abs(model[i,j]-rotated_image[i,j]), loc = 0, scale = 1))*2
    return prob

def reconstruct(dataset, models, size):
    thetas = np.linspace(0,360,size)
    dtheta = 2/size *np.pi
    next_models = np.zeros(models.shape)
    probs = np.zeros((len(dataset), len(thetas), models.shape[0]))
    for i in range(len(dataset)):
        print(i)
        for j in range(len(thetas)):
            for k in range(models.shape[0]):
                rotated_image = rotate_image(dataset[i], thetas[j])
                prob = proba(models[k], rotated_image) * dtheta
                probs[i,j,k]= prob
                next_models[k] += prob * rotated_image
                
    for k in range(models.shape[0]):
        next_models[k]/np.sum(probs[:,:,k])
        
    return next_models

def EM(dataset, images_init):
    return images_final

In [None]:
next_model = reconstruct(dataset, images, 10)

In [None]:
np.mean(next_model)

In [None]:
plt.imshow(next_model[0])

In [None]:
plt.imshow(next_model[1])

In [None]:
plt.imshow(next_model[2])