In [34]:
import numpy as np
from typing import List, Tuple
import matplotlib.pyplot as plt
# ============================================================
# Data Generation (This does the same thing as the genimages.py provided)
# ============================================================
np.random.seed(0)

features = [
    [0, 0, 1, 0,
     0, 1, 1, 1,
     0, 0, 1, 0,
     0, 0, 0, 0],
    [0, 1, 0, 0,
     0, 1, 0, 0,
     0, 1, 0, 0,
     0, 1, 0, 0],
    [1, 1, 1, 1,
     0, 0, 0, 0,
     0, 0, 0, 0,
     0, 0, 0, 0],
    [1, 0, 0, 0,
     0, 1, 0, 0,
     0, 0, 1, 0,
     0, 0, 0, 1],
    [0, 0, 0, 0,
     0, 0, 0, 0,
     1, 1, 0, 0,
     1, 1, 0, 0],
    [1, 1, 1, 1,
     1, 0, 0, 1,
     1, 0, 0, 1,
     1, 1, 1, 1],
    [0, 0, 0, 0,
     0, 1, 1, 0,
     0, 1, 1, 0,
     0, 0, 0, 0],
    [0, 0, 0, 1,
     0, 0, 0, 1,
     0, 0, 0, 1,
     0, 0, 0, 1],
]

num_samples = 2000
num_features = 16
K_true = len(features)

feature_weights = 0.5 + np.random.rand(K_true, 1) * 0.5
mu_true = np.array([weight * feat for weight, feat in zip(feature_weights, features)])
latent_factors = (np.random.rand(num_samples, K_true) < 0.3).astype(float)
data = latent_factors @ mu_true + np.random.randn(num_samples, num_features)*0.1


In [35]:
def m_step(X, ES, ESS):
    """
    mu, sigma, pie = MStep(X,ES,ESS)

    Inputs:
    -----------------
           X: shape (N, D) data matrix
          ES: shape (N, K) E_q[s]
         ESS: shape (K, K) sum over data points of E_q[ss'] (N, K, K)
                           if E_q[ss'] is provided, the sum over N is done for you.

    Outputs:
    --------
          mu: shape (D, K) matrix of means in p(y|{s_i},mu,sigma)
       sigma: shape (,)    standard deviation in same
         pie: shape (1, K) vector of parameters specifying generative distribution for s
    """
    N, D = X.shape
    if ES.shape[0] != N:
        raise TypeError('ES must have the same number of rows as X')
    K = ES.shape[1]
    if ESS.shape == (N, K, K):
        ESS = np.sum(ESS, axis=0)
    if ESS.shape != (K, K):
        raise TypeError('ESS must be square and have the same number of columns as ES')

    mu = np.dot(np.dot(np.linalg.inv(ESS), ES.T), X).T
    sigma = np.sqrt((np.trace(np.dot(X.T, X)) + np.trace(np.dot(np.dot(mu.T, mu), ESS))
                     - 2 * np.trace(np.dot(np.dot(ES.T, X), mu))) / (N * D))
    pie = np.mean(ES, axis=0, keepdims=True)

    return mu, sigma, pie


