# A mini-tutorial for Recognition Parametrized Factor Analysis (RP-FA)

**TL;DR** This notebook contains implementation examples of RP-FA as well as a few empirical results on latent identifiability and some useful tricks for fitting a model. 

## 1. Notations, Methods and Tricks

### 1.1 Model

Let $\mathcal{X} = \{ x_j \}$ be one observation comprising $J$ conditionally indendent factors, the RPM writes:

$$ \mathsf{P_{\theta}}(\mathcal{X}, \mathcal{Z}) = \mathsf{p_{\theta_z}}(\mathcal{Z}) \prod_{j} \left( \mathsf{p_{0,j}}(\mathsf{x_{j}}) \frac{\mathsf{f_{\theta j}}(\mathsf{\mathcal{Z}} | \mathsf{x_{j}})}{\mathsf{F_{\theta j}}(\mathcal{Z})} \right) $$

where, given observations $\{ \mathcal{X}^{(n)} \}_{n=1}^N$:

- $\mathsf{p_{0,j}}(\mathsf{x_{j}}) = \frac{1}{N} \sum_n \delta(\mathsf{x_{j}} - \mathsf{x_{j}}^{(n)})$

- $\mathsf{F_{\theta j}}(\mathcal{Z}) = \frac{1}{N} \sum_n \mathsf{f_{\theta j}}(\mathsf{\mathcal{Z}} | \mathsf{x_{j}^{(n)}})$

We assume that each recognition factor $\mathsf{f_{\theta j}}( \cdot | \mathsf{x_{j}^{(n)}})$ is defined by natural parameters $\eta_j( \mathsf{x_{j}}^{(n)}) = \eta_j^{(n)}$, which are parametrised by neural networks. 


### 1.2 Free Energy and Inner Variational Bound (or ELBO and ELLBO)

The Free Energy (or ELBO) is a lower bound to the log-likelihood. Given variational distributions $q^{(n)}$ (defined by natural parameter $\eta_q^{(n)}$) it writes:

$$\mathcal{F} = - \sum_n \mathsf{KL}(q^{(n)} || \mathsf{p_{\theta_z}}) + \sum_{jn} \log \mathsf{p_{0,j}}(x_{j}^{(n)}) + \sum_{jn} \bigg \langle \frac{\mathsf{f_{\theta j}}(\mathsf{\mathcal{Z}} | \mathsf{x_{j}})}{\mathsf{F_{\theta j}}(\mathcal{Z})} \bigg\rangle_{q^{(n)}}$$ 

When the recognition and the variational are multivariate normal distribution, $\mathcal{F}$ is not tractable. We therefore rely on an inner variational lower bound (or ELLBO):

$$\tilde {\mathcal{F}} = - \sum_n \mathsf{KL}(q^{(n)} || \mathsf{p_{\theta_z}}) - \sum_{jn} \mathsf{KL}(q^{(n)} || \mathsf{\hat{f}_{\theta j}}( \cdot | \mathsf{x_{j}^{(n)}})) + \sum_{nj} \log \Gamma_j^n$$ 



where:
- $\mathsf{\hat{f}_{\theta j}}( \cdot | \mathsf{x_{j}^{(n)}})$ is defined by natural parameter $\eta_j( \mathsf{x_{j}}^{(n)}) + \tilde{\eta}_j^{(n)}$

- $\Gamma_j^n = \frac{\exp(s_j^{n,n})}{\sum_m \exp(s_j^{m,n})}$

- $s_j^{m,n} = - \frac{1}{4} \eta_{j,1}^{(m) \top} \left[  \left(\eta_{j,2}^{(n)} + \tilde{\eta}_{j, 2}^{(n)} \right)^{-1} + \left(\eta_{j,2}^{(n)} \right)^{-1} \right] \eta_{j,1}^{(m)} + \frac{1}{2} \eta_{j,1}^{(m) \top}  \left(\eta_{j,2}^{(n)} + \tilde{\eta}_{j, 2}^{(n)} \right)^{-1} \tilde{\eta}_{j,1}^{(n)}$

**Note:** this parametrisation slightly differs from the [original paper](https://arxiv.org/abs/2209.05661) but, with no loss of generality, it ensures validity of the model provided that all precision parameters are valid (to this extent, one typically parametrise their Cholesky decomposition, possibly with some jitter. Also note that using good estimation of logsumexp functions is critical for stability).

### 1.3 Learning and inference

- A convenient yet principled assumption is to set all 2nd natural parameters independent of the sample $n$. For example, for all $n= 1..N$, $\eta_{j,2}^{(n)} = \eta_{j,2}$. Not only does it follow from related conjugacy results, but it significantly speeds up and stabilize learning.

- Using fully flexible auxiliary parameters $\tilde{\eta}$ greatly complexifies the fitting procedure and increases memory requirements. Instead, we recall that the ELLBO is tight when $\mathsf{F_{\theta j}} \times \tilde{f}_j^{(n)} \approx q^{(n)}$. Of course, we cannot guarantee equality in the finite sample case since $\mathsf{F_{\theta j}}$ is a mixture. But we can either approximate it with (i) the prior or (ii) a moment-matched approximation. The former is more efficient, but the latter allows to tailor the approximation to each factor $j$.

- In the latter case, we obtain $\tilde{\eta}_j^{(n)} = \eta_q^{(n)} - \eta_{\mathsf{Fj}}$ where 

$$\eta_{\mathsf{Fj}, 1} = \left( I + \mathbb{V}(\eta_{j, 1}) \left(- \frac{1}{2} \eta_{j,2}^{-1} \right) \right)^{-1} \mathbb{E}(\eta_{j, 1}) \text{ and } \eta_{\mathsf{Fj}, 2} = \left( I + \mathbb{V}(\eta_{j, 1}) \left(- \frac{1}{2} \eta_{j,2}^{-1} \right) \right)^{-1} \eta_{j, 2} \text{ (note: here expectation and variances are taken over samples } n)$$



- If the prior $\mathsf{p_{\theta_z}}$ is also a multivariate normal distribution defined by $\eta_0$, the update for $q$ is closed form and obeys:

$$\eta_q^{(n)} = \frac{1}{J+1} \left( \eta_0 + \sum_{j=1}^J \eta_j^{(n)} + \tilde{\eta}_j^{(n)}   \right)$$

- A more general scenario that I am currently testing uses a Mixture of Gaussians. $\mathsf{p_{\theta_z}} = \sum_u \omega_u \mathsf{p_u}$ where each $\mathsf{p_u}$ is a Multivariate Gaussians with natural parameter $\eta_u$. In this case, we can use another variational approximation to the KL divergence between the prior and the variational distribution. It yiels the ELLLBO (or Evidence Lower Lower Lower Bound): 

$$\tilde{\tilde{\mathcal{F}}} = \sum_n \log \beta^{(n)} - \sum_{jn} \mathsf{KL}(q^{(n)} || \mathsf{\hat{f}_{\theta j}}( \cdot | \mathsf{x_{j}^{(n)}})) + \sum_{nj} \log \Gamma_j^n$$ 

where the last two terms are unchanged and $\beta^{(n)} = \sum_u \omega_u e^{- \mathsf{KL}(q^{(n)} || \mathsf{p_u})}$. This prior seems appropriate when the underlying prior is multimodal, or badly described by a Gaussian distribution. For initialization, we try to space the centroids evenly on a unit hypersphere to avoid them collapsing early. **Note**: One might consider using a similar bound on Mixture Parametrised Recognition factors.

- When the variational is not closed form, one solution is to paramametrize it with a fusion neural network that takes each observations $\{ x_j \}$ factors as an input. Yet, this seems unecessarily complicated. Although the update for is not closed form anymore, deriving the ELLLBO with respect $\eta_q^{n}$ shows that it should only depend on the recognition factors embedding. We therefore propose to parametrize the variational with a simple MultiLayer Perceptron (MLP) whose input is the concatenation of the factors natural parameters (see figure below). Finally, as initialization empiraically showed to be key, we pretrain the variational recognition network to output the average (or similar) natural parameter of $J$.





<img src="figures/tmp.png" style="height:300px">


## 2. Results