Skip to content

Commit

Permalink
[feat]: add EPIG implementation in active heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
reeshipaul committed Jul 19, 2023
1 parent b96d09b commit 73259ca
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions baal/active/heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,70 @@ def reorder_indices(self, scores_list):
ranks = ranks[::-1]
ranks = _shuffle_subset(ranks, self.shuffle_prop)
return ranks

class EPIG(AbstractHeuristic):
"""
Implementation of Expected Predicted Information Gain
https://arxiv.org/abs/2304.08151
References:
Code from https://github.com/fbickfordsmith/epig
"""

def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction)

def _conditional_epig_from_probs(self, predictions, targets):
# converting to Tensor
probs_pool = torch.Tensor(predictions)
probs_target = torch.Tensor(targets)
# Estimate the joint predictive distribution.
probs_pool = probs_pool.permute(1, 0, 2) # [K, N_p, Cl]
probs_targ = probs_targ.permute(1, 0, 2) # [K, N_t, Cl]
probs_pool = probs_pool[:, :, None, :, None] # [K, N_p, 1, Cl, 1]
probs_targ = probs_targ[:, None, :, None, :] # [K, 1, N_t, 1, Cl]
probs_pool_targ_joint = probs_pool * probs_targ # [K, N_p, N_t, Cl, Cl]
probs_pool_targ_joint = torch.mean(probs_pool_targ_joint, dim=0) # [N_p, N_t, Cl, Cl]

# Estimate the marginal predictive distributions.
probs_pool = torch.mean(probs_pool, dim=0) # [N_p, 1, Cl, 1]
probs_targ = torch.mean(probs_targ, dim=0) # [1, N_t, 1, Cl]

# Estimate the product of the marginal predictive distributions.
probs_pool_targ_indep = probs_pool * probs_targ # [N_p, N_t, Cl, Cl]

# Estimate the conditional expected predictive information gain for each pair of examples.
# This is the KL divergence between probs_pool_targ_joint and probs_pool_targ_joint_indep.
nonzero_joint = probs_pool_targ_joint > 0 # [N_p, N_t, Cl, Cl]
log_term = torch.clone(probs_pool_targ_joint) # [N_p, N_t, Cl, Cl]
log_term[nonzero_joint] = torch.log(
probs_pool_targ_joint[nonzero_joint]
) # [N_p, N_t, Cl, Cl]
log_term[nonzero_joint] -= torch.log(
probs_pool_targ_indep[nonzero_joint]
) # [N_p, N_t, Cl, Cl]
scores = torch.sum(probs_pool_targ_joint * log_term, dim=(-2, -1)) # [N_p, N_t]
return scores # [N_p, N_t]

def compute_score(self, predictions, targets):
"""
Compute the score according to the heuristic.
Args:
predictions (ndarray): Array of predictions
targets (ndarray): Array of targets
Returns:
Array of scores.
"""
assert predictions.ndim >= 3
assert targets.ndim == predictions.ndim

scores = self._conditional_epig_from_probs(predictions, targets) # [N_p, N_t]
epig_scores = torch.mean(scores, dim=-1) # [N_p,]
return np.array(epig_scores)

0 comments on commit 73259ca

Please sign in to comment.