-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matchbox recommender message 7 #428
Comments
Message (7) is implemented in infer/src/Runtime/Factors/Product.cs#AAverageLogarithm. Messages (2) and (7) come from Variational Message Passing section 4.2. A concise form of those messages can be found in Table 1 of Gates: A Graphical Notation for Mixture Models. |
Hi and thanks Tom. The reference has been very helpful. It allowed us to verify that the exact message 7 that we calculated was in fact correct, since when we apply the AverageLog to it we get almost exactly the approximate Gaussian proposed in the MatchBox paper. We know that the variational approximations should minimize the KL divergence with respect to the exact message. However, the proposed Gaussian is a good approximation of the transformation of the exact message, but not of the exact message. The distance between the approximate message and the exact is very high, indeed. For example, in the following figure we can see the 3 distributions, the exact message (blue), the transformation of the exact message (orange) and the approximate Gaussian proposed in the MatchBox paper (green). Why then are we using this approximation that do not minimize de KL divergence with respect to the exact message? Even if applying the inverse of the AverageLog to it would give us a good approximation to the exact message, we would be back to the original problem, an intractable distribution. What am I missing? Thank you for all your contribution to modern Bayesian inference. AppendixThe code to replicate the figure.
|
The reason for using this approximation is in section 3.1.2 of the paper you linked. |
I still don't understand why the approximate message 7 does not minimize the KL divergence with respect to the exact message 7. |
It minimizes KL divergence with the arguments swapped. Is your issue that the code doesn't do what the paper says, or that you disagree with what the paper says to do? |
Our original goal was to implement the matchbox model from scratch as an exercise to learn as much as possible from the methods created by you. Matchbox is particularly interesting because it offers an analytical approximation (the proposed message 2 and 7 in the paper) to a common problem, the multiplication of two Gaussian variables. The issue is that during the implementation process, we found that the approximate message 7 proposed by the matchbox paper does not minimize the reverse KL divergence with respect to the exact message 7. Since we couldn't find the error in our calculations, we decided to examine the original code implemented by you. Thanks to your initial response, we were able to verify that the exact message 7 calculated by us was correct, as when we apply the AverageLog to it, we obtain exactly the approximate Gaussian proposed in the MatchBox paper. Now the question is why does the approximate message 7 proposed by the paper not minimize the reverse KL divergence? (at least as indicated by our calculations) Exact analytic messageThe following is the exact analytic message. Each message 7 represents an integral of the joint distribution below the isoline defined by s_k. In the following images, we present the joint distribution with four isolines on the left side, and the corresponding areas below those isolines on the right side. Collectively, all of these integrals create the likelihood that sends the exact message 7 to the latent variable The reverse KL divergenceTo evaluate the reverse KL divergence, we implemented a simple numerical integration. For example, when approximating a mixture of Gaussians with a unimodal Gaussian using reverse KL divergence, we obtained the following result. There are two local minima, one for each peak. And there is a single global minimum, corresponding to the highest peak.
Reverse KL divergence between the approximate and the exact messagesThe exact message 7 has a similar structure. However, unlike the mixture of Gaussians, the tails of the exact message 7 are extremely wide, which requires the reverse KL minimization approximation to find a very different compromise compared to the example of the mixture of Gaussians.
It seems that the width of the tails has a more significant impact than the two peaks, which, ultimately, are not that far apart. Thanks for allWe provided this explanation because we are genuinely interested in understanding the details of the valuable methodology that you have developed. We will always be indebted to you. Really thanks @tminka |
Please write out the formula that you think should be minimized. I suspect there is some confusion about what "minimize the KL divergence" actually means. |
The mathematical definition of the KL-Divergence is, with In the previous message I minimize both KL divergence in a well known example. In the following figure (already shown in the previous message), the red Gaussian distribution minimize the reverse KL divergence (within the Gaussian family) with respect to the blue distribution (a mixture of Gaussian),
Here is the code that reproduce the left-side figure, import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm # Gaussian distribution
from scipy.special import rel_entr # KL-Divergence
from scipy.optimize import basinhopping # Optimization algorithm
def Dkl(p, q):
return rel_entr(p,q)
# Grid
step = 0.025
x_grid = np.arange(-6,6,step)
# True distribution and their approximation
p_mixture = norm.pdf(x_grid,2,0.5)*0.6 + norm.pdf(x_grid,-2,0.5)*0.4
def objetive_reverse_KL(mean_sd):
# Deriving discrete analogues of continuous distribution
P = p_mixture * step # high_X_base[] vector
Q = norm.pdf(x_grid,mean_sd[0],mean_sd[1]) * step
return sum(Dkl(Q,P))
muMix_QP, sigmaMix_QP = basinhopping(objetive_reverse_KL, [2,0.5])["x"]
fig = plt.figure()
plt.plot(x_grid,p_mixture )
plt.plot(x_grid,norm.pdf(x_grid,muMix_QP,sigmaMix_QP))
plt.show() We can verify that the optimized mean and standard deviation are correct since it is the global minima of the right-side figure, where we plot the value of the reverse KL-Divergence for each combinations of mean and standard deviation. Exactly the same procedure was executed to compute the reverse KL Divergence between the exact distribution (known as message 7 in MatchBox paper) and the approximated distribution proposed in the paper and implemented by As a reminder, the factor graph of MatchBox is, Then, the exact distribution can be computed as, and the proposed distribution in the paper is, where In the following figure we plot the exact distribution (blue) and the approximated distribution (green). With the naked eye we can see that the approximated distribution is not close to the exact distribution. Here is the code to reproduce the figure import matplotlib.pyplot as plt
from scipy.stats import norm
import math
import numpy as np
# Mean and SD of message 1 and message 6 at the MatchBox factor graph
mtk = 1; sdtk = 2 # Message 1 (equivalent to the prior of variable t)
m6k = 2; sd6k = 1 # Message 6
# Grid
step = 0.025
tk = np.arange(-6,6.1, step); tk = list(tk) # grid for t_k variable
zk = np.arange(-6,6.1, step); zk = list(zk) # grid for z_k variable
Sk = np.arange(-9,9+step, step) # grid for s_k variable
# The proposed approximation
message_7_approx = norm.pdf(Sk, mtk*m6k /(sdtk**2 + mtk**2), sd6k /np.sqrt((sdtk**2 + mtk**2)))
# The exact message 7 (an integral for each value of s_k)
message_7_exact = []
for s in Sk:
integral = sum(norm.pdf(s*np.array(tk),m6k,sd6k)*norm.pdf(np.array(tk),mtk,sdtk))*step
message_7_exact.append(integral)
message_7_exact = message_7_exact/(sum(message_7_exact)*step) # Normalization
# Plot
fig = plt.figure()
plt.plot(Sk, message_7_exact, label="Exact m7")
plt.plot(Sk, message_7_approx, label="Approx m7")
plt.legend(loc="upper left",fontsize=10)
plt.show() After computing the reverse KL-divergence for all possible combinations of mean and standard deviation,
Here is the code to reproduce the figure import matplotlib.pyplot as plt
from scipy.stats import norm
import math
import numpy as np
from scipy.special import rel_entr # KL-Divergence
# Mean and SD of message 1 and message 6 at the MatchBox factor graph
mtk = 1; sdtk = 2 # Message 1 (equivalent to the prior of variable t)
m6k = 2; sd6k = 1 # Message 6
# Grid for variable t and variable s
step_long=0.05; Sk_long = np.arange(-100,100+0.05, step_long); tk_long = Sk_long
# Exact distribution (message 7 in MatchBox factor graph)
message_7_exact_long = []
for s in Sk_long:
# Message7(s_k) =
integral = sum(norm.pdf(s*np.array(tk_long),m6k,sd6k) * norm.pdf(np.array(tk_long),mtk,sdtk))*step_long
message_7_exact_long.append(integral)
message_7_exact_long = message_7_exact_long/(sum(message_7_exact_long)*step_long)
# Compute the KL Divergence for a range of MU and SD
mus = np.arange(-15,15.25,0.25)
sigmas = np.arange(1,48,1)
muXsigma = []; min_divergence = np.inf;
for m in mus:
muXsigma.append([])
for s in sigmas:
P = message_7_exact_long * step_long # Deriving discrete analogues (exact dist)
Q = norm.pdf(Sk_long,m,s) * step_long # Deriving discrete analogues (approx dist)
Dkl_value = sum(rel_entr(Q,P)) # Compute numerical reverse KL Divergence
muXsigma[-1].append( Dkl_value )
# Follow the distribution with minimum reverse KL divergence
if muXsigma[-1][-1] < min_divergence:
min_divergence = muXsigma[-1][-1]
argmin = (m, s)
fig = plt.figure()
plt.imshow(muXsigma, interpolation='nearest', extent = (np.min(sigmas), np.max(sigmas), np.min(mus), np.max(mus) ) )
contours = plt.contour(sigmas, mus, muXsigma, 10, colors='black')
plt.clabel(contours, inline=1, fontsize=10)
plt.axhline(y=0, color="gray")
plt.scatter(argmin[1],argmin[0], color="black" )
plt.xlabel("Sigma",fontsize=14)
plt.ylabel("Mu",fontsize=14)
plt.show() Even the true distribution has two modes, in this example the tails are extremely wide (unlike the mixture of Gaussians example). This explains why the distribution that minimizes the reverse KL divergence do not select one mode, as in the mixture of Gaussians example. We'd greatly appreciate any comment, since we are genuinely interested in understanding the details of the methodology that you have developed. If we've made a mistake somewhere along the way, we'd love to understand it. I fixed some minor errors, but I don't found any major mistake. We are not obsessed with the fact that the proposed approximated distribution do not minimize the reverse KL Divergence (I am not a mathematician, and I even don't know how to derive an analytical formula for the approximated distribution). However, Macarena Piaggio (@macabot-sh) has replicated the MatchBox tutorial using exactly the same code available in the |
The issue is that this is not what "minimizing the KL divergence" means. The messages in EP and VMP minimize KL divergence of the approximate posterior to the true posterior (the thing that actually matters), not the KL divergence between the "exact message" (as you call it) and the approximate message. Messages are not distributions so it doesn't even make sense to ask for their KL divergence. |
Yes, you are right, the message 8 (or 7) are likelihood so they are not distributions. It is my homework to verify the KL divergence correctly. Thank you very much. As we mentioned before, we have tried to replicate the output of MatchBox tutorial using the (source code available at Infer.Net), modifying as few lines as possible ( |
Hello again, Here we will follow your technical report "Divergence measures and message passing" [1], rephrasing it as close as possible. To recap, we want the best approximation First, we must write This defines the preferred factorization, which is not unique. Then, each factor will be approximated by a member of a family Here we want factors The best approximation of each factor depends on the rest of the network, giving a chicken-and-egg problem, which it is solved by an iterative message-passing procedure: each factor sends its best approximation to the rest of the net, and then recomputes its best approximation based on the messages it receives. To make this tractable, we must assume that the approximations we've already made Because Individual messages need not be normalized, and need not be proper distributions. In a fully-factorized context we only need to propagate messages between a factor and their variables. In this issue we are discussing if the approximation of message In particular, at each iteration of the message-passing procedure, we know that the true posterior of variable In the collaborative filtering version of matchbox, message 1 for variable where We can compute a good numerical approximation of exact message 7 using simple quadrature procedure based on a dense grid. where where import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from scipy.special import kl_div # KL-Divergence
from scipy.optimize import basinhopping # For search global minimum
mu_uki = 1; sd_uki = 2 # Prior of U_ki
mu_tk = 1; sd_tk = 2 # Message 1 for tk
mu_6k = 2; sd_6k = 1 # Message 6
# Grid
step = 0.025
grid = list(np.arange(-19,19+step, step)); tk = grid; sk = grid; uki = grid;
# The proposed approximate message 7
mu_7sk = (mu_tk*mu_6k) / (sd_tk**2 + mu_tk**2)
sd_7sk = np.sqrt( (sd_6k**2) / (sd_tk**2 + mu_tk**2) )
message_7_approx = norm.pdf(sk, mu_7sk, sd_7sk )
## Exact message 7 (an integral for each value of s_k)
message_7_exact = []
for s in sk:
integral = sum(norm.pdf(s*np.array(tk),mu_6k,sd_6k)*norm.pdf(np.array(tk),mu_tk,sd_tk))*step
message_7_exact.append(integral)
# Exact message 1 (same as approx message 1)
message_1 = norm.pdf(sk, mu_uki, sd_uki)
# Posteriors (exact and approx)
exact_marginal = np.array(message_1) * np.array(message_7_exact)
exact_posterior = exact_marginal/sum(exact_marginal*step)
approx_marginal = np.array(message_1) * message_7_approx
approx_posterior = approx_marginal/sum(approx_marginal*step)
# Inverse KL minimization
def objetive_QP(mu_sd):
return sum(kl_div(norm.pdf(sk,*mu_sd)*step, exact_posterior*step))
mu_minQP, sd_minQP = basinhopping(objetive_QP, [2,0.5])["x"]
# Plot
plt.plot(sk, exact_posterior, label="Exact posterior for sk")
plt.plot(sk, approx_posterior, label="MatchBox posterior for sk")
plt.plot(sk, norm.pdf(sk, mu_minQP, sd_minQP), label="Min KL(N($\mu,\sigma^2$)||Exact)")
plt.legend(loc="upper left", title="First iteration", fontsize=9)
plt.show() Please let us know if we've made any wrong assumptions along the way. Thank you for your help. [1] T. Minka. Divergence measures and message passing. |
Message 7 doesn't minimize that objective either. It seems that all of your confusion comes from a misunderstanding of the objective of Variational Message Passing. The objective is explained in the VMP paper that I referenced in my first reply. Questions about the mathematics of VMP should probably be asked on Cross-Validated rather than here. |
Hi everyone,
I've been looking for where message 7 (m⁎->s as labelled in the Matchbox paper [1] factor graph) is implemented in the infer.net code, but I can't seem to find it. As a study excercise, me and @glandfried have been trying to implement the collaborative filtering version of the Matchbox recommender algorithm, as described in the Matchbox paper. We have been able to implement almost everything. However, our implementation of approximate message 7 differs greatly from the exact version of the message, the bimodal one. In contrast, we had no issues with the approximate message 2.
If anyone could help me find the infer.net implementation of the message so we can compare I would appreciate it. So far I could only find approximate message 2 (m⁎->z) at
infer/src/Runtime/Factors/Product.cs#ProductAverageLogarithm
.As a reminder, I copy here the factor graph,
and the approximation equations (approximate messages 2 and 7 respectively),
From the original Matchbox paper. I interpret this as,
(μt and σ2t denote the mean and variance of the (Gaussian) marginal distribution p(t). μz ->⁎ σ2z ->⁎ denote the mean and variance of message 6).
It would also be nice to get a hint to derive the approximations for 2 and 7 on our own (or a reference).
Thanks in advance!
[1] Matchbox Paper: https://www.microsoft.com/en-us/research/publication/matchbox-large-scale-bayesian-recommendations/
The text was updated successfully, but these errors were encountered: