In [13]:
import numpy.linalg as LA
from scipy.optimize import minimize
from math import inf
from sklearn.preprocessing import normalize
from pyclustering.cluster.kmedians import kmedians
from itertools import permutations

from tqdm import tqdm
from Synthetic import *


def count_transition(x, a, y, trajectories):
    cnt = 0
    H = len(trajectories[0])
    for trajectory in trajectories:
        cnt += sum(1 for i in range(H-2) if i % 3 == 0 and trajectory[i:i + 3] == [x, a, y])
    return cnt


def count_visitation(x, a, trajectories):
    cnt = 0
    H = len(trajectories[0])
    for trajectory in trajectories:
        cnt += sum(1 for i in range(H-1) if i % 2 == 0 and trajectory[i:i + 2] == [x, a])
    return cnt


def count_transition_latent(s, a, v, trajectories, f):
    cnt = 0
    H = len(trajectories[0])
    for trajectory in trajectories:
        cnt += sum(1 for i in range(H-2) if i % 2 == 0 and f[trajectory[i]] == s and trajectory[i + 1] == a and f[
            trajectory[i + 2]] == v)
    return cnt


def count_transition_mixed1(x, a, s, trajectories, f):
    cnt = 0
    H = len(trajectories[0])
    for trajectory in trajectories:
        cnt += sum(1 for i in range(H-2) if i % 2 == 0 and trajectory[i] == x and trajectory[i + 1] == a and f[
            trajectory[i + 2]] == s)
    return cnt


def count_transition_mixed2(s, a, x, trajectories, f):
    cnt = 0
    H = len(trajectories[0])
    for trajectory in trajectories:
        cnt += sum(1 for i in range(H-2) if
                   i % 2 == 0 and f[trajectory[i]] == s and trajectory[i + 1] == a and trajectory[i + 2] == x)
    return cnt


def low_rank(N, r=1):
    U, S, V = LA.svd(N, full_matrices=False)
    Nr = np.zeros((len(U), len(V)))
    for i in range(r):
        Nr += S[i] * np.outer(U.T[i], V[i])
    return Nr


def count_error(f, f_1, perm, n):
    cnt = 0
    for x in range(n):
        if f[x] != perm[f_1[x]]:
            cnt += 1
    return cnt


def error_rate(f, f_1, n, S):
    error = n
    for perm_ in permutations(range(S)):
        perm = {}
        for s in range(S):
            perm[s] = perm_[s]
        error = min(error, count_error(f, f_1, perm, n))

    return error / n


def init_spectral(env, trajectories):
    n, S, A, H = env.n, env.S, env.A, env.H
    T = len(trajectories)

    # Collect trimmed, low-rank approx, empirical transition matrices
    transition_matrices = []
    for a in range(A):
        # Collect empirical transition matrices
        transition_matrix_a = np.zeros([n, n])
        visitations_a = np.zeros([n])
        for x in tqdm(range(n)):
            visitations_a[x] = count_visitation(x, a, trajectories)
            for y in range(n):
                transition_matrix_a[x, y] = count_transition(x, a, y, trajectories)
        # Trimming!
        contexts_ordered = np.argsort(visitations_a)
        ratio = (T*H)/(n*A)
        num_trimmed = int(np.floor(n * np.exp(- ratio * np.log(ratio))))
        if num_trimmed > 0:
            contexts_trimmed = contexts_ordered[-num_trimmed:]
            for x, y in zip(contexts_trimmed, contexts_trimmed):
                transition_matrix_a[x][y] = 0

        # Low-rank approximation
        transition_matrices.append(low_rank(transition_matrix_a, r=S))

    M_in = np.concatenate(tuple(transition_matrices), axis=1)
    M_out = np.concatenate(tuple(transition_matrices), axis=0).T
    M = np.concatenate((M_in, M_out), axis=1)

    # l1-normalize rows
    row_sums = M.sum(axis=1)
    row_sums[row_sums == 0] = 1
    M = M / row_sums[:, np.newaxis]
    # M = normalize(M, norm='l1', axis=1)

    # S-median clustering to the rows
    initial_medians = M[:S, :]
    # initial_medians = np.random.randn(S, 2*n*A)
    kmedians_instance = kmedians(M, initial_medians)
    kmedians_instance.process()
    clusters = kmedians_instance.get_clusters()

    f_1 = {}
    for x in range(n):
        for s in range(S):
            if x in clusters[s]:
                f_1[x] = s

    return f_1


def likelihood_improvement(env, trajectories, f_1):
    # likelihood_improvement
    n, S, A, H = env.n, env.S, env.A, env.H
    T = len(trajectories)

    f_final = f_1
    num_iter = int(np.floor(np.log(n * A)))

    for _ in tqdm(range(num_iter)):
        # estimated latent transition matrices
        Ns = [np.zeros((S, S)) for _ in range(A)]
        for a in range(A):
            for s in range(S):
                for k in range(S):
                    Ns[a][s][k] = count_transition_latent(s, a, k, trajectories, f_final)

        # likelihood improvement
        f_ = {}
        for x in range(n):
            likelihoods = []
            for j in range(S):
                # number of visitations to j
                N2 = 0
                for a in range(A):
                    tmp = np.sum(Ns[a], axis=0)
                    N2 += tmp[j]
                for a in range(A):
                    likelihood = 0
                    # number of visitations from (j, a)
                    tmp = np.sum(Ns[a], axis=1)
                    N1 = tmp[j]
                    if N1 == 0 or N2 == 0:
                        print(f"N1={N1}, N2={N2}")
                        likelihood = -inf
                    else:
                        for s in range(S):
                            # estimate of p and p_bwd
                            p_estimated = Ns[a][j][s] / N1  # ((j, a) -> s) / ((j, a) -> X)
                            p_bwd_estimated = Ns[a][s][j] / N2  # (j <- (s, a)) / (j <- X)
                            if p_estimated == 0 or p_bwd_estimated == 0:
                                print(f"p_estimated={p_estimated}, p_bwd_estimated={p_bwd_estimated}")
                                likelihood = -inf
                                continue
                            # number of visitations (x, a) -> s
                            N3 = count_transition_mixed1(x, a, s, trajectories, f_final)
                            # number of visitations (s, a) -> x
                            N4 = count_transition_mixed2(s, a, x, trajectories, f_final)

                            # compute likelihood
                            likelihood += (N3 * np.log(p_estimated)) + (N4 * np.log(p_bwd_estimated))
                likelihoods.append(likelihood)

            # new cluster
            f_[x] = np.argmax(likelihoods)
        f_final = f_
    return f_final

In [10]:
T = 50
env = Synthetic()
# true clusters
f = {}
for s in range(env.S):
    cluster = env.partitions[s]
    for x in range(cluster.start, cluster.start + cluster.n):
        f[x] = s
# obtain trajectories
trajectories = generate_trajectories(T, env)

# initial spectral clustering
f_1 = init_spectral(env, trajectories)
init_err_rate = error_rate(f, f_1, env.n, env.S)
print("Error rate after initial clustering is ", init_err_rate)

100%|██████████| 200/200 [00:50<00:00,  3.97it/s]
100%|██████████| 200/200 [00:50<00:00,  3.98it/s]


Error rate after initial clustering is  0.49


In [14]:
# likelihood_improvement
f_final = likelihood_improvement(env, trajectories, f_1)
final_err_rate = error_rate(f, f_1, env.n, env.S)
print("Final error rate is ", final_err_rate)

 20%|██        | 1/5 [00:04<00:19,  4.86s/it]

N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0

 40%|████      | 2/5 [00:06<00:08,  2.97s/it]

N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0

 60%|██████    | 3/5 [00:08<00:04,  2.37s/it]

p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0

 80%|████████  | 4/5 [00:09<00:02,  2.07s/it]

p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0

100%|██████████| 5/5 [00:11<00:00,  2.29s/it]

p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
p_estimated=0.0, p_bwd_estimated=0.0
p_estimated=0.0, p_bwd_estimated=0.0
N1=0.0, N2=0.0
N1=0.0, N2=0.0
Final error rate is  0.49





In [4]:
f_1

{0: 0,
 1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 0,
 9: 0,
 10: 0,
 11: 0,
 12: 0,
 13: 0,
 14: 0,
 15: 0,
 16: 0,
 17: 0,
 18: 0,
 19: 0,
 20: 0,
 21: 0,
 22: 0,
 23: 0,
 24: 0,
 25: 0,
 26: 0,
 27: 0,
 28: 0,
 29: 0,
 30: 0,
 31: 0,
 32: 0,
 33: 0,
 34: 0,
 35: 0,
 36: 0,
 37: 0,
 38: 0,
 39: 0,
 40: 0,
 41: 0,
 42: 0,
 43: 0,
 44: 0,
 45: 0,
 46: 0,
 47: 0,
 48: 0,
 49: 0,
 50: 0,
 51: 0,
 52: 0,
 53: 0,
 54: 0,
 55: 0,
 56: 0,
 57: 0,
 58: 0,
 59: 0,
 60: 0,
 61: 0,
 62: 0,
 63: 0,
 64: 0,
 65: 0,
 66: 0,
 67: 0,
 68: 0,
 69: 0,
 70: 0,
 71: 0,
 72: 0,
 73: 0,
 74: 0,
 75: 0,
 76: 0,
 77: 0,
 78: 0,
 79: 0,
 80: 0,
 81: 0,
 82: 0,
 83: 0,
 84: 0,
 85: 0,
 86: 0,
 87: 0,
 88: 0,
 89: 0,
 90: 0,
 91: 0,
 92: 0,
 93: 0,
 94: 0,
 95: 0,
 96: 0,
 97: 0,
 98: 0,
 99: 0,
 100: 1,
 101: 1,
 102: 1,
 103: 1,
 104: 1,
 105: 1,
 106: 1,
 107: 1,
 108: 1,
 109: 1,
 110: 1,
 111: 1,
 112: 1,
 113: 1,
 114: 1,
 115: 1,
 116: 1,
 117: 1,
 118: 1,
 119: 1,
 120: 1,
 121: 1,
 122: 1,
 12