In [36]:
class MessagePassing:
    """
    Implements a message passing approximation for binary latent factor models using variational inference.

    Attributes:
        bernoulli_params (np.ndarray):
            Bernoulli parameter tensor of shape (num_data_points, num_latents, num_latents).
            - Off-diagonal entries represent \tilde{g}_{ij, \neg s_i}(s_j) for data point n.
            - Diagonal entries represent \tilde{f}_i(s_i).
    """

    def __init__(self, bernoulli_parameter_matrix: np.ndarray):
        """
        Initializes the MessagePassing with a given Bernoulli parameter tensor.

        Args:
            bernoulli_parameter_matrix (np.ndarray):
                Initial Bernoulli parameter tensor with shape
                (num_data_points, num_latents, num_latents).
        """
        self.bernoulli_parameter_matrix = bernoulli_parameter_matrix

    @property
    def expectation_s(self) -> np.ndarray:
        """
        Computes the expectation of the latent variables S.

        Returns:
            np.ndarray: Lambda matrix representing E[S].
        """
        return self.lambda_matrix

    @property
    def expectation_ss(self) -> np.ndarray:
        """
        Computes the expectation of the outer product of latent variables S.

        Returns:
            np.ndarray: E[SS^T].
        """
        ess = self.lambda_matrix.T @ self.lambda_matrix
        # Replace diagonal with correct E[s_i^2] = E[s_i] = lambda_i
        np.fill_diagonal(ess, self.lambda_matrix.sum(axis=0))
        return ess

    @property
    def log_lambda_matrix(self) -> np.ndarray:
        """
        Computes the logarithm of the lambda matrix.

        Returns:
            np.ndarray: log(E[S]).
        """
        return np.log(self.lambda_matrix)

    @property
    def log_one_minus_lambda_matrix(self) -> np.ndarray:
        """
        Computes the logarithm of (1 - lambda matrix).

        Returns:
            np.ndarray: log(1 - E[S]).
        """
        return np.log(1 - self.lambda_matrix)

    @property
    def n(self) -> int:
        """
        Returns the number of data points.

        Returns:
            int: Number of data points (n).
        """
        return self.lambda_matrix.shape[0]

    @property
    def k(self) -> int:
        """
        Returns the number of latent variables.

        Returns:
            int: Number of latent variables (k).
        """
        return self.lambda_matrix.shape[1]

    def compute_free_energy(
        self,
        x: np.ndarray,
        binary_latent_factor_model: 'LoopyBP',
    ) -> float:
        """
        Computes the free energy associated with the current EM parameters and data.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            binary_latent_factor_model (LoopyBP): Binary latent factor model instance.

        Returns:
            float: Average free energy per data point.
        """
        expectation_log_p_x_s_given_theta = (
            self._compute_expectation_log_p_x_s_given_theta(
                x, binary_latent_factor_model
            )
        )
        entropy = self._compute_approximation_model_entropy()
        free_energy = (expectation_log_p_x_s_given_theta + entropy) / self.n
        return free_energy

    def _compute_expectation_log_p_x_s_given_theta(
        self,
        x: np.ndarray,
        binary_latent_factor_model: "LoopyBP",
    ) -> float:
        """
        Computes the expectation of log P(X, S | theta).

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            binary_latent_factor_model (LoopyBP): Binary latent factor model instance.

        Returns:
            float: Expectation of log P(X, S | theta).
        """
        # Compute expected mean: E[S] @ mu^T
        expected_mean = self.lambda_matrix @ binary_latent_factor_model.mu.T  # Shape: (n, d)

        # Compute E[S S^T] element-wise multiplied by mu mu^T
        expected_ss_mu_mu = np.multiply(
            self.expectation_ss,
            binary_latent_factor_model.mu.T @ binary_latent_factor_model.mu
        )  # Shape: (k, k)

        # Compute the expectation of log P(X | S, theta)
        term1 = - (self.n * binary_latent_factor_model.k / 2) * np.log(2 * np.pi * binary_latent_factor_model.variance)
        term2 = -0.5 * binary_latent_factor_model.precision * (
            np.sum(x ** 2)
            - 2 * np.sum(x * expected_mean)
            + np.sum(expected_ss_mu_mu)
            - np.trace(expected_ss_mu_mu)  # Remove E[s_i^2] and add correct E[s_i]
            + np.sum(
                self.lambda_matrix @ (binary_latent_factor_model.mu ** 2).T
            )
        )
        expectation_log_p_x_given_s_theta = term1 + term2

        # Compute the expectation of log P(S | theta)
        expectation_log_p_s_given_theta = np.sum(
            self.lambda_matrix * binary_latent_factor_model.log_pi
            + (1 - self.lambda_matrix) * binary_latent_factor_model.log_one_minus_pi
        )

        return expectation_log_p_x_given_s_theta + expectation_log_p_s_given_theta


    def _compute_approximation_model_entropy(self) -> float:
        """
        Computes the entropy of the approximation model.

        Returns:
            float: Entropy.
        """

        entropy = -np.sum(
            self.lambda_matrix * self.log_lambda_matrix
            + (1 - self.lambda_matrix) * self.log_one_minus_lambda_matrix
        )
        return entropy

    @property
    def lambda_matrix(self) -> np.ndarray:
        """
        Computes the lambda matrix by aggregating natural parameters and applying the sigmoid function.

        Returns:
            np.ndarray: Lambda matrix representing E[S].
        """
        aggregated_natural_params = self.natural_parameter_matrix.sum(axis=1)  # Sum over incoming messages
        lambda_matrix = 1 / (1 + np.exp(-aggregated_natural_params))
        # Numerical stability
        lambda_matrix = np.clip(lambda_matrix, 1e-10, 1 - 1e-10)
        return lambda_matrix

    @property
    def natural_parameter_matrix(self) -> np.ndarray:
        """
        Computes the natural parameters (eta) from the Bernoulli parameters.

        Returns:
            np.ndarray: Natural parameter matrix.
        """
        odds = self.bernoulli_parameter_matrix / (1 - self.bernoulli_parameter_matrix)
        natural_params = np.log(odds)
        return natural_params

    def aggregate_incoming_binary_factor_messages(
        self, node_index: int, excluded_node_index: int
    ) -> np.ndarray:
        """
        Aggregates incoming natural messages to a target node, excluding messages from a specific node.

        Args:
            target_node (int): Index of the target latent variable.
            exclude_node (int): Index of the node to exclude from aggregation.

        Returns:
            np.ndarray: Aggregated natural parameters for each data point.
        """
        # Sum natural parameters from all nodes except the excluded node
        incoming_before = self.natural_parameter_matrix[:, :excluded_node_index, node_index]
        incoming_after = self.natural_parameter_matrix[:, excluded_node_index + 1 :, node_index]
        aggregated = np.sum(incoming_before, axis=1) + np.sum(incoming_after, axis=1)
        return aggregated.reshape(-1)

    @staticmethod
    def calculate_bernoulli_parameter(
            natural_parameter_matrix: np.ndarray
    ) -> np.ndarray:
        """
        Computes Bernoulli parameters from natural parameters using the sigmoid function.

        Args:
            natural_params (np.ndarray): Natural parameter matrix.

        Returns:
            np.ndarray: Updated Bernoulli parameters.
        """
        bernoulli_parameter = 1 / (1 + np.exp(-natural_parameter_matrix))
        # Numerical stability
        bernoulli_parameter = np.clip(bernoulli_parameter, 1e-10, 1 - 1e-10)
        return bernoulli_parameter

    def variational_expectation_step(
        self, x: np.ndarray, binary_latent_factor_model: 'LoopyBP'
    ) -> List[float]:
        """
        Performs the variational expectation step, updating singleton and binary factors iteratively.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            binary_latent_factor_model (LoopyBP): Binary latent factor model instance.

        Returns:
            List[float]: List of free energies after each update.
        """
        free_energies = [self.compute_free_energy(x, binary_latent_factor_model)]

        for i in range(self.k):
            # Update singleton factor for latent variable i
            singleton_natural = self.calculate_singleton_message_update(
                LoopyBP=binary_latent_factor_model,
                x=x,
                i=i,
            )
            self.bernoulli_parameter_matrix[:, i, i] = self.calculate_bernoulli_parameter(singleton_natural)
            free_energies.append(self.compute_free_energy(x, binary_latent_factor_model))

            # Update binary factors between latent variables i and j
            for j in range(i):
                # Update message from i to j
                binary_natural_ij = self.calculate_binary_message_update(
                    LoopyBP=binary_latent_factor_model,
                    x=x,
                    i=i,
                    j=j,
                )
                self.bernoulli_parameter_matrix[:, i, j] = self.calculate_bernoulli_parameter(binary_natural_ij)

                # Update message from j to i
                binary_natural_ji = self.calculate_binary_message_update(
                    LoopyBP=binary_latent_factor_model,
                    x=x,
                    i=i,
                    j=j,
                )
                self.bernoulli_parameter_matrix[:, j, i] = self.calculate_bernoulli_parameter(binary_natural_ji)

                free_energies.append(self.compute_free_energy(x, binary_latent_factor_model))

        return free_energies


    def calculate_binary_message_update(
        self,
        x: np.ndarray,
        LoopyBP: "LoopyBP",
        i: int,
        j: int,
    ) -> float:
        """
        Updates the natural parameters for a binary factor between two latent variables.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            latent_model (LoopyBP): Binary latent factor model instance.
            i, source (int): Index of the source latent variable.
            j, target (int): Index of the target latent variable.

        Returns:
            np.ndarray: Updated natural parameters for the binary factor.
        """
        # Compute natural parameters excluding the target node
        natural_parameter_i_not_j = LoopyBP.b_index(x=x, node_index=i)
        natural_parameter_i_not_j += self.aggregate_incoming_binary_factor_messages(node_index=i, excluded_node_index=j)

        # Retrieve interaction weight between source and target
        w_i_j = LoopyBP.w_matrix_index(i, j)

        # Update natural parameters for the binary factor
        updated_natural = np.log1p(np.exp(w_i_j + natural_parameter_i_not_j)) - np.log1p(np.exp(natural_parameter_i_not_j))
        return updated_natural

    @staticmethod
    def calculate_singleton_message_update(
        x: np.ndarray,
        LoopyBP: "LoopyBP",
        i: int,
    ) -> float:
        """
        Updates the natural parameters for a singleton factor of a latent variable.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            latent_model (LoopyBP): Binary latent factor model instance.
            i, latent_idx (int): Index of the latent variable to update.

        Returns:
            np.ndarray: Updated natural parameters for the singleton factor.
        """
        # Singleton update does not require message aggregation
        return LoopyBP.b_index(x=x, node_index=i)


