In [105]:
import numpy as np
from scipy.stats import entropy
from scipy.special import rel_entr

class REL_DS:
    def __init__(self):
        self.forecast_bins = np.arange(0.01, 1, 0.01)
        self.data = {p: [] for p in self.forecast_bins}
        self.eps = 1E-8

    def add_pair(self, forecast, observation):
        # Find the closest forecast bin for the given forecast
        bin_idx = np.argmin(np.abs(self.forecast_bins - forecast))
        closest_bin = self.forecast_bins[bin_idx]
        self.data[closest_bin].append(observation)

    def compute_rel_ds(self):
        N = sum([len(observations) for observations in self.data.values()])
        total = 0

        for forecast_prob, observations in self.data.items():
            n_k = len(observations)
            if n_k == 0:
                continue

            # Empirical PMF
            o_bar = [1 - np.mean(observations) + self.eps, np.mean(observations)+self.eps]
            f_k = [self.eps + (1 - forecast_prob), self.eps + forecast_prob]

            # KL divergence calculation
            kl_div = rel_entr(o_bar, f_k).sum()
            total += n_k * kl_div

        return total / N


In [106]:
class RES_DS:
    def __init__(self):
        self.forecast_obs_pmf = {}  # Joint PMF p(f, o)
        self.N = 0  # Total forecasts
        self.overall_o = np.array([0, 0])  # Marginal PMF p(o)
        self.eps = 1E-8  # To prevent log(0)

    def add_pair(self, forecast, observation):
        self.N += 1
        self.overall_o[observation] += 1
        
        if forecast not in self.forecast_obs_pmf:
            self.forecast_obs_pmf[forecast] = np.array([0, 0])
        self.forecast_obs_pmf[forecast][observation] += 1

    def compute_entropy(self, pmf):
        normalized_pmf = pmf / np.sum(pmf)
        entropy = -np.sum(normalized_pmf * np.log(normalized_pmf + self.eps))
        return entropy
    
    def compute_conditional_entropy(self,observations, forecasts):
        # Joint distribution of O and F
        joint_pmf, _, _ = np.histogram2d(observations, forecasts, bins=(len(np.unique(observations)), len(np.unique(forecasts))), density=True)
        
        # Marginal distribution of O
        obs_pmf, _ = np.histogram(observations, bins=len(np.unique(observations)), density=True)
        
        # Marginal distribution of F
        forecast_pmf, _ = np.histogram(forecasts, bins=len(np.unique(forecasts)), density=True)
    
        # Compute conditional entropy using the formula
        conditional_entropy = 0
        for i in range(joint_pmf.shape[0]):
            for j in range(joint_pmf.shape[1]):
                if joint_pmf[i][j] > 0:  # to avoid log(0)
                    conditional_entropy += joint_pmf[i][j] * np.log(forecast_pmf[j] / joint_pmf[i][j])
    
        return conditional_entropy

    def compute_res_ds(self,o,f):
        # Calculate H(O)
        h_o = self.compute_entropy(self.overall_o)
        
        # Calculate H(O|F)
        # h_o_given_f = 0
        # for forecast, obs_pmf in self.forecast_obs_pmf.items():
        #     p_f = np.sum(obs_pmf) / self.N
            # conditional_entropy = self.compute_entropy(obs_pmf)
            # conditional_entropy = entropy(obs_pmf,condition=)
            # conditional_entropy = self.conditional_entropy(obs_pmf,forecast)
            # h_o_given_f += p_f * conditional_entropy
        h_o_given_f = self.compute_conditional_entropy(o,f)
        
        # Calculate RES
        res = (h_o - h_o_given_f)/self.N
        
        return res


In [107]:
def compute_unc_ds(o, eps=1E-8):
    """
    Compute the UNC_DS metric for the overall empirical PMF of observations.
    
    Parameters:
        - o : 1-D vector
    
    
    Returns:
    - unc_ds: The UNC_DS metric.
    """
    # overall_o: The overall empirical PMF of observations (2-element vector)
    overall_o = [1 - np.mean(o) + eps, np.mean(o)+eps]
    
    # Ensure the PMF is normalized
    overall_o_normalized = overall_o / np.sum(overall_o)

    # Compute entropy
    entropy = -np.sum(overall_o_normalized * np.log(overall_o_normalized + eps))
    
    return entropy


In [108]:
def compute_ds(o_series, f_series,eps=1E-8):
    """
    Compute the DS metric for a time series of observed and forecasted probabilities.
    
    Parameters:
    - o_series: A list of observed probabilities (2-element vectors) for each time step.
    - f_series: A list of forecasted probabilities (2-element vectors) for each time step.
    
    Returns:
    - ds: The DS metric.
    """
    assert len(o_series) == len(f_series), "Mismatch in length of observation and forecast series"
    
    N = len(o_series)
    total_kl_divergence = 0.0
    
    for t in range(N):
        # Compute KL divergence for the t-th time step
        # kl_div_t = kl_divergence(o_series[t], f_series[t])
        # print(o_series[t],f_series[t])
        # kl_div_t = entropy(o_series[t]+eps, f_series[t]+eps,2)
        kl_div_t = rel_entr(o_series[t]+eps,f_series[t]+eps)

        total_kl_divergence += kl_div_t

    ds = total_kl_divergence / N
    return ds


In [109]:
def generate_time_series(seed=42,num_samples=100):
    rng = np.random.default_rng(seed=seed)

    o = rng.choice([0, 1], size=num_samples, p=[0.5, 0.5])
    
    f = np.zeros(num_samples)
    
    for i in range(num_samples):
        if o[i] == 1:
            f[i] = rng.choice(np.arange(0.6, 0.98, 0.01))
        else:
            f[i] = rng.choice(np.arange(0.02, 0.45, 0.01))
    return o, f

o, f = generate_time_series()
o = np.round(o, 2)
f = np.round(f, 2)
time_series = list(zip(f, o))
# print(o)
# print(f)

In [110]:
UNC = compute_unc_ds(o)
print(UNC)

0.691346079073826


In [111]:
res_calculator = RES_DS()
for forecast, observation in time_series:  # assuming time_series is a list of forecast/observation pairs
    res_calculator.add_pair(forecast, observation)

res_ds_value = res_calculator.compute_res_ds(o,f)
print(res_ds_value)

0.8386900774619516


In [112]:

rel_calculator = REL_DS()
for forecast, observation in time_series:  # assuming time_series is a list of forecast/observation pairs
    rel_calculator.add_pair(forecast, observation)

rel_ds_value = rel_calculator.compute_rel_ds()
print(rel_ds_value)



0.2671843712870085


In [113]:
ds = compute_ds(o,f)
print(ds)

0.12625846495741266


In [114]:
print(UNC - res_ds_value + rel_ds_value)

0.11984037289888289


In [115]:
# log2(e) = 1.442695

In [116]:
raise Exception 

Exception: 