In [1]:
import numpy as np
from numba import njit, prange
import matplotlib.pyplot as plt

pi = np.pi
np.random.seed(123)

In [2]:
signal = np.concatenate([np.random.normal(m, 0, 2) for m in [1, 3]])
signal = signal.reshape((-1, 1))

In [3]:
Theta = np.linspace(1, 4, 3, endpoint = False).reshape(-1, 1)

In [4]:
def apart(signal, Theta, penalty):
    n_samples = signal.shape[0]
    n_states = Theta.shape[0]
    costs = np.empty((n_samples, n_states), dtype=np.float64)
    for k_state in range(n_states):
        for k_sample in range(n_samples):
            diff = np.abs(signal[k_sample] - Theta[k_state])
            costs[k_sample, k_state] = np.sum(np.fmin(diff, 2 * np.pi - diff))

    n_samples, n_states = costs.shape
    soc_array = np.empty((n_samples + 1, n_states), dtype=np.float64)
    state_array = np.empty((n_samples + 1, n_states), dtype=np.int32)
    soc_array[0] = 0
    state_array[0] = -1

    # Forward loop
    for t in range(1, n_samples + 1):
        for k_state in range(n_states):
            best_state = k_state
            best_soc = soc_array[t - 1][best_state]
            for k in range(n_states):
                if k != k_state:
                    soc = soc_array[t - 1][k]
                    if soc + penalty < best_soc:
                        best_state = k
                        best_soc = soc + penalty

            soc_array[t, k_state] = best_soc + costs[t - 1, k_state]
            state_array[t, k_state] = best_state

    # Backtracking
    end = n_samples
    state = np.argmin(soc_array[end])
    states = np.empty(n_samples, dtype=np.int32)
    while (state > -1) and (end > 0):
        states[end - 1] = state
        state = state_array[end, state]
        end -= 1
    return states, soc_array, state_array

In [5]:
penalty = 0.1
states, soc_array, state_array = apart(signal, Theta, penalty)
print(states)
print(soc_array)
print(state_array)

[0 0 2 2]
[[0.  0.  0. ]
 [0.  1.  2. ]
 [0.  1.1 2.1]
 [2.  1.1 0.1]
 [2.2 1.2 0.1]]
[[-1 -1 -1]
 [ 0  1  2]
 [ 0  0  0]
 [ 0  0  0]
 [ 2  2  2]]


In [6]:
# Version 1
def geo_d(x: np.ndarray, y: np.ndarray) -> float:
    diff = np.abs(x - y)
    return np.sum(np.fmin(diff, 2 * np.pi - diff))

def apart(y, Theta, penalty):
    T, D = y.shape
    M, _ = Theta.shape

    V = np.zeros((T + 1, M))
    s = np.empty((T + 1, M), dtype=np.int32)

    for t in range(1, T + 1):
        for k in range(M):
            V_candidates = np.zeros(M)
            for j in range(M):
                V_candidates[j] = V[t - 1][j] + penalty * (not np.array_equal(Theta[j], Theta[k])) + geo_d(y[t - 1], Theta[k])
            V[t][k] = min(V_candidates)
            s[t][k] = np.argmin(V_candidates)

    # Backtracking
    end = T
    state = np.argmin(V[end])
    states = np.empty(T, dtype=np.int32)
    while (state > -1) and (end > 0):
        states[end - 1] = state
        state = s[end, state]
        end -= 1
    return states, V, s



penalty = 0.1
states, soc_array, state_array = apart(signal, Theta, penalty)
print(states)
print(soc_array)
print(state_array)

[0 0 2 2]
[[0.  0.  0. ]
 [0.  1.  2. ]
 [0.  1.1 2.1]
 [2.  1.1 0.1]
 [2.2 1.2 0.1]]
[[        11          0 -698863171]
 [         0          1          2]
 [         0          0          0]
 [         0          0          0]
 [         2          2          2]]


In [7]:
# Version 2
def geo_d(theta, psi):
    diff = np.abs(psi - theta)
    return np.sum(np.square(np.minimum(diff, 2*pi - diff)))