def init_message_passing(k: int, n: int) -> MessagePassing:
    """
    Message passing initialisation

    :param k: number of latent variables
    :param n: number of data points
    :return: message passing
    """
    bernoulli_parameter_matrix = np.random.random(size=(n, k, k))
    return MessagePassing(bernoulli_parameter_matrix)


In [37]:
class LoopyBP:
    def __init__(
        self,
        mu: np.ndarray,
        sigma: float,
        pi: np.ndarray,
    ):
        """
        Initializes the LoopyBP model with mean, variance, and prior probabilities.

        Args:
            mu (np.ndarray): Mean matrix of shape (num_dimensions, num_latents).
            sigma (float): Standard deviation parameter.
            pi (np.ndarray): Prior probabilities of latent variables, shape (1, num_latents).
        """
        super().__init__()
        self._mu = mu  # Shape: (num_dimensions, num_latents)
        self._sigma = sigma
        self._pi = pi   # Shape: (1, num_latents)
    # ==========================
    # Property Definitions
    # ==========================

    @property
    def d(self) -> int:
        return self.mu.shape[0]

    @property
    def k(self) -> int:
        return self.mu.shape[1]

    @property
    def mu(self) -> np.ndarray:
        """Mean matrix."""
        return self._mu

    @mu.setter
    def mu(self, value: np.ndarray) -> None:
        self._mu = value

    @property
    def sigma(self) -> float:
        """Standard deviation."""
        return self._sigma

    @sigma.setter
    def sigma(self, value: float) -> None:
        self._sigma = value

    @property
    def pi(self) -> np.ndarray:
        """Prior probabilities of latent variables."""
        return self._pi

    @pi.setter
    def pi(self, value: np.ndarray) -> None:
        self._pi = value

    @property
    def variance(self) -> float:
        """Variance, square of the standard deviation."""
        return self.sigma ** 2

    @property
    def precision(self) -> float:
        """Precision, inverse of variance."""
        return 1.0 / self.variance

    @property
    def log_pi(self) -> np.ndarray:
        """Logarithm of prior probabilities."""
        return np.log(self.pi)

    @property
    def log_one_minus_pi(self) -> np.ndarray:
        """Logarithm of (1 - prior probabilities)."""
        return np.log(1 - self.pi)

    @property
    def log_pi_ratio(self) -> np.ndarray:
        """Log ratio of pi to (1 - pi)."""
        return self.log_pi - self.log_one_minus_pi

    @property
    def num_dimensions(self) -> int:
        """Number of dimensions in the data."""
        return self.mu.shape[0]

    @property
    def num_latents(self) -> int:
        """Number of latent variables."""
        return self.mu.shape[1]

    # ==========================
    # Static Methods
    # ==========================

    @staticmethod
    def calculate_maximisation_parameters(
        x: np.ndarray,
        approximation: MessagePassing,
    ) -> Tuple[np.ndarray, float, np.ndarray]:
        """
        Performs the Maximization (M) step to update model parameters based on expectations.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            approximation (MessagePassing): Variational approximation instance.

        Returns:
            Tuple[np.ndarray, float, np.ndarray]:
                - Updated mu matrix.
                - Updated sigma (float).
                - Updated pi vector.
        """
        return m_step(
            X=x,
            ES=approximation.expectation_s,
            ESS=approximation.expectation_ss,
        )

    # ==========================
    # Instance Methods
    # ==========================

    def maximisation_step(
        self,
        x: np.ndarray,
        binary_latent_factor_approximation: MessagePassing,
    ) -> None:
        """
        Updates the model parameters by performing the Maximization step.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            binary_latent_factor_approximation (MessagePassing): Variational approximation instance.
        """
        mu_updated, sigma_updated, pi_updated = self.calculate_maximisation_parameters(
            x, binary_latent_factor_approximation
        )
        self.mu = mu_updated
        self.sigma = sigma_updated
        self.pi = pi_updated

    def w_matrix(self) -> np.ndarray:
        """
        Computes the weight matrix for Loopy Belief Propagation.

        Returns:
            np.ndarray: Weight matrix of shape (num_latents, num_latents).
        """
        return -self.precision * (self.mu.T @ self.mu)

    def w_matrix_index(self, i, j) -> float:
        """
        Retrieves the weight between two specific latent variables.

        Args:
            source (int): Index of the source latent variable.
            target (int): Index of the target latent variable.

        Returns:
            float: Weight value between the specified latent variables.
        """
        return -self.precision * (self.mu[:, i] @ self.mu[:, j])

    def b(self, x: np.ndarray) -> np.ndarray:
        """
        Computes the 'b' term in LoopyBP for all data points.

        Args:
            data (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).

        Returns:
            np.ndarray: Computed 'b' matrix of shape (num_data_points, num_latents).
        """
        return -(
            self.precision * x @ self.mu
            + self.log_pi_ratio
            - 0.5 * self.precision * np.sum(self.mu ** 2, axis=0)
        )

    def b_index(self, x: np.ndarray, node_index: int) -> np.ndarray:
        """
        Computes the 'b' term for a specific node in LoopyBP across all data points.

        Args:
            x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
            node_index (int): Index of the target latent variable.

        Returns:
            np.ndarray: Computed 'b' vector for the specified node, shape (num_data_points,).
        """
        return -(
            self.precision * x @ self.mu[:, node_index]
            + (self.log_pi[0, node_index] - self.log_one_minus_pi[0, node_index])
            - 0.5 * self.precision * (self.mu[:, node_index] @ self.mu[:, node_index])
        ).reshape(-1)






def init_LoopyBP(
    x: np.ndarray,
    binary_latent_factor_approximation: "MessagePassing",
) -> LoopyBP:
    """
    Initialise the Loopy BP model by running a single m step with the parameters of a given binary latent factor approximation
    """
    mu, sigma, pi = LoopyBP.calculate_maximisation_parameters(
        x, binary_latent_factor_approximation
    )
    return LoopyBP(mu=mu, sigma=sigma, pi=pi)


In [38]:
def plot_latent_features(
    mu: np.ndarray,
    num_latents: int,
    feature_shape: Tuple[int, int],
    title: str,
    save_path: str,
) -> None:
    """
    Plots and saves the latent features as images.

    Args:
        mu (np.ndarray): Mean matrix of shape (num_dimensions, num_latents).
        num_latents (int): Number of latent variables (factors).
        feature_shape (Tuple[int, int]): Shape to which each latent factor is reshaped for visualization.
        title (str): Title of the plot.
        save_path (str): File path to save the plot.
    """
    # Handle the case where num_latents might be 1
    if num_latents == 1:
        fig, axes = plt.subplots(1, 1, figsize=(2, 2))
        axes = [axes]
    else:
        fig, axes = plt.subplots(1, num_latents, figsize=(num_latents * 2, 2))

    for i in range(num_latents):
        # Reshape the latent factor for visualization
        try:
            feature_image = mu[:, i].reshape(feature_shape)
        except ValueError as e:
            raise ValueError(
                f"Cannot reshape latent factor {i} to shape {feature_shape}: {e}"
            )

        axes[i].imshow(feature_image, cmap='gray', interpolation='none')
        axes[i].set_title(f"Feature {i + 1}")
        axes[i].axis('off')

    fig.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close(fig)


def plot_free_energy(
    free_energy: List[float],
    title: str,
    xlabel: str,
    ylabel: str,
    save_path: str,
) -> None:
    """
    Plots and saves the free energy over EM iterations.

    Args:
        free_energy (List[float]): List of free energy values over iterations.
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        save_path (str): File path to save the plot.
    """
    plt.figure(figsize=(8, 6))
    plt.plot(free_energy, marker='o')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()



In [39]:
import matplotlib.pyplot as plt
import numpy as np

def runLoopyBP(
        x: np.ndarray,
        k: int,
        em_iterations: int,
        save_path: str
) -> None:
    """
    Executes the Loopy Belief Propagation (LoopyBP) algorithm for binary latent factor models,
    including initialization, expectation-maximization (EM) iterations, and visualization
    of latent features and free energy over iterations.

    Args:
        x (np.ndarray): Data matrix of shape (num_data_points, num_dimensions).
        k (int): Number of latent variables (factors) in the model.
        em_iterations (int): Number of EM iterations to perform.
        save_path (str): Base path for saving generated plots.

    """
    n = x.shape[0]

    # Initialize message passing and LoopyBP models
    message_passing = init_message_passing(k, n)
    LoopyBP = init_LoopyBP(x, message_passing)

    # Plot and save initial latent features
    plot_latent_features(
        mu=LoopyBP.mu,
        num_latents=k,
        feature_shape=(4, 4),
        title="Initial Latent Features (Loopy BP)",
        save_path=f"{save_path}-init-latent-factors.png",
    )


    # Perform EM Updates
    message_passing, LoopyBP, free_energy = learn_binary_factors(
        x=x,
        k=k,
        em_iterations=em_iterations,
        binary_latent_factor_model=LoopyBP,
        binary_latent_factor_approximation=message_passing,
    )

    # Plot and save learned latent features
    plot_latent_features(
        mu=LoopyBP.mu,
        num_latents=k,
        feature_shape=(4, 4),
        title="Learned Latent Features (Loopy BP)",
        save_path=f"{save_path}-learned-latent-factors.png",
    )


    # Plot and save free energy over EM iterations
    plot_free_energy(
        free_energy=free_energy,
        title="Free Energy (Loopy BP)",
        xlabel="Iterations",
        ylabel="Free Energy",
        save_path=f"{save_path}-free-energy.png",
    )



