<a href="https://colab.research.google.com/github/gwarnertt/AI_Project_4/blob/master/Untitled18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
"""Mixture model for matrix completion"""
from typing import Tuple
import numpy as np
import common
from scipy.special import logsumexp
from common import GaussianMixture

X = np.loadtxt('toy_data.txt')
K =2
seed = 0
mixture, post = common.init(X,K,seed)

def estep(X: np.ndarray, mixture: GaussianMixture) -> Tuple[np.ndarray, float]:
    """E-step: Softly assigns each datapoint to a gaussian component

    Args:
        X: (n, d) array holding the data, with incomplete entries (set to 0)
        mixture: the current gaussian mixture

    Returns:
        np.ndarray: (n, K) array holding the soft counts
            for all components for all examples
        float: log-likelihood of the assignment

    """
    mixture, post = common.init(X,2,seed=1)
    n, _ = X.shape
    K, _ = mixture.mu.shape
    post = np.zeros((n, K))

    ll = 0
    for i in range(n):
        for j in range(K):
            likelihood = gaussian(X[i], mixture.mu[j], mixture.var[j])
            post[i, j] = mixture.p[j] * likelihood
        total = post[i, :].sum()
        post[i, :] = post[i, :] / total
        ll += np.log(total)

    return post, ll
    raise NotImplementedError

def gaussian(x: np.ndarray, mean: np.ndarray, var: float) -> float:
    """Computes the probablity of vector x under a normal distribution

    Args:
        x: (d, ) array holding the vector's coordinates
        mean: (d, ) mean of the gaussian
        var: variance of the gaussian

    Returns:
        float: the probability
    """
    d = len(x)
    log_prob = -d / 2.0 * np.log(2 * np.pi * var)
    log_prob -= 0.5 * ((x - mean)**2).sum() / var
    return np.exp(log_prob)    



def mstep(X: np.ndarray, post: np.ndarray, mixture: GaussianMixture,
          min_variance: float = .25) -> GaussianMixture:
    """M-step: Updates the gaussian mixture by maximizing the log-likelihood
    of the weighted dataset

    Args:
        X: (n, d) array holding the data, with incomplete entries (set to 0)
        post: (n, K) array holding the soft counts
            for all components for all examples
        mixture: the current gaussian mixture
        min_variance: the minimum variance for each gaussian

    Returns:
        GaussianMixture: the new gaussian mixture
    """
    n,d = X.shape
    K = post.shape[1]
    j = np.sum(post, axis = 0)
    pi = j/n
    mu = (post.T@X)/j.reshape(-1,1)
    norms = np.linalg.norm(X[:,None]-mu, ord=2, axis=2)**2
    var = np.sum(post*norms, axis=0)/(j*d)
    return GaussianMixture
    raise NotImplementedError


def run(X: np.ndarray, mixture: GaussianMixture,
        post: np.ndarray) -> Tuple[GaussianMixture, np.ndarray, float]:
    """Runs the mixture model

    Args:
        X: (n, d) array holding the data
        post: (n, K) array holding the soft counts
            for all components for all examples

    Returns:
        GaussianMixture: the new gaussian mixture
        np.ndarray: (n, K) array holding the soft counts
            for all components for all examples
        float: log-likelihood of the current assignment
    """
 
    old_loglh = None
    new_loglh = None
    
    while (old_loglh is None or new_loglh - old_loglh> 1e-6*np.abs(new_loglh)):
        old_loglh = new_loglh
  
        post, new_loglh = estep(X, mixture)
        mixture = mstep(X,post, mixture)
    
    return mixture, post, new_loglh
    raise NotImplementedError


def fill_matrix(X: np.ndarray, mixture: GaussianMixture) -> np.ndarray:
    """Fills an incomplete matrix according to a mixture model

    Args:
        X: (n, d) array of incomplete data (incomplete entries =0)
        mixture: a mixture of gaussians

    Returns
        np.ndarray: a (n, d) array with completed data
    """
    raise NotImplementedError

mixture, post = common.init(X, K, seed)
post, log_lh =  estep(X,mixture)
GaussianMixture = mstep(X, post, mixture)
mixture, post, new_loglh = run(X, post, mixture)
new_loglh


-1422.0774778715281