In [1]:
import math
import numpy as np
np.set_printoptions(threshold=np.inf)
import pandas as pd
import matplotlib.pyplot as plt

from scipy.stats import median_abs_deviation as mad

In [2]:
class WBS:

    def __init__(self, x: np.array, M: int = 5000, rand_interval: bool = True) -> None:
        self.x = x
        self.M = M
        self.rand_interval = rand_interval
        
        self.n = len(x)

        assert self.n > 1, "x must have at least 2 elements"
        assert np.isnan(x).any() == False, "x must not contain NaN values"
        assert np.var(x) != 0, "x must not be constant"

        assert np.isnan(M) == False, "M must not be NaN"
        assert M > 0, "M must be positive"

        self.intervals = self.intervals_init()

        self.results = []

    
    def intervals_init(self) -> pd.DataFrame:
        if self.rand_interval:
            self.intervals = np.zeros((self.M, 2), dtype=int)
            self.intervals[:, 0] = np.ceil(np.random.uniform(size=self.M) * (self.n - 2)).astype(int) + 1
            self.intervals[:, 1] = self.intervals[:, 0] + 1 + np.ceil(np.random.uniform(size=self.M) * (self.n - 1 - self.intervals[:, 0])).astype(int) 
        else:
            m = math.ceil(0.5 * (math.sqrt(8 * M + 1) + 1))
            m = min(self.n, m)
            self.M = int(m * (m - 1) / 2)
            end_points = np.round(np.concatenate(([1], np.linspace(2, self.n - 1, m - 2), [self.n])))
            self.intervals = np.zeros((int(self.M), 2))
            
            k = 0
            for i in range(1,m):
                tmp = (m - i)
                self.intervals[k:(k + tmp), 0] = np.repeat(end_points[i-1], tmp)
                self.intervals[k:(k + tmp), 1] = end_points[i:m]
                k = k + tmp
            
        return pd.DataFrame(self.intervals, columns=['s', 'e'], dtype=int)


    @staticmethod
    def ipi_arg_max(res: np.array, n: int) -> tuple:
        max_count = 0
        max_fabs = -1
        ipargmax = 0
        for i in range(n-1):
            abs_res = abs(res[i])
            if abs_res > max_fabs:
                ipargmax = i
                max_fabs = abs_res
                max_count = 1
            elif abs_res == max_fabs:
                max_count += 1

        if max_count > 1:
            max_count = max_count // 2 + (max_count % 2)
            k = 0
            i = 0
            while ((i < (n - 1)) and (k < max_count)):
                i += 1
                if abs(res[i]) == max_fabs:
                    k += 1
            ipargmax = i
        
        return ipargmax, res[ipargmax]


    @staticmethod
    def wbs_ipi(x: np.array, n: int) -> tuple:
        one_over_n = 1.0 / n
        n_squared = n * n

        iminus = [1.0 / math.sqrt(n_squared - n) * sum(x[1:])]
        iplus = [math.sqrt(1.0 - one_over_n) * x[0]]
        res = [iplus[0] - iminus[0]]

        for i in range(1, n - 1):
            iplusone_inv = 1.0 / (i + 1.0)
            factor = math.sqrt((n - i - 1.0) * i * iplusone_inv / (n - i))
            iplus.append(iplus[i - 1] * factor + x[i] * math.sqrt(iplusone_inv - one_over_n))
            iminus.append(iminus[i - 1] / factor - x[i] / math.sqrt(n_squared * iplusone_inv - n))
            res.append(iplus[i] - iminus[i])

        return WBS.ipi_arg_max(res, n)
        

    def bs_rec(self, x: np.array, s: int, e: int, minth: float = -1., scale: int = 0) -> None:
        n = e - s + 1
        if n > 1:
            ipargmax, ipmax = WBS.wbs_ipi(x[s-1:e], n)
            cptcand = ipargmax + s

            if minth > abs(ipmax) or minth < 0:
                minth = abs(ipmax)
            
            self.results.append([s, e, cptcand, ipmax, minth, scale])
            
            self.bs_rec(x, s, cptcand, minth, scale + 1)
            self.bs_rec(x, cptcand + 1, e, minth, scale + 1)


    def wbs_rec(self, s: int, e: int, index: list, indexn: int, minth: float = -1, scale: int = 1) -> None:
        n = e - s + 1

        if n > 1:
            if indexn > 0:
                ipargmax, ipmax = WBS.wbs_ipi(self.x[s-1:e], n)
 
                if np.abs(ipmax) < self.wbs_res.loc[index[0], 'abs.CUSUM']:
                    cptcand = self.wbs_res.loc[index[0], 'cpt']
                    if minth > self.wbs_res.loc[index[0], 'abs.CUSUM'] or minth < 0:
                        minth = self.wbs_res.loc[index[0], 'abs.CUSUM']
                    self.results.append([s, e, cptcand, ipmax, minth, scale])
                else:
                    cptcand = ipargmax + s
                    if minth > np.abs(ipmax) or minth < 0:
                        minth = np.abs(ipmax)
                    self.results.append([s, e, cptcand, ipmax, minth, scale])
                
                indexnl, indexnr = [], []
                for i in range(indexn):
                    if self.wbs_res.loc[index[i], 's'] >= s and self.wbs_res.loc[index[i], 'e'] <= cptcand:
                        indexnl.append(index[i])
                    elif self.wbs_res.loc[index[i], 's'] >= cptcand + 1 and self.wbs_res.loc[index[i], 'e'] <= e:
                        indexnr.append(index[i])
                
                if len(indexnl) > 0:
                    self.wbs_rec(s, cptcand, indexnl, len(indexnl), minth, scale + 1)
                else:
                    self.bs_rec(self.x, s, cptcand, minth, scale + 1)
                
                if len(indexnr) > 0:
                    self.wbs_rec(cptcand + 1, e, indexnr, len(indexnr), minth, scale + 1)
                else:
                    self.bs_rec(self.x, cptcand + 1, e, minth, scale + 1)
            
            else:
                self.bs_rec(self.x, s, e, minth, scale)


    def wbs(self) -> None:
        wbs_res = []
        for i in range(self.M):
            s = self.intervals.loc[i, 's']
            e = self.intervals.loc[i, 'e']
            ipargmax, ipmax = WBS.wbs_ipi(self.x[s-1:e], e - s + 1)
            cptcand = ipargmax + s
            wbs_res.append([s, e, cptcand, ipmax, abs(ipmax)])
    
        self.wbs_res = pd.DataFrame(wbs_res, columns=['s', 'e', 'cpt', 'CUSUM', 'abs.CUSUM'])
        largest_cusum_index = self.wbs_res.sort_values(by='abs.CUSUM', ascending=False).index.to_list()
        
        self.wbs_rec(1, self.n, largest_cusum_index, self.M)

        self.results = pd.DataFrame(self.results, columns=['s', 'e', 'cpt', 'CUSUM', 'min.th', 'scale'])
        self.results.sort_values(by=['cpt'], inplace=True)
    

    @staticmethod
    def means_between_changepoints(x: np.array, changepoints: list) -> np.array:
        changepoints = sorted(changepoints)
        len_cpt = len(changepoints)
        s = np.zeros(len_cpt + 1, dtype=int)
        e = np.zeros(len_cpt + 1, dtype=int)
        e[-1] = len(x) - 1
        if len_cpt:
            s[1:] = np.array(changepoints) + 1
            e[:-1] = np.array(changepoints)

        means = np.zeros(len_cpt + 1)
        for i in range(len_cpt + 1):
            means[i] = np.mean(x[s[i]:e[i]+1])
        
        return np.repeat(means, e - s + 1)
    

    @staticmethod
    def ssic_penalty(n: int, cpt: list, alpha: float, ssic_type: str) -> float:
        if ssic_type == "log":
            pen = np.log(n) ** alpha
        elif ssic_type == "power":
            pen = n ** alpha
        
        k = len(cpt)
        
        return k * pen
    

    def changepoint(self, threshold: float = None, threshold_const: float = 1.3, Kmax: int = 50, alpha: float = 1.01, ssic_type: str = "log") -> None:

        sigma = np.median(np.abs(np.diff(self.x) - np.median(np.diff(self.x)))) * 1.4826 / np.sqrt(2)

        if threshold is not None:
            th = threshold
        else:
            th = sigma * threshold_const * np.sqrt(2 * np.log(self.n))

        self.results.sort_values(by=['min.th'], ascending=False, inplace=True)
        changepoints = self.results['cpt'].tolist()[0:Kmax]
        changepoints = [x - 1 for x in changepoints]

        ic_curve = np.zeros(len(changepoints) + 1)
        for i in range(len(changepoints), -1, -1):
            means = WBS.means_between_changepoints(self.x, changepoints[:i])
            min_log_likelihood = self.n / 2 * np.log(np.sum((self.x - means) ** 2) / self.n)
            ic_curve[i] = min_log_likelihood + WBS.ssic_penalty(self.n, changepoints[:i], alpha=alpha, ssic_type=ssic_type)
        min_ic_index = np.argmin(ic_curve)
        if min_ic_index == 0:
            cpt_ic = None
        else:
            cpt_ic = changepoints[:min_ic_index]
        self.changepoints = cpt_ic



In [3]:
np.random.seed(42)
x = np.concatenate((np.random.normal(0, 1, 50), np.random.normal(1, 1, 50), np.random.normal(0, 1, 50), np.random.normal(1, 1, 50)))

In [4]:
wbs = WBS(x)
wbs.wbs()
wbs.changepoint()

# print(wbs.results.loc[wbs.results['min.th'].max() == wbs.results['min.th'], :])
# plt.figsize=(20, 10)
# plt.plot(wbs.results['cpt'], wbs.results['CUSUM'])
# plt.show()

IC changepoints: [49, 149, 99]
