Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusMNoack committed Sep 22, 2023
1 parent 5d754b8 commit 1ec275c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions fvgp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,9 @@ def mutual_information(self,joint,m1,m2):
The first marginal distribution
m2 : np.ndarray
The second marginal distribution
Return:
-------
mutual information : float
"""
return self.entropy(m1) + self.entropy(m2) - self.entropy(joint)
###########################################################################
Expand Down Expand Up @@ -1670,7 +1673,8 @@ def gp_mutual_information(self,x_pred, x_out = None):
joint_covariance = \
np.asarray(np.block([[K,k],\
[k.T, kk]]))
return self.mutual_information(joint_covariance, kk, K)
return {"x":x_pred,
"mutual information":self.mutual_information(joint_covariance, kk, K)}

###########################################################################
def gp_total_correlation(self,x_pred, x_out = None):
Expand Down Expand Up @@ -1705,8 +1709,10 @@ def gp_total_correlation(self,x_pred, x_out = None):

prod_covariance = np.asarray(np.block([[K, k * 0.],\
[k.T * 0., kk * np.identity(len(kk))]]))
return self.kl_div(np.zeros((len(joint_covariance))),np.zeros((len(joint_covariance))),joint_covariance,prod_covariance)


return {"x":x_pred,
"total correlation":self.kl_div(np.zeros((len(joint_covariance))),np.zeros((len(joint_covariance))),joint_covariance,prod_covariance)}
###########################################################################
def shannon_information_gain(self, x_pred, x_out = None):
"""
Expand Down Expand Up @@ -1735,7 +1741,7 @@ def shannon_information_gain(self, x_pred, x_out = None):
if x_out is not None: x_pred = self._cartesian_product(x_pred,x_out)

return {"x": x_pred,
"sig":np.exp(-self.gp_total_correlation(x_pred, x_out = None))}
"sig":np.exp(-self.gp_total_correlation(x_pred, x_out = None)["total correlation"])}

###########################################################################
def shannon_information_gain_vec(self, x_pred, x_out = None):
Expand All @@ -1759,7 +1765,7 @@ def shannon_information_gain_vec(self, x_pred, x_out = None):

sig = np.zeros((len(x_pred)))
for i in range(len(x_pred)):
sig[i] = np.exp(-self.gp_mutual_information(x_pred[i].reshape(1,len(x_pred[i])), x_out = None))
sig[i] = np.exp(-self.gp_mutual_information(x_pred[i].reshape(1,len(x_pred[i])), x_out = None)["mutual information"])

return {"x": x_pred,
"sig(x)":sig}
Expand Down

0 comments on commit 1ec275c

Please sign in to comment.