In [None]:
import warnings
from collections.abc import Iterable

import numpy as np

# np.seterr(divide='ignore')
# warnings.filterwarnings('ignore')
warnings.simplefilter('error')

class DKL:
    def __init__(self, o, f):
        self.o = o
        self.f = f

        assert np.max(f) < 1.0
        assert np.min(f) > 0.0

        # Number of unique forecast probs
        self.k = np.unique(self.f)

        # Number of unique classes
        self.c = np.unique(self.o)

        # Frequency of obs
        self.o_bar = np.mean(self.o)

    @classmethod
    def compute_info_gain(cls, o, f1, f2, from_components=False):
        DKL1 = cls.__compute_dkl(o, f1)[1]
        DKL2 = cls.__compute_dkl(o, f2)[1]
        return DKL1 - DKL2

    @staticmethod
    def __compute_dkl(T, P):
        """Generalised, so t is "truth", p "prediction", just to avoid assuming 1 > f > 0
        """
        def do_term1(t,p):
            return (1 - t) * np.log2((1 - t) / (1 - p))

        def do_term2(t,p):
            return t * np.log2(t / p)

        if not isinstance(T, Iterable):
            T = [T,]

        if not isinstance(P, Iterable):
            P = [P,]

        all_dkl = []
        for t,p in zip(T,P):

            if (t == 1) or (p == 1):
                all_dkl.append(do_term2(t,p))
            elif (t == 0) or (p == 0):
                all_dkl.append(do_term1(t,p))
            else:
                all_dkl.append(do_term1(t,p) + do_term2(t,p))

            # term1 = np.nan_to_num(raw_term1)
            # term2 = np.nan_to_num(raw_term2)

            # Find nans and convert to 0
            # Not infinity, as f in {0,1} dealt with already
            # dkl = -term1 -term2
            # all_dkl = np.nan_to_num(raw_dkl)
        return np.array(all_dkl), np.mean(all_dkl)

    def compute_dkl(self, from_components=False):
        if from_components:
            U = self.compute_unc()
            R = self.compute_rel()
            D = self.compute_dsc()
            return R - D + U
        all_dkl, raw_dkl = self.__compute_dkl(self.o,self.f)
        dkl = np.nan_to_num(raw_dkl)
        # print(f"{all_dkl=}, {raw_dkl=}, {dkl=}")
        return dkl

    def compute_dsc(self):
        # N total forecasts
        # K number of unique forecasts
        # ok_bar is frequency for obs in prob-class k
        N = len(self.o)
        K = len(self.k)
        ok_bar_1d = np.zeros([K])
        dsc_1d = np.zeros([K])
        fk_list = []
        nk_1d = np.zeros([K])
        dkl_1d = np.zeros([K])
        for ik,k in enumerate(self.k):
            ok_bar_1d[ik] = np.mean(self.o[self.f==k])
            fk_list.append(self.f[self.f==k])
            nk_1d[ik] = len(fk_list[ik])
            dkl_all, dkl_mean = self.__compute_dkl(ok_bar_1d[ik],self.o_bar)
            dkl_1d[ik] = np.mean(dkl_all)
            dsc_1d[ik] = nk_1d[ik] * dkl_1d[ik]
            # dsc_2d.append(nk*dkl_mean)
        print(f"{dsc_1d=}, {ok_bar_1d=}, {nk_1d=}, {dkl_1d=}")
        # dsc_1d = np.nan_to_num(dsc_1d)
        return np.sum(dsc_1d)/N

    def compute_unc(self):
        term1 = (1-self.o_bar) * np.log2(1-self.o_bar)
        term2 = self.o_bar * np.log2(self.o_bar)
        unc = term1+term2
        # Find nans and convert to 0
        # Not infinity, as f in {0,1} dealt with already
        return -unc

    def compute_rel(self):
        # N total forecasts
        # K number of unique forecasts
        # ok_bar is frequency for obs in prob-class k
        # fk is the pmf for prob-class k
        N = len(self.o)
        K = len(self.k)
        rel_1d = np.zeros([K])
        ok_bar_1d = np.zeros([K])
        fk_list = []
        nk_1d = np.zeros([K])
        dkl_1d = np.zeros([K])
        for ik,k in enumerate(self.k):
            ok_bar_1d[ik] = np.mean(self.o[self.f==k])
            fk_list.append(self.f[self.f==k])
            nk_1d[ik] = len(fk_list[ik])
            dkl_all, dkl_mean = self.__compute_dkl(ok_bar_1d[ik],fk_list[ik])
            # dkl_1d[ik] = np.sum(dkl_all)
            dkl_1d[ik] = dkl_mean
            # print(ok_bar,fk,nk,dkl)
            rel_1d[ik] = (nk_1d[ik]*dkl_1d[ik])
        # rel_1d[rel_1d == np.nan] = 0
        # rel_1d = np.nan_to_num(rel_1d)
        print(f"{rel_1d=}, {ok_bar_1d=}, {fk_list=}, {nk_1d=}, {dkl_1d=}")
        return np.sum(rel_1d)/N

    def compute_bs(self):
        return np.mean((self.o-self.f)**2)

    def compute_bss(self):
        bs = self.compute_bs()
        bs_unc = self.o_bar * (1-self.o_bar)
        return 1 - (bs/bs_unc)

    def compute_skill_score(self, return_components=True):
        U = self.compute_unc()
        R = self.compute_rel()
        D = self.compute_dsc()
        D_ss = D/U
        R_ss = R/U
        SS = D_ss - R_ss
        if return_components:
            return SS, D_ss, R_ss
        return SS

    def compute_info_gain_over_climo(self):
        R = self.compute_rel()
        D = self.compute_dsc()

        # Info gain over using the base-rate for forecasts
        return (D - R)