# TP3 - *Latent Dirichlet Allocation* et Inférence variationnelle 

### Estimation avancée - G3 SDIA

Dans ce TP, on s'intéresse à la méthode "inférence variationnelle" (VI) qui permet d'approcher la loi a posteriori d'un modèle (généralement inconnue) par une autre loi plus simple (généralement un produit de lois bien connues). Nous allons l'appliquer à un modèle probabiliste pour des données textuelles, appelé *Latent Dirichlet Allocation* (LDA, qui n'a rien à voir avec la LDA *Linear Discriminant Analysis* du cours de ML).

### Instructions

1. Renommer votre notebook sous la forme `tp3_Nom1_Nom2.ipynb`, et inclure le nom du binôme dans le notebook. 

2. Votre code, ainsi que toute sortie du code, doivent être commentés !

3. Déposer votre notebook sur Moodle dans la section prévue à cet effet avant la date limite : 23 Décembre 2023, 23h59.

In [82]:
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

## Partie 0 - Introduction

LDA is a popular probabilistic model for text data, introducted in [Blei et al. (2003)](https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf). In this model, the posterior distribution is intractable, and we choose to resort to variational inference (note that a Gibbs sampler would be feasible as well, but would be very slow). In particular, the CAVI updates can be easily derived.

In a few words, in LDA, each document is a mixture of topics, and each topic is a mixture of words. Uncovering those is the goal of *topic modeling*, and this is what we are going to do today. We will be using a collection of abstracts of papers published in JMLR (*Journal of Machine Learning Research*), one of the most prominent journals of the field.

**Check the .pdf file describing the model.**
The posterior is :
$$p(\boldsymbol{\beta}, \boldsymbol{\theta}, \mathbf{z} | \mathcal{D}),$$
which we are going to approximate in the following way :
$$\simeq \left[ \prod_{k=1}^K q(\beta_k) \right] \left[ \prod_{d=1}^D q(\theta_d) \right] \left[ \prod_{d=1}^D \prod_{n=1}^{N_d} q(z_{dn}) \right], $$
with :
* $q(\beta_k)$ a Dirichlet distribution (of size V) with parameter $[\lambda_{k1}, ...,\lambda_{kV}]$
* $q(\gamma_d)$ a Dirichlet distribution (of size K) with parameter $[\gamma_{d1}, ...,\gamma_{dK}]$
* $q(z_{dn})$ a Multinomial distribution (of size K) with parameter $[\phi_{dn1}, ..., \phi_{dnK}]$

The updates are as follows :
* $$\lambda_{kv} = \eta + \sum_{d=1}^D \sum_{n=1}^{N_d} w_{dnv} \phi_{dnk} $$
* $$\gamma_{dk} = \alpha + \sum_{n=1}^{N_d} \phi_{dnk}$$
* $$ \phi_{dnk} \propto \exp \left( \Psi(\gamma_{dk}) + \Psi(\lambda_{k, w_{dn}}) - \Psi(\sum_{v=1}^V \lambda_{kv}) \right)$$

$\Psi$ is the digamma function, use `scipy.special.digamma`.

## Partie 1 - Les données

The data is already prepared, see code below. We have a total of 1898 abstracts.

In [83]:
jmlr_papers = pkl.load(open("jmlr.pkl","rb")) 

**Q1.** Fill in a list of keywords from the course, to see how many papers are about probabilistic ML.

In [84]:
bayesian_jmlr_papers = []

for paper in jmlr_papers:
    bayesian_keywords = ['Bayesian', 'Bayes', 'Gaussian process', 'Markov chain Monte Carlo', 'MCMC', 'Variational inference']
    if any([kwd in paper["abstract"] for kwd in bayesian_keywords]):
        bayesian_jmlr_papers.append(paper)
        
print("There are", str(len(bayesian_jmlr_papers))+" Bayesian papers out of", str(len(jmlr_papers)))

There are 314 Bayesian papers out of 1898


Let us now preprocess the data. It is important to remove so-called "stop-words" like a, is, but, the, of, have... Scikit-learn will do the job for us. We will keep only the top-1000 words from the abstracts.

As a result, we get the count matrix $\mathbf{C}$ of size $D = 1898 \times V = 1000$. $c_{dv}$ is the number of occurrences of word $v$ in document $d$. This compact representation is called "bag-of-words". Of course from $\mathbf{C}$ you easily recover the words, since in LDA the order does not matter.

In [85]:
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features = 1000, stop_words='english')
X = vectorizer.fit_transform([paper["abstract"] for paper in jmlr_papers])
print(vectorizer.get_feature_names_out()) # Top-1000 words
C = X.toarray() # Count matrix

# Removing documents with 0 words
idx = np.where(np.sum(C, axis = 1)==0)
C = np.delete(C, idx, axis = 0)

['100' '16' '17' '18' '949' '_blank' 'ability' 'able' 'abs' 'according'
 'account' 'accuracy' 'accurate' 'achieve' 'achieved' 'achieves' 'action'
 'actions' 'active' 'adaboost' 'adaptive' 'addition' 'additional'
 'additive' 'address' 'advantage' 'advantages' 'agent' 'aggregation' 'al'
 'algorithm' 'algorithmic' 'algorithms' 'allow' 'allowing' 'allows'
 'alternative' 'analysis' 'analyze' 'applicable' 'application'
 'applications' 'applied' 'apply' 'applying' 'approach' 'approaches'
 'appropriate' 'approximate' 'approximately' 'approximation'
 'approximations' 'arbitrary' 'art' 'article' 'artificial' 'associated'
 'assume' 'assumed' 'assumption' 'assumptions' 'asymptotic'
 'asymptotically' 'attributes' 'available' 'average' 'averaging' 'bandit'
 'base' 'based' 'basic' 'basis' 'batch' 'bayes' 'bayesian' 'behavior'
 'belief' 'benchmark' 'best' 'better' 'bias' 'bib' 'binary' 'block'
 'boosting' 'bound' 'bounded' 'bounds' 'br' 'build' 'building' 'called'
 'capture' 'carlo' 'case' 'cases' 'ca

**Q2.** How many elements of $\mathbf{C}$ are non-zero ? Is this surprising ?

In [86]:
#count the number of non-zero in C
print("There are", str(np.count_nonzero(C))+" non-zero entries in C")
print("It represents {:.2f}% of the matrix".format(np.count_nonzero(C)/C.size*100))

There are 85804 non-zero entries in C
It represents 4.53% of the matrix


> This sparsity results from the large, diverse vocabulary and the specific usage of words in academic texts. The bag-of-words model, focusing on word frequencies and the removal of common stop words, further contributes to this sparsity. The wide range of topics in JMLR leads to varied word usage across documents, accentuating the sparsity in the matrix.

## Partie 2 - Inférence variationnelle

> As you know from the lecture, VI aims at maximizing the ELBO. I have prepared for you the function to compute the ELBO.

In [87]:
from scipy.special import digamma, loggamma

def ELBO(L, G, phi, a, e, W):
    # Computes the ELBO with the values of the parameters L (Lambda), G (Gamma), and Phi
    # a, e are hyperparameters (alpha and eta)
    # W are the words (obsereved)
    
    # L - K x V matrix (variational parameters Lambda)
    # G - D x K matrix (variational parameters Gamma)
    # phi - List of D elements, each element is a Nd x K matrix (variational parameters Phi)
    # a - Scalar > 0 (hyperparameter alpha)
    # e - Scalar > 0 (hyperparameter eta)
    # W - List of D elements, each element is a Nd x V matrix (observed words)
    
    e_log_B = (digamma(L).T - digamma(np.sum(L, axis = 1))).T
    e_log_T = (digamma(G).T - digamma(np.sum(G, axis = 1))).T
    D,K=G.shape
    V=L.shape[1]
    t1 = (e-1)*np.sum(e_log_B)
    t2 = (a-1)*np.sum(e_log_T)

    phi_s = np.zeros((D,K))
    for d in range(0,D):
        phi_s[d,:] = np.sum(phi[d], axis = 0)
    t3 = np.sum(e_log_T*phi_s)
    
    tmp = np.zeros((K,V))
    for d in range(0,D):
        tmp = tmp + np.dot(phi[d].T, W[d])
    t4 = np.sum(e_log_B*tmp)
    
    t5 = np.sum(loggamma(np.sum(L, axis = 1))) - np.sum(loggamma(L)) + np.sum((L-1)*e_log_B)
    t6 = np.sum(loggamma(np.sum(G, axis = 1))) - np.sum(loggamma(G)) + np.sum((G-1)*e_log_T)

    t7 = 0
    for d in range(0,D):
        t7 = t7 + np.sum(phi[d]*np.log(phi[d] + np.spacing(1)))

    return t1 + t2 + t3 + t4 - t5 - t6 - t7

**Q1.** Transform the matrix $\mathbf{C}$ into the observed words $\mathbf{w}$. $\mathbf{w}$ should be a list of $D$ elements, each element of the list being a $N_d \times V$ matrix.

In [88]:
#transform  the matrix C into the observed words w
def transform_C_to_W(C):
    """
    Transform a word count matrix C into a list of word matrices w.

    :param C: a DxV matrix where C[d, v] is the count of word v in document d.
    :return: A list of D elements, each a N_d x V matrix representing the words in document d.
    """
    D, V = C.shape
    w = []

    for d in range(D):
        N_d = np.sum(C[d, :]) # Total number of words in document d
        word_matrix = np.zeros((N_d, V), dtype=int) # Initialize the word matrix for document d

        # Fill in the word occurrences
        word_idx = 0
        for v in range(V):
            count = C[d, v]
            for _ in range(count):
                word_matrix[word_idx, v] = 1
                word_idx += 1

        w.append(word_matrix) # Append the word matrix to the list
    return w

w = transform_C_to_W(C)
# Display the resulting word matrices
print(w)

def percentage_zeros_w(w):
    total_elements = sum(matrix.size for matrix in w)
    total_zeros = sum(np.count_nonzero(matrix == 0) for matrix in w)
    return (total_zeros / total_elements) * 100

print(percentage_zeros_w(w))


## Check 
print("On vérifie que la 1ère ligne de C est égale à la somme de la 1ere ligne de W : ", np.sum(C[0]) ==  np.sum(np.sum(w[0],axis=1)), ", et vaut :",np.sum(C[0]) )

def verify_transformation(C, W):
    # Check if the number of documents matches
    if len(C) != len(W):
        return False, "Number of documents does not match."

    for doc_idx, (doc_c, doc_w) in enumerate(zip(C, W)):
        # Total word count in the document from C
        total_words_c = np.sum(doc_c)

        # Total word count in the document from W
        total_words_w = doc_w.shape[0]

        # Check if total word counts match
        if total_words_c != total_words_w:
            return False, f"Total word count mismatch in document {doc_idx}."

        # Check if word frequencies match
        for word_idx, word_count in enumerate(doc_c):
            if word_count != np.sum(doc_w[:, word_idx]):
                return False, f"Word frequency mismatch in document {doc_idx} for word index {word_idx}."

    return True, "Transformation is correct."

# Verify the transformation
print(verify_transformation(C, w))


[array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
   

**Q2.** Implement the CAVI algorithm. The updates are given at the beginning of the notebook. Monitor the convergence with the values of the ELBO (but start with a fixed number of iterations, like 50).

In [110]:
from tqdm import tqdm
def CAVI(W, K, a, e, seed):
    np.random.seed(seed)  # Set random seed for reproducibility
    D = len(W)  # Number of documents
    V = W[0].shape[1]  # Vocabulary size

    # Initialize variational parameters
    L = np.random.rand(K, V)   # Lambda - K x V
    G = 1/K*np.ones((D,K))   # Gamma - D x K
    phi = [np.zeros((W[d].shape[0],K)) for d in range(0,D)]  # Phi - List of D Nd x K matrices
    elbo_list=[]
    max_iter = 50  # Set a fixed number of iterations

    for i in tqdm(range(max_iter)):
        for d in range(D):
            Nd = W[d].shape[0]
            for n in range(Nd):
                phi_update = digamma(G[d,:]) - digamma(np.sum(L, axis = 1)) +digamma(L[:,np.where(W[d][n,:]==1)[0]].T) 
                phi[d][n, :] = np.exp(phi_update)
                phi[d][n, :] /= np.sum(phi[d][n, :])  # Normalize
                
            G[d, :] =a + np.sum(phi[d], axis = 0)
                
        for k in range(K):
            L[k, :] = e + np.sum([np.dot(W[d].T, phi[d][:,k]) for d in range(0,D)], axis = 0)
        elbo=ELBO(L, G, phi, a, e, W)
        elbo_list.append(elbo)
        print(elbo)
    return L, G, phi, elbo_list

**Q3.** Run the algorithm with $K = 10$, $\alpha = 0.5$, $\eta = 0.1$. From the results, compute the MMSE of $\lambda_{kv}$ and $\gamma_{dk}$.

**Bonus** : Re-run the algorithm several times with different initializations, and keep the solution which returns the highest ELBO.

NB : In my implementation, one iteration of the CAVI algorithm takes about 4 seconds to run.

In [111]:
K,alpha,eta = 10,0.5,0.1
seed =42
# Running the CAVI algorithm with the corrected implementation
lambda_mmse, gamma_mmse, phi_mmse,elbo_list = CAVI(w, K, alpha, eta, seed)

# Displaying a portion of the MMSE estimates of lambda and gamma
lambda_mmse[:5], gamma_mmse[:5]  # Showing the first 5 rows for brevity 

  2%|▏         | 1/50 [00:04<04:04,  5.00s/it]

-778582.1597238141


  4%|▍         | 2/50 [00:09<03:55,  4.91s/it]

-775316.0050709111


  6%|▌         | 3/50 [00:14<03:50,  4.90s/it]

-772421.2164063791


  8%|▊         | 4/50 [00:19<03:45,  4.90s/it]

-769628.308051209


 10%|█         | 5/50 [00:24<03:39,  4.88s/it]

-766859.1886799146


 12%|█▏        | 6/50 [00:29<03:35,  4.89s/it]

-764084.2335513106


 14%|█▍        | 7/50 [00:34<03:32,  4.94s/it]

-761332.9579953447


 16%|█▌        | 8/50 [00:39<03:29,  4.98s/it]

-758724.123984636


 18%|█▊        | 9/50 [00:44<03:23,  4.97s/it]

-756378.9858813848


 20%|██        | 10/50 [00:49<03:18,  4.96s/it]

-754321.4105888437


 22%|██▏       | 11/50 [00:54<03:13,  4.95s/it]

-752508.7671648411


 24%|██▍       | 12/50 [00:59<03:07,  4.93s/it]

-750890.3875742368


 26%|██▌       | 13/50 [01:04<03:01,  4.91s/it]

-749426.4439893623


 28%|██▊       | 14/50 [01:08<02:54,  4.85s/it]

-748093.2362843428


 30%|███       | 15/50 [01:13<02:49,  4.84s/it]

-746873.4224224736


 32%|███▏      | 16/50 [01:18<02:42,  4.79s/it]

-745764.2349806747


 34%|███▍      | 17/50 [01:22<02:37,  4.77s/it]

-744769.8706113254


 36%|███▌      | 18/50 [01:27<02:31,  4.74s/it]

-743890.9091140557


 38%|███▊      | 19/50 [01:32<02:26,  4.72s/it]

-743114.2758480182


 40%|████      | 20/50 [01:37<02:21,  4.72s/it]

-742424.2393100639


 42%|████▏     | 21/50 [01:41<02:16,  4.71s/it]

-741806.6666872359


 44%|████▍     | 22/50 [01:46<02:11,  4.70s/it]

-741248.2442605817


 46%|████▌     | 23/50 [01:51<02:06,  4.70s/it]

-740742.6659992235


 48%|████▊     | 24/50 [01:55<02:02,  4.70s/it]

-740282.0063250811


 50%|█████     | 25/50 [02:00<01:59,  4.79s/it]

-739861.9424122045


 52%|█████▏    | 26/50 [02:05<01:55,  4.82s/it]

-739478.617999777


 54%|█████▍    | 27/50 [02:10<01:51,  4.85s/it]

-739125.0717355504


 56%|█████▌    | 28/50 [02:15<01:45,  4.80s/it]

-738806.2105113312


 58%|█████▊    | 29/50 [02:20<01:41,  4.81s/it]

-738512.8926595523


 60%|██████    | 30/50 [02:24<01:36,  4.81s/it]

-738241.5322691823


 62%|██████▏   | 31/50 [02:29<01:31,  4.83s/it]

-737994.1108405346


 64%|██████▍   | 32/50 [02:34<01:28,  4.93s/it]

-737766.9122430632


 66%|██████▌   | 33/50 [02:40<01:24,  4.99s/it]

-737561.487134031


 68%|██████▊   | 34/50 [02:44<01:19,  4.96s/it]

-737368.7178509962


 70%|███████   | 35/50 [02:50<01:14,  4.99s/it]

-737185.3253294305


 72%|███████▏  | 36/50 [02:55<01:12,  5.14s/it]

-737015.0211893513


 74%|███████▍  | 37/50 [03:00<01:07,  5.23s/it]

-736860.4711768219


 76%|███████▌  | 38/50 [03:06<01:02,  5.18s/it]

-736718.2373376116


 78%|███████▊  | 39/50 [03:11<00:56,  5.13s/it]

-736586.8308270946


 80%|████████  | 40/50 [03:15<00:50,  5.07s/it]

-736464.818119973


 82%|████████▏ | 41/50 [03:20<00:45,  5.03s/it]

-736352.9013066012


 84%|████████▍ | 42/50 [03:25<00:39,  5.00s/it]

-736247.0781103596


 86%|████████▌ | 43/50 [03:30<00:34,  4.97s/it]

-736145.3585255087


 88%|████████▊ | 44/50 [03:35<00:29,  4.94s/it]

-736046.364624216


 90%|█████████ | 45/50 [03:40<00:24,  4.93s/it]

-735953.1423321783


 92%|█████████▏| 46/50 [03:45<00:19,  4.91s/it]

-735867.5958820526


 94%|█████████▍| 47/50 [03:50<00:14,  4.90s/it]

-735785.9085668377


 96%|█████████▌| 48/50 [03:55<00:09,  4.89s/it]

-735707.5240694513


 98%|█████████▊| 49/50 [04:00<00:04,  4.89s/it]

-735636.4389549136


100%|██████████| 50/50 [04:04<00:00,  4.90s/it]

-735568.0198208147





(array([[ 0.10000671,  0.10000911,  0.1000069 , ...,  0.10003935,
         12.50768267,  0.10002145],
        [ 0.100006  ,  0.10001607,  0.10000588, ...,  0.10002878,
          0.10002653,  0.10004146],
        [ 0.10000977,  0.10002698,  0.10000628, ...,  0.10003936,
          0.10004494, 17.8440665 ],
        [ 0.10000761, 46.09978008,  0.10000819, ...,  0.10002035,
         26.25195678,  0.10003497],
        [ 0.10000679,  0.10001357,  0.10000483, ...,  0.10002751,
         23.54014278, 35.72820022]]),
 array([[ 5.49439474,  0.61086968, 12.20792458,  3.99961019,  1.10895762,
         20.00222314,  0.75291176,  4.93088198,  3.67395101, 10.21827529],
        [ 4.33471045,  0.6298793 ,  2.91261609,  2.44727937,  1.45309445,
         21.35867972,  1.51821376,  0.5548628 ,  3.34860355,  2.44206052],
        [10.83092077, 14.84835561,  0.63879722,  0.58085656,  2.25417587,
          4.94205685,  3.82845649, 12.0077235 ,  0.62320005,  9.44545709],
        [ 0.80221874,  7.81185834,  5.895

In [113]:
#try with different initialisation
# Running the CAVI algorithm with the corrected implementation
lambda_mmse, gamma_mmse, phi_mmse,elbo_list2 = CAVI(w, 20, alpha, eta, seed)
lambda_mmse, gamma_mmse, phi_mmse,elbo_list3 = CAVI(w, 5, alpha, eta, seed)
lambda_mmse, gamma_mmse, phi_mmse,elbo_list4 = CAVI(w, K, 0.1, eta, seed)
lambda_mmse, gamma_mmse, phi_mmse,elbo_list5 = CAVI(w, K, alpha, 0.5, seed)



  2%|▏         | 1/50 [00:07<05:47,  7.10s/it]

-798853.9884589749


  4%|▍         | 2/50 [00:14<05:54,  7.39s/it]

-794644.8247014575


  6%|▌         | 3/50 [00:21<05:42,  7.28s/it]

-790523.4123153432


  8%|▊         | 4/50 [00:29<05:33,  7.25s/it]

-786158.9721163716


 10%|█         | 5/50 [00:36<05:23,  7.19s/it]

-781363.7576828237


 12%|█▏        | 6/50 [00:43<05:14,  7.15s/it]

-776127.139711158


 14%|█▍        | 7/50 [00:50<05:05,  7.10s/it]

-770595.7889834435


 16%|█▌        | 8/50 [00:56<04:53,  6.99s/it]

-765127.808622489


 18%|█▊        | 9/50 [01:03<04:45,  6.96s/it]

-760098.3633800854


 20%|██        | 10/50 [01:10<04:36,  6.92s/it]

-755705.5402324477


 22%|██▏       | 11/50 [01:17<04:30,  6.94s/it]

-751917.002540249


 24%|██▍       | 12/50 [01:24<04:25,  7.00s/it]

-748636.0893190008


 26%|██▌       | 13/50 [01:31<04:17,  6.95s/it]

-745800.791804696


 28%|██▊       | 14/50 [01:38<04:12,  7.02s/it]

-743350.0961535436


 30%|███       | 15/50 [01:45<04:05,  7.02s/it]

-741238.6672259733


 32%|███▏      | 16/50 [01:52<03:58,  7.01s/it]

-739414.4393646042


 34%|███▍      | 17/50 [01:59<03:50,  6.99s/it]

-737828.0476254037


 36%|███▌      | 18/50 [02:06<03:42,  6.96s/it]

-736459.8126338995


 38%|███▊      | 19/50 [02:13<03:38,  7.05s/it]

-735267.8840879096


 40%|████      | 20/50 [02:20<03:31,  7.05s/it]

-734221.2191466991


 42%|████▏     | 21/50 [02:28<03:25,  7.07s/it]

-733295.926452711


 44%|████▍     | 22/50 [02:34<03:16,  7.01s/it]

-732480.9340625047


 46%|████▌     | 23/50 [02:42<03:09,  7.03s/it]

-731751.6572516039


 48%|████▊     | 24/50 [02:48<03:01,  6.98s/it]

-731089.3528594866


 50%|█████     | 25/50 [02:55<02:53,  6.95s/it]

-730486.0798571047


 52%|█████▏    | 26/50 [03:02<02:45,  6.89s/it]

-729933.4829604949


 54%|█████▍    | 27/50 [03:09<02:38,  6.90s/it]

-729422.9148654633


 56%|█████▌    | 28/50 [03:16<02:32,  6.92s/it]

-728954.1041230041


 58%|█████▊    | 29/50 [03:23<02:24,  6.89s/it]

-728533.2356026773


 60%|██████    | 30/50 [03:30<02:17,  6.88s/it]

-728157.7818137208


 62%|██████▏   | 31/50 [03:36<02:10,  6.85s/it]

-727818.6301248235


 64%|██████▍   | 32/50 [03:43<02:04,  6.89s/it]

-727507.1140225725


 66%|██████▌   | 33/50 [03:51<01:59,  7.01s/it]

-727219.6312024102


 68%|██████▊   | 34/50 [03:58<01:52,  7.00s/it]

-726954.7409856968


 70%|███████   | 35/50 [04:05<01:45,  7.02s/it]

-726708.5514571203


 72%|███████▏  | 36/50 [04:12<01:38,  7.05s/it]

-726477.1140850516


 74%|███████▍  | 37/50 [04:19<01:32,  7.11s/it]

-726262.8685524805


 76%|███████▌  | 38/50 [04:26<01:25,  7.09s/it]

-726063.6390829406


 78%|███████▊  | 39/50 [04:33<01:18,  7.15s/it]

-725877.6290066445


 80%|████████  | 40/50 [04:40<01:10,  7.07s/it]

-725701.6632513214


 82%|████████▏ | 41/50 [04:47<01:02,  6.99s/it]

-725536.3618576352


 84%|████████▍ | 42/50 [04:54<00:55,  6.92s/it]

-725379.9479251562


 86%|████████▌ | 43/50 [05:01<00:48,  6.88s/it]

-725236.3880998986


 88%|████████▊ | 44/50 [05:07<00:41,  6.85s/it]

-725107.8594191629


 90%|█████████ | 45/50 [05:14<00:34,  6.83s/it]

-724988.2924201154


 92%|█████████▏| 46/50 [05:21<00:27,  6.81s/it]

-724876.9592549291


 94%|█████████▍| 47/50 [05:28<00:20,  6.80s/it]

-724770.1188180147


 96%|█████████▌| 48/50 [05:35<00:13,  6.79s/it]

-724668.412462427


 98%|█████████▊| 49/50 [05:41<00:06,  6.80s/it]

-724569.2149325744


100%|██████████| 50/50 [05:48<00:00,  6.97s/it]


-724474.7455634095


  2%|▏         | 1/50 [00:03<03:14,  3.97s/it]

-767608.2321981466


  4%|▍         | 2/50 [00:07<03:09,  3.95s/it]

-765377.7674529199


  6%|▌         | 3/50 [00:11<03:05,  3.94s/it]

-763882.1919322435


  8%|▊         | 4/50 [00:15<03:01,  3.94s/it]

-762705.3013176785


 10%|█         | 5/50 [00:19<02:57,  3.95s/it]

-761673.3841480523


 12%|█▏        | 6/50 [00:23<02:54,  3.97s/it]

-760698.8055510157


 14%|█▍        | 7/50 [00:27<02:51,  3.98s/it]

-759752.8425229634


 16%|█▌        | 8/50 [00:31<02:46,  3.96s/it]

-758847.1382295989


 18%|█▊        | 9/50 [00:35<02:42,  3.97s/it]

-758004.2006120846


 20%|██        | 10/50 [00:39<02:38,  3.95s/it]

-757220.1400255265


 22%|██▏       | 11/50 [00:43<02:33,  3.94s/it]

-756491.0974946391


 24%|██▍       | 12/50 [00:47<02:28,  3.91s/it]

-755816.8808793215


 26%|██▌       | 13/50 [00:51<02:23,  3.89s/it]

-755184.068866666


 28%|██▊       | 14/50 [00:55<02:19,  3.88s/it]

-754594.0394474141


 30%|███       | 15/50 [00:58<02:15,  3.87s/it]

-754050.6300649125


 32%|███▏      | 16/50 [01:02<02:11,  3.88s/it]

-753548.2850973175


 34%|███▍      | 17/50 [01:06<02:08,  3.89s/it]

-753089.2214372295


 36%|███▌      | 18/50 [01:10<02:04,  3.88s/it]

-752668.1938240724


 38%|███▊      | 19/50 [01:14<02:00,  3.87s/it]

-752276.413894375


 40%|████      | 20/50 [01:18<01:56,  3.87s/it]

-751914.4932278455


 42%|████▏     | 21/50 [01:22<01:52,  3.86s/it]

-751579.4992057242


 44%|████▍     | 22/50 [01:25<01:47,  3.86s/it]

-751270.4925748648


 46%|████▌     | 23/50 [01:29<01:44,  3.88s/it]

-750986.0869534389


 48%|████▊     | 24/50 [01:33<01:41,  3.90s/it]

-750721.2707480454


 50%|█████     | 25/50 [01:37<01:37,  3.91s/it]

-750477.5378257531


 52%|█████▏    | 26/50 [01:41<01:33,  3.91s/it]

-750250.6300165425


 54%|█████▍    | 27/50 [01:45<01:29,  3.91s/it]

-750040.015889852


 56%|█████▌    | 28/50 [01:49<01:26,  3.92s/it]

-749843.5246250454


 58%|█████▊    | 29/50 [01:53<01:22,  3.93s/it]

-749660.7535513127


 60%|██████    | 30/50 [01:57<01:18,  3.93s/it]

-749489.5999544943


 62%|██████▏   | 31/50 [02:01<01:14,  3.93s/it]

-749325.499486087


 64%|██████▍   | 32/50 [02:05<01:10,  3.94s/it]

-749167.7054239266


 66%|██████▌   | 33/50 [02:09<01:07,  3.95s/it]

-749016.495946375


 68%|██████▊   | 34/50 [02:13<01:03,  3.95s/it]

-748872.2808968942


 70%|███████   | 35/50 [02:17<00:59,  3.95s/it]

-748736.189313847


 72%|███████▏  | 36/50 [02:21<00:55,  3.94s/it]

-748605.6714126664


 74%|███████▍  | 37/50 [02:25<00:51,  3.94s/it]

-748480.0486681297


 76%|███████▌  | 38/50 [02:28<00:47,  3.94s/it]

-748360.4883907482


 78%|███████▊  | 39/50 [02:32<00:43,  3.93s/it]

-748246.8692874499


 80%|████████  | 40/50 [02:36<00:39,  3.94s/it]

-748138.7204182569


 82%|████████▏ | 41/50 [02:40<00:35,  3.94s/it]

-748038.3162581564


 84%|████████▍ | 42/50 [02:44<00:31,  3.93s/it]

-747943.2430506564


 86%|████████▌ | 43/50 [02:48<00:27,  3.95s/it]

-747854.4657955314


 88%|████████▊ | 44/50 [02:52<00:23,  4.00s/it]

-747769.1778148944


 90%|█████████ | 45/50 [02:56<00:19,  3.99s/it]

-747683.9387292953


 92%|█████████▏| 46/50 [03:00<00:15,  3.98s/it]

-747601.085343063


 94%|█████████▍| 47/50 [03:04<00:11,  3.98s/it]

-747522.3393825312


 96%|█████████▌| 48/50 [03:08<00:07,  3.97s/it]

-747450.4464701323


 98%|█████████▊| 49/50 [03:12<00:03,  3.97s/it]

-747381.6747788048


100%|██████████| 50/50 [03:16<00:00,  3.93s/it]


-747315.2178497462


  2%|▏         | 1/50 [00:04<04:04,  5.00s/it]

-760350.1340762714


  4%|▍         | 2/50 [00:10<04:02,  5.05s/it]

-756526.7504435282


  6%|▌         | 3/50 [00:15<03:58,  5.08s/it]

-752725.0548325927


  8%|▊         | 4/50 [00:20<03:57,  5.15s/it]

-748704.889593527


 10%|█         | 5/50 [00:25<03:51,  5.15s/it]

-744459.5342834567


 12%|█▏        | 6/50 [00:30<03:46,  5.14s/it]

-740088.5070222928


 14%|█▍        | 7/50 [00:35<03:40,  5.14s/it]

-735774.7302416003


 16%|█▌        | 8/50 [00:40<03:35,  5.13s/it]

-731773.1170427885


 18%|█▊        | 9/50 [00:46<03:30,  5.14s/it]

-728278.06189267


 20%|██        | 10/50 [00:51<03:26,  5.15s/it]

-725304.8591369068


 22%|██▏       | 11/50 [00:56<03:21,  5.16s/it]

-722784.6913446302


 24%|██▍       | 12/50 [01:01<03:15,  5.15s/it]

-720637.870259904


 26%|██▌       | 13/50 [01:06<03:10,  5.15s/it]

-718788.4756624596


 28%|██▊       | 14/50 [01:11<03:05,  5.15s/it]

-717176.5219361366


 30%|███       | 15/50 [01:17<03:00,  5.15s/it]

-715779.6395930286


 32%|███▏      | 16/50 [01:22<02:55,  5.15s/it]

-714588.962145405


 34%|███▍      | 17/50 [01:27<02:49,  5.15s/it]

-713567.8547203878


 36%|███▌      | 18/50 [01:32<02:44,  5.14s/it]

-712675.117755897


 38%|███▊      | 19/50 [01:37<02:39,  5.16s/it]

-711894.8968258689


 40%|████      | 20/50 [01:42<02:34,  5.16s/it]

-711222.612797922


 42%|████▏     | 21/50 [01:47<02:28,  5.12s/it]

-710637.4228078998


 44%|████▍     | 22/50 [01:52<02:22,  5.09s/it]

-710131.3925845827


 46%|████▌     | 23/50 [01:57<02:16,  5.07s/it]

-709702.5256306367


 48%|████▊     | 24/50 [02:03<02:12,  5.08s/it]

-709325.731186501


 50%|█████     | 25/50 [02:08<02:07,  5.09s/it]

-708988.1328585112


 52%|█████▏    | 26/50 [02:13<02:02,  5.11s/it]

-708700.0335602491


 54%|█████▍    | 27/50 [02:18<01:57,  5.11s/it]

-708445.4311261054


 56%|█████▌    | 28/50 [02:23<01:54,  5.22s/it]

-708230.1211636482


 58%|█████▊    | 29/50 [02:29<01:52,  5.34s/it]

-708046.2163024768


 60%|██████    | 30/50 [02:34<01:45,  5.30s/it]

-707877.8433362263


 62%|██████▏   | 31/50 [02:39<01:39,  5.26s/it]

-707722.0355755333


 64%|██████▍   | 32/50 [02:44<01:33,  5.22s/it]

-707579.1062346586


 66%|██████▌   | 33/50 [02:50<01:28,  5.19s/it]

-707450.1760949215


 68%|██████▊   | 34/50 [02:55<01:22,  5.18s/it]

-707340.0243891897


 70%|███████   | 35/50 [03:00<01:17,  5.17s/it]

-707239.9602630842


 72%|███████▏  | 36/50 [03:05<01:12,  5.15s/it]

-707144.0410656402


 74%|███████▍  | 37/50 [03:10<01:06,  5.11s/it]

-707055.0321082624


 76%|███████▌  | 38/50 [03:15<01:01,  5.13s/it]

-706971.5322585755


 78%|███████▊  | 39/50 [03:20<00:56,  5.13s/it]

-706891.9549409356


 80%|████████  | 40/50 [03:25<00:51,  5.13s/it]

-706819.0480410432


 82%|████████▏ | 41/50 [03:31<00:46,  5.13s/it]

-706752.7595819781


 84%|████████▍ | 42/50 [03:36<00:41,  5.14s/it]

-706688.389962347


 86%|████████▌ | 43/50 [03:41<00:35,  5.14s/it]

-706626.0719966731


 88%|████████▊ | 44/50 [03:46<00:30,  5.14s/it]

-706570.2465405331


 90%|█████████ | 45/50 [03:51<00:25,  5.12s/it]

-706519.6811167971


 92%|█████████▏| 46/50 [03:56<00:20,  5.12s/it]

-706473.3542994676


 94%|█████████▍| 47/50 [04:01<00:15,  5.12s/it]

-706432.8744659355


 96%|█████████▌| 48/50 [04:07<00:10,  5.14s/it]

-706396.1969099827


 98%|█████████▊| 49/50 [04:12<00:05,  5.10s/it]

-706361.418810081


100%|██████████| 50/50 [04:17<00:00,  5.14s/it]


-706327.1441258165


  2%|▏         | 1/50 [00:05<04:12,  5.14s/it]

-813375.6647963491


  4%|▍         | 2/50 [00:10<04:06,  5.14s/it]

-810540.9587825945


  6%|▌         | 3/50 [00:15<04:04,  5.20s/it]

-808291.2897658278


  8%|▊         | 4/50 [00:20<03:59,  5.20s/it]

-806153.6045794714


 10%|█         | 5/50 [00:25<03:52,  5.18s/it]

-803985.0639737338


 12%|█▏        | 6/50 [00:31<03:48,  5.19s/it]

-801745.3278146258


 14%|█▍        | 7/50 [00:36<03:42,  5.18s/it]

-799461.1990185921


 16%|█▌        | 8/50 [00:41<03:37,  5.17s/it]

-797231.1490026374


 18%|█▊        | 9/50 [00:46<03:31,  5.16s/it]

-795185.149832079


 20%|██        | 10/50 [00:51<03:26,  5.16s/it]

-793364.6740703956


 22%|██▏       | 11/50 [00:56<03:22,  5.18s/it]

-791710.0272196694


 24%|██▍       | 12/50 [01:02<03:16,  5.18s/it]

-790158.4030438662


 26%|██▌       | 13/50 [01:07<03:11,  5.17s/it]

-788691.6229166107


 28%|██▊       | 14/50 [01:12<03:05,  5.16s/it]

-787322.9297071744


 30%|███       | 15/50 [01:17<03:01,  5.19s/it]

-786067.9143486462


 32%|███▏      | 16/50 [01:22<02:56,  5.18s/it]

-784926.7966194162


 34%|███▍      | 17/50 [01:27<02:50,  5.17s/it]

-783887.9444297304


 36%|███▌      | 18/50 [01:33<02:45,  5.17s/it]

-782936.6804473435


 38%|███▊      | 19/50 [01:38<02:40,  5.18s/it]

-782058.279173585


 40%|████      | 20/50 [01:43<02:35,  5.18s/it]

-781238.5469160123


 42%|████▏     | 21/50 [01:48<02:29,  5.17s/it]

-780466.2404947378


 44%|████▍     | 22/50 [01:53<02:24,  5.17s/it]

-779736.9671590722


 46%|████▌     | 23/50 [01:58<02:19,  5.16s/it]

-779053.8943695151


 48%|████▊     | 24/50 [02:04<02:14,  5.16s/it]

-778422.1831398056


 50%|█████     | 25/50 [02:09<02:08,  5.16s/it]

-777843.0307490318


 52%|█████▏    | 26/50 [02:14<02:03,  5.15s/it]

-777313.9218374353


 54%|█████▍    | 27/50 [02:19<01:58,  5.15s/it]

-776830.6944368859


 56%|█████▌    | 28/50 [02:24<01:53,  5.14s/it]

-776386.7602557649


 58%|█████▊    | 29/50 [02:29<01:48,  5.18s/it]

-775974.2726110601


 60%|██████    | 30/50 [02:35<01:43,  5.18s/it]

-775586.5678867722


 62%|██████▏   | 31/50 [02:40<01:38,  5.17s/it]

-775218.5714397385


 64%|██████▍   | 32/50 [02:45<01:32,  5.16s/it]

-774866.646014906


 66%|██████▌   | 33/50 [02:50<01:27,  5.17s/it]

-774528.509606466


 68%|██████▊   | 34/50 [02:55<01:22,  5.17s/it]

-774203.0424319927


 70%|███████   | 35/50 [03:00<01:17,  5.16s/it]

-773890.2576526124


 72%|███████▏  | 36/50 [03:05<01:12,  5.14s/it]

-773591.2386978132


 74%|███████▍  | 37/50 [03:11<01:06,  5.14s/it]

-773307.4259277761


 76%|███████▌  | 38/50 [03:16<01:01,  5.15s/it]

-773039.5248473041


 78%|███████▊  | 39/50 [03:21<00:56,  5.15s/it]

-772787.1398694022


 80%|████████  | 40/50 [03:26<00:51,  5.15s/it]

-772549.1307105881


 82%|████████▏ | 41/50 [03:31<00:46,  5.14s/it]

-772324.0534311627


 84%|████████▍ | 42/50 [03:36<00:40,  5.12s/it]

-772110.4768902421


 86%|████████▌ | 43/50 [03:41<00:36,  5.15s/it]

-771907.0350898309


 88%|████████▊ | 44/50 [03:47<00:30,  5.15s/it]

-771712.418406833


 90%|█████████ | 45/50 [03:52<00:25,  5.12s/it]

-771525.6739968936


 92%|█████████▏| 46/50 [03:57<00:20,  5.12s/it]

-771346.622076085


 94%|█████████▍| 47/50 [04:02<00:15,  5.18s/it]

-771175.752180179


 96%|█████████▌| 48/50 [04:07<00:10,  5.14s/it]

-771013.736077374


 98%|█████████▊| 49/50 [04:12<00:05,  5.15s/it]

-770860.9526772839


100%|██████████| 50/50 [04:18<00:00,  5.16s/it]

-770717.2727594622





In [114]:
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(len(elbo_list)), y=elbo_list,
                    mode='lines',
                    name='lines'))
                    
fig.add_trace(go.Scatter(x=np.arange(len(elbo_list2)), y=elbo_list2,mode='lines',name='lines'))
fig.add_trace(go.Scatter(x=np.arange(len(elbo_list3)), y=elbo_list3,mode='lines',name='lines'))
fig.add_trace(go.Scatter(x=np.arange(len(elbo_list4)), y=elbo_list4,mode='lines',name='lines'))
fig.add_trace(go.Scatter(x=np.arange(len(elbo_list5)), y=elbo_list5,mode='lines',name='lines'))
fig.update_layout(title='ELBO',xaxis_title='iteration',yaxis_title='ELBO')


fig.show()

**Q4.** Based on the MMSE estimates :
* What are the top-10 words per topic ? With your machine learning knowledge, can you make sense of some of the topics ?
* Choose one document at random and display its topic proportions. Comment.

In [115]:
#top 10 words for each topic
def top_10_words(L):
    """
    Return the top 10 words with highest probability for each topic.

    :param L: a KxV matrix where L[k, v] is the probability of word v in topic k.
    :return: A list of K elements, each a list of 10 words with highest probability for topic k.
    """
    K, V = L.shape
    top_words = []

    for k in range(K):
        top_word_indices = np.argsort(L[k, :])[-10:][::-1]
        top_word_probs = L[k, top_word_indices]
        top_words.append(list(zip(top_word_indices, top_word_probs)))

    return top_words

z=top_10_words(lambda_mmse)
print(z)



[[(69, 334.2903942012143), (264, 260.42609744003994), (115, 246.68218378515556), (763, 180.047649705949), (30, 173.86777296552324), (438, 170.47760076270353), (799, 159.94230556353813), (87, 140.78082034201512), (242, 134.47383840357037), (580, 131.59767665673525)], [(486, 324.98935424655934), (594, 308.5404264874659), (277, 299.76485385990554), (351, 179.3685605407065), (655, 176.25391624930893), (193, 142.77556033136932), (85, 133.3922403178852), (775, 131.96989052390347), (822, 127.17397587221924), (474, 119.19054716885023)], [(871, 574.0845754162134), (881, 448.15523117604243), (380, 264.4967879994341), (30, 250.22818015721379), (930, 152.49713104337062), (486, 139.49621791329506), (382, 122.32349428567679), (501, 110.8232232470101), (84, 101.25343334385143), (915, 99.93599662973713)], [(528, 407.6103089902703), (287, 192.6540877983804), (721, 182.85810364242062), (286, 182.32367238016494), (85, 146.7515370800831), (775, 135.88825208211793), (476, 134.22063500679954), (588, 118.901

----- Your answer here -----

**Q5.** Open questions :

- ###  What are some limitations of the LDA model ? Can you imagine an improvement ?
> Les limites visibiles du modèle LDA sont : 
> >Comme nous l'avons dit, le modèle LDA ne prends pas en compte l'ordre des mots dans ses Bags of Words(BOW). Cette simplification peut entraîner une perte d'informations, notamment sur le contexte et la nuance des propos.
> >
> >La LDA semble être un modèle coûteux en termes de calcul, surtout pour de grands ensembles de données. L'inférence peut être lente, et le modèle peut également être sensible aux paramètres d'initialisation, notamment le nombre de sujets K dont le choix n'est pas trivial et peut nécessiter une approche d'essai et d'erreur.
>
> Comme amélioration on pourrait suggérer : 
> > Utiliser des méthodes d'inférence plus efficaces/avancées, comme les réseaux de neurones pour capturer des relations plus complexes dans les données textuelles.
>>
> > Introduire des hiérarchies dans les sujets pour une représentation plus riche et plus nuancée des données.
> >

- ###  In this notebook, we have treated the hyperparameters as fixed. How could they be learned ?

- ### Can you imagine a method to choose the number of topics ?

- ### What strategies should we use to make the algorithm more efficient ?

**BONUS.** Papier-crayon. À partir du modèle, pouvez-vous dériver les lois conditionnelles de l'échantillonneur de Gibbs ? Pour rappel, nous avons besoin de ces lois pour dériver ensuite les updates de l'algorithme CAVI.