In [None]:
import torch
import zoib
import numpy as np

In [None]:
zoib_fit = zoib.ZOIBeta(p=torch.tensor([.5]), 
               q=torch.tensor([.3]), 
               concentration1=torch.tensor([1.]), 
               concentration0=torch.tensor([1.]))

In [None]:
assert zoib_fit.log_prob(torch.tensor([0.])).item() == -0.6931471824645996

In [None]:
assert zoib_fit.log_prob(torch.tensor([1.])).item() == -1.8971199989318848

In [None]:
assert zoib_fit.log_prob(torch.tensor([0.4])).item() == -1.0498220920562744

In [None]:
assert zoib_fit.log_prob(torch.tensor([.9])).item() == -1.0498220920562744

In [None]:
assert zoib_fit.log_prob(torch.tensor([.2])).item() == -1.0498220920562744

In [None]:
# Test full loss
pred = torch.tensor([[[0.5, 0.3, 1., 1.],[0.5, 0.3, 1., 1.], [0.5, 0.3, 1., 1.], [0.5, 0.3, 1., 1.], [0.5, 0.3, 1., 1.]]])
true_y = torch.tensor([[0., 1., 0.4, 0.9, 0.2]])
log_probs = zoib.zoib_loss(
    pred,
    true_y,# Flatten to (batch_size*timesteps) X 1
    return_mean=False) 
assert log_probs.detach().numpy().all() == np.array([0.6931472, 1.89712  , 1.0497648, 1.0498629, 1.0498054]).all()


mean_loss = zoib.zoib_loss(
    pred,
    true_y,# Flatten to (batch_size*timesteps) X 1
    return_mean=True)
assert mean_loss.item() == 1.1479400396347046

# Real example from training: log_prob ONLY cars about p, but expected val cares about everything

In [None]:
t = [4.8095e-01, 2.3518e-05, 5.6511e+02, 8.6665e+02]

In [None]:
zoib_fit_other = zoib.ZOIBeta(p=torch.tensor([t[0]]), 
               q=torch.tensor([t[1]]), 
               concentration1=torch.tensor([t[2]]), 
               concentration0=torch.tensor([t[3]]))

In [None]:
zoib_fit_other.log_prob(torch.tensor([0]))

In [None]:
def zoib_expected(t):
    # E = q*(1-p) + (1-p-q)*(conc1/(conc1+conc0))
    # Or # = prob_1_given_not0*(1-prob_0) + (1 - prob_bernoulli)*(expect_val_beta)
    t = t.detach().numpy()
    prob_1 = t[:,1]*(1-t[:,0])
    prob_beta = (1 - t[:,0])*(1 - t[:,1])
    beta_expected = t[:,2]/(t[:,3]+t[:,2])
    return prob_1 + prob_beta*beta_expected

In [None]:
zoib_expected(torch.tensor([t]))