In [40]:
def is_converge(
    free_energies: List[float],
    current_lambda_matrix: np.ndarray,
    previous_lambda_matrix: np.ndarray,
    free_energy_threshold: float = 1e-6,
    lambda_threshold: float = 1e-6,
) -> bool:
    """
    Determine whether the algorithm has converged based on changes in free energy
    and the lambda matrix.

    Convergence is achieved if the change in free energy between the last two iterations
    is below a specified threshold and the change in the lambda matrix (measured by
    the Frobenius norm) is also below a specified threshold.

    Parameters
    ----------
    free_energies : List[float]
        List of free energy values recorded at each iteration.
    current_lambda_matrix : np.ndarray
        The current lambda matrix after the latest iteration.
    previous_lambda_matrix : np.ndarray
        The lambda matrix from the previous iteration.
    free_energy_threshold : float, optional
        Threshold for the change in free energy to determine convergence, by default 1e-6.
    lambda_threshold : float, optional
        Threshold for the change in the lambda matrix (Frobenius norm) to determine convergence,
        by default 1e-6.

    Returns
    -------
    bool
        True if both the change in free energy and the change in lambda matrix are below
        their respective thresholds, indicating convergence. Otherwise, False.
    """
    if len(free_energies) < 2:
        # Not enough data to determine convergence
        return False

    # Calculate the absolute change in free energy
    free_energy_change = abs(free_energies[-1] - free_energies[-2])

    # Calculate the Frobenius norm of the change in lambda matrix
    lambda_change = np.linalg.norm(current_lambda_matrix - previous_lambda_matrix)

    # Check if both changes are below their respective thresholds
    return (free_energy_change <= free_energy_threshold) and (lambda_change <= lambda_threshold)