def apart(y, Theta, lda):
    T, D = y.shape
    M, _ = Theta.shape

    V = np.zeros((T + 1, M))
    s = -1*np.ones((T + 1, M), dtype=np.int32)
    for t in range(1, T + 1):
        for k in range(M):
            V_candidates = V[t-1] + lda * np.any(Theta[k] != Theta, axis=1) + geo_d(Theta[k], y[t-1])
            best_idx = np.argmin(V_candidates)
            V[t][k] = V_candidates[best_idx]
            s[t][k] = best_idx

    # Backtracking
    states = np.empty(T, dtype=np.int32)
    state = np.argmin(V[T])
    for t in reversed(range(T)):
        states[t] = state
        state = state_array[t + 1, state]
    
    return states, V, s



penalty = 0.1
states, soc_array, state_array = apart(signal, Theta, penalty)
print(states)
print(soc_array)
print(state_array)

[0 0 2 2]
[[0.  0.  0. ]
 [0.  1.  4. ]
 [0.  1.1 4.1]
 [4.  1.1 0.1]
 [4.2 1.2 0.1]]
[[-1 -1 -1]
 [ 0  1  2]
 [ 0  0  0]
 [ 0  0  0]
 [ 2  2  2]]


In [8]:
# verison 3

@njit
def geo_d(theta, psi):
    diff = np.abs(psi - theta)
    return np.sum(np.square(np.minimum(diff, 2*pi - diff)))

@njit(parallel=True)
def apart(y, Theta, lda):
    T = y.shape[0]
    M = Theta.shape[0]

    V = np.zeros((T + 1, M))
    s = -1 * np.ones((T + 1, M), dtype=np.int32)

    for t in range(1, T + 1):
        for k in prange(M):
            V_candidates = np.zeros(M)
            for m in prange(M):
                if np.array_equal(Theta[m], Theta[k]):
                    V_candidates[m] = V[t - 1][m] + geo_d(Theta[k], y[t - 1])
                else:
                    V_candidates[m] = V[t - 1][m] + geo_d(Theta[k], y[t - 1]) + lda

            best_idx = np.argmin(V_candidates)
            V[t][k] = V_candidates[best_idx]
            s[t][k] = best_idx

    # Backtracking
    states = np.empty(T, dtype=np.int32)
    state = np.argmin(V[T])
    for t in range(T - 1, -1, -1):
        states[t] = state
        state = s[t + 1, state]

    return states, V, s



penalty = 0.1
states, soc_array, state_array = apart(signal, Theta, penalty)
print(states)
print(soc_array)
print(state_array)

[0 0 2 2]
[[0.  0.  0. ]
 [0.  1.  4. ]
 [0.  1.1 4.1]
 [4.  1.1 0.1]
 [4.2 1.2 0.1]]
[[-1 -1 -1]
 [ 0  1  2]
 [ 0  0  0]
 [ 0  0  0]
 [ 2  2  2]]


In [9]:
states[1:]

array([0, 2, 2])

In [10]:
states[:-1] != states[1:]

array([False,  True, False])

In [11]:
# Version 2
def d2(theta, psi):
    diff = np.abs(psi - theta)
    return np.sum(np.square(np.minimum(diff, 2*pi - diff)))

def apart(y, Theta, lda):
    T = y.shape[0]
    M = Theta.shape[0]

    V = np.zeros((T + 1, M))
    s = -1 * np.ones((T + 1, M), dtype=np.int32)
    for t in range(1, T + 1):
        for k in range(M):
            V_candidates = V[t-1] + lda * np.any(Theta[k] != Theta, axis=1) + d2(Theta[k], y[t-1])
            best_idx = np.argmin(V_candidates)
            V[t][k] = V_candidates[best_idx]
            s[t][k] = best_idx

    # Backtracking
    states = np.zeros(T, dtype=np.int32)
    state = np.argmin(V[T])
    for t in reversed(range(T)):
        states[t] = state
        state = s[t + 1][state]
    
    chpnts = np.arange(len(y) - 1)[states[:-1] != states[1:]]
    return chpnts, Theta[states]



penalty = 0.1
chpnts, signal_mean = apart(signal, Theta, penalty)
print(chpnts)
print(signal_mean)

[1]
[[1.]
 [1.]
 [3.]
 [3.]]