def learn_binary_factors(
    x: np.ndarray,
    k: int,
    em_iterations: int,
    binary_latent_factor_model: 'LoopyBP',
    binary_latent_factor_approximation: 'MeanFieldApproximation',
) -> Tuple['MeanFieldApproximation', 'LoopyBP', List[float]]:
    """
    Perform the Expectation-Maximization (EM) algorithm to learn binary latent factors.

    This function iteratively performs the E-step and M-step to optimize the
    variational approximation of binary latent factors and update the
    variational Bayes model. It records the free energy at each iteration to
    monitor convergence.

    Parameters
    ----------
    x : np.ndarray
        Data matrix of shape (n_samples, n_dimensions), where n_samples is the
        number of data points and n_dimensions is the number of observed dimensions.
    em_iterations : int
        Maximum number of EM iterations to perform.
    binary_latent_factor_model : VariationalBayes
        An instance of VariationalBayes representing the current model.
    binary_latent_factor_approximation : MeanFieldApproximation
        An instance of MeanFieldApproximation representing the current variational
        approximation of the binary latent factors.

    Returns
    -------
    Tuple[MeanFieldApproximation, VariationalBayes, List[float]]
        A tuple containing:
        - The updated MeanFieldApproximation instance.
        - The updated VariationalBayes model.
        - A list of free energy values recorded at each EM iteration.
    """
    # Initialize the list of free energies with the initial free energy
    free_energies: List[float] = [
        binary_latent_factor_approximation.compute_free_energy(
            x, binary_latent_factor_model
        )
    ]

    for iteration in range(1, em_iterations + 1):
        # Store the previous lambda matrix for convergence checking
        previous_lambda_matrix = np.copy(binary_latent_factor_approximation.lambda_matrix)

        # E-step: Update the variational approximation (lambda matrix)
        free_energy_history = binary_latent_factor_approximation.variational_expectation_step(
            x=x,
            binary_latent_factor_model=binary_latent_factor_model,
        )

        # M-step: Update the variational Bayes model parameters
        binary_latent_factor_model.maximisation_step(
            x=x,
            binary_latent_factor_approximation=binary_latent_factor_approximation,
        )

        # Compute and record the new free energy
        current_free_energy = binary_latent_factor_approximation.compute_free_energy(
            x, binary_latent_factor_model
        )
        free_energies.append(current_free_energy)

        # Check for convergence
        if is_converge(
            free_energies=free_energies,
            current_lambda_matrix=binary_latent_factor_approximation.lambda_matrix,
            previous_lambda_matrix=previous_lambda_matrix,
        ):
            print(f"current K = {k},"
                  f" Convergence achieved at iteration {iteration},"
                  f" Free Energy at Convergence: {current_free_energy}.")
            break


    return binary_latent_factor_approximation, binary_latent_factor_model, free_energies

In [None]:
import os
from dataclasses import asdict


import numpy as np
import pandas as pd

# Constants for output directories and random seed
OUTPUTS_FOLDER = "LoopyBP"
DEFAULT_SEED = 43

if __name__ == "__main__":
    np.random.seed(DEFAULT_SEED)

    if not os.path.exists(OUTPUTS_FOLDER):
        os.makedirs(OUTPUTS_FOLDER)


    x = data
    k = 8
    em_iterations = 100
    e_maximum_steps = 50
    e_convergence_criterion = 0

    runLoopyBP(x, k, em_iterations, save_path=os.path.join(OUTPUTS_FOLDER, "all"))