In [2]:
import numpy as np
import pandas as pd
import tqdm
from sklearn.preprocessing import OneHotEncoder

In [3]:
np.random.seed(63)

states = np.arange(4)
observations = np.arange(4)

pi_th = np.random.random(4)
pi_th /= pi_th.sum()

A_th = np.random.random((4, 4))
i_zero = [0, 0, 0, 1, 1, 1, 2, 2, 3, 3]
j_zero = [0, 2, 3, 0, 1, 3, 1, 2, 1, 2]
A_th[i_zero, j_zero] = 0
A_th /= A_th.sum(axis=1)[:, None]

B_th = np.random.random((4, 4))
B_th /= B_th.sum(axis=1)[:, None]

def generate(A, B, pi, T):
    x = []
    y = []
    for t in range(T):
        if t == 0:
            x.append(np.random.choice(states, p=pi))
        else:
            x.append(np.random.choice(states, p=A[x[-1]]))
        y.append(np.random.choice(observations, p=B[x[-1]]))
    x = np.array(x)
    y = np.array(y)
    return x, y

x1, y1 = generate(A_th, B_th, pi_th, 250)
x2, y2 = generate(A_th, B_th, pi_th, 150)

In [4]:
def loglikelihood(y, A, B, pi):
    T = len(y)
    n_s = A.shape[0]
    
    norm = np.empty(T)

    alpha = np.empty((T, n_s))
    alpha[0] = pi * B[:, y[0]]
    norm[0] = 1/alpha[0].sum()
    for t in range(T-1):
        alpha[t+1] = B[:, y[t+1]] * A.T.dot(alpha[t])
        norm[t+1] = 1/alpha[t+1].sum()
        alpha[t+1] *= norm[t+1]

    loglike = alpha[T-1].sum() - np.log(norm).sum()
    return loglike

In [5]:
def forward_backward(y, A, B, pi):
    T = len(y)
    n_s = A.shape[0]
    
    norm = np.empty(T)

    alpha = np.empty((T, n_s))
    alpha[0] = pi * B[:, y[0]]
    norm[0] = 1/alpha[0].sum()
    alpha[0] *= norm[0]
    for t in range(T-1):
        alpha[t+1] = B[:, y[t+1]] * A.T.dot(alpha[t])
        norm[t+1] = 1/alpha[t+1].sum()
        alpha[t+1] *= norm[t+1]

    beta = np.empty((T, n_s))
    beta[T-1] = norm[T-1] * 1
    beta[T-1] *= norm[T-1]
    for t in range(T-2, -1, -1):
        beta[t] = A.dot(beta[t+1] * B[:, y[t+1]])
        beta[t] *= norm[t]

    return alpha, beta, norm

In [6]:
def update(y, A, B, pi, alpha, beta, norm):
    T = len(y)
    n_s = A.shape[0]
    n_o = B.shape[1]
    
    new_pi = np.empty(n_s)
    new_A = np.empty((n_s, n_s))
    new_B = np.empty((n_s, n_o))
    
    for i in range(n_s):
        new_pi[i] = alpha[0, i] * beta[0, i] / alpha[0].dot(beta[0])
    
    for i in range(n_s):
        for j in range(n_s):
            new_A[i, j] = sum(
                [alpha[t, i] * A[i, j] * B[j, y[t+1]] * beta[t+1, j] for t in range(T-1)]
            ) / sum(
                [alpha[t, i] * beta[t, i] / norm[t] for t in range(T-1)]
            )
    for j in range(n_s):
        for k in range(n_o):
            new_B[j, k] = sum(
                [alpha[t, j] * beta[t, j] / norm[t] for t in range(T) if y[t] == k]
            ) / sum(
                [alpha[t, j] * beta[t, j] / norm[t] for t in range(T)]
            )
            
    return new_A, new_B, new_pi

In [7]:
def baum_welch_step(y, A, B, pi):
    alpha, beta, norm = forward_backward(y, A, B, pi)
    return update(y, A, B, pi, alpha, beta, norm)

def baum_welch(y, A0, B0, pi0, iterations=100):
    A, B, pi = A0, B0, pi0
    for it in tqdm.trange(iterations):
        A, B, pi = baum_welch_step(y, A, B, pi)
    return A, B, pi

In [8]:
def baum_welch_step_dataset2(dataset, A, B, pi):
    
    n_s = A.shape[0]
    n_o = B.shape[1]
    
    new_A_num = np.zeros((n_s, n_s))
    new_A_den = np.zeros((n_s, n_s))
    new_B_num = np.zeros((n_s, n_o))
    new_B_den = np.zeros((n_s, n_o))
    new_pi_num = np.zeros(n_s)
    new_pi_den = np.zeros(n_s)
    
    for y in dataset:
        T = len(y)
        alpha, beta, norm = forward_backward(y, A, B, pi)
        
        for i in range(n_s):
            new_pi_num[i] += alpha[t, i] * beta[t, i] / norm[t]
            new_pi_den[i] += 1

        for i in range(n_s):
            for j in range(n_s):
                new_A_num[i, j] += sum(
                    [alpha[t, i] * A[i, j] * B[j, y[t+1]] * beta[t+1, j] for t in range(T-1)]
                )
                new_A_den[i, j] += sum(
                    [alpha[t, i] * beta[t, i] / norm[t] for t in range(T-1)]
                )
        for j in range(n_s):
            for k in range(n_o):
                new_B_num[j, k] += sum(
                    [alpha[t, j] * beta[t, j] / norm[t] for t in range(T-1) if y[t] == k]
                )
                new_B_den[j, k] += sum(
                    [alpha[t, j] * beta[t, j] / norm[t] for t in range(T-1)]
                )

    return new_A_num/new_A_den, new_B_num/new_B_den, new_pi_num/new_pi_den

def baum_welch_dataset2(dataset, A0, B0, pi0, iterations=100):
    A, B, pi = A0, B0, pi0
    for it in tqdm.trange(iterations):
        A, B, pi = baum_welch_step_dataset2(dataset, A, B, pi)
    return A, B, pi

In [76]:
def baum_welch_step_dataset(dataset, A, B, pi):
    
    n_s = A.shape[0]
    n_o = B.shape[1]
    
    new_A_num = np.zeros((n_s, n_s))
    new_A_den = np.zeros((n_s, n_s))
    new_B_num = np.zeros((n_s, n_o))
    new_B_den = np.zeros((n_s, n_o))
    new_pi_num = np.zeros(n_s)
    new_pi_den = np.zeros(n_s)
    
    for y in dataset:
        T = len(y)
        alpha, beta, norm = forward_backward(y, A, B, pi)
        
        new_A_num += (
            alpha[:-1, :, None] * A[None, :, :] * B[:, y[1:]].T[:, None, :] * beta[1:, None, :]
        ).sum(axis=0)
        new_A_den += (
            alpha[:-1, :, None] * beta[:-1, :, None] / norm[:-1, None, None]
        ).sum(axis=0)
        
        onehot_encoder = OneHotEncoder(sparse=False, categories="auto")
        y_binary = onehot_encoder.fit_transform(y[:, None]).astype(int)
        new_B_num += (
            y_binary[:-1, None, :] * alpha[:-1, :, None] * beta[:-1, :, None] / norm[:-1, None, None]
        ).sum(axis=0)
        new_B_den += (
            alpha[:-1, :, None] * beta[:-1, :, None] / norm[:-1, None, None]
        ).sum(axis=0)
        
        new_pi_num += alpha[0, :] * beta[0, :] / norm[0]
        new_pi_den += 1

    return new_A_num/new_A_den, new_B_num/new_B_den, new_pi_num/new_pi_den

def baum_welch_dataset(dataset, A0, B0, pi0, iterations=100):
    A, B, pi = A0, B0, pi0
    for it in tqdm.trange(iterations):
        A, B, pi = baum_welch_step_dataset(dataset, A, B, pi)
    return A, B, pi

In [10]:
def baum_welch_dna(dataset, iterations=100):
    np.random.seed(63)
    A_guess = np.array([
        [  0,   1,   0,   0,], # exon 1 => exon 2
        [  0,   0,   1,   0,], # exon 2 => exon 3
        [0.5,   0,   0, 0.5,], # exon 3 => exon 1 or intron
        [0.5,   0,   0, 0.5,],
    ])
    A0 = np.random.random((4, 4)) * A_guess
    A0 /= A0.sum(axis=1)[:, None]
    
    B0 = np.random.random((4, 4)) / 4
    B0 /= B0.sum(axis=1)[:, None]
    
    pi0 = np.random.random(4)
    pi0 /= pi0.sum()
    
    print("\nInitial values")
    print(A0)
    print(B0)
    print(pi0)

    A, B, pi = baum_welch_dataset(dataset, A0, B0, pi0, iterations)
    
    print("Estimation")
    print(A)
    print(B)
    print(pi)
    
    return A, B, pi

In [11]:
def translate(dna_string):
    dna_string = dna_string.replace("A", "0")
    dna_string = dna_string.replace("T", "1")
    dna_string = dna_string.replace("G", "2")
    dna_string = dna_string.replace("C", "3")
    seq = list(dna_string)
    return np.array(seq).astype(int)

In [12]:
n_samples = 10
iterations = 100
for d in range(3):
    dataset_string = pd.read_csv("data/Xtr{}.csv".format(d), index_col=0)
    dataset = [
        translate(dna_string)
        for dna_string in dataset_string.values[:, 0]
    ]
    A, B, pi = baum_welch_dna(dataset[:n_samples], iterations)

  6%|▌         | 6/100 [00:00<00:01, 52.52it/s]


Initial values
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.54819171 0.         0.         0.45180829]
 [0.63018735 0.         0.         0.36981265]]
[[0.28091143 0.00260081 0.43141308 0.28507468]
 [0.27725454 0.26157888 0.20000889 0.26115768]
 [0.09972905 0.41371784 0.01206804 0.47448506]
 [0.24094784 0.26725151 0.34782094 0.14397971]]
[0.19152753 0.29893667 0.17995657 0.32957923]


100%|██████████| 100/100 [00:02<00:00, 49.79it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Estimation
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.26296646 0.         0.         0.73703354]
 [0.24723052 0.         0.         0.75276948]]
[[4.21564794e-03 9.24168892e-08 2.56331972e-01 7.39452288e-01]
 [4.61894839e-01 2.57029640e-01 4.26576457e-03 2.76809757e-01]
 [1.92338325e-01 1.55122981e-01 3.66416755e-01 2.86121939e-01]
 [2.56775762e-01 3.29383731e-01 3.26164047e-01 8.76764593e-02]]
[5.00719479e-25 9.07281066e-15 4.40676182e-07 4.16918117e+00]

Initial values
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.54819171 0.         0.         0.45180829]
 [0.63018735 0.         0.         0.36981265]]
[[0.28091143 0.00260081 0.43141308 0.28507468]
 [0.27725454 0.26157888 0.20000889 0.26115768]
 [0.09972905 0.41371784 0.01206804 0.47448506]
 [0.24094784 0.26725151 0.34782094 0.14397971]]
[0.19152753 0.29893667 0.17995657 0.32957923]


100%|██████████| 100/100 [00:02<00:00, 45.58it/s]
  6%|▌         | 6/100 [00:00<00:01, 51.08it/s]

Estimation
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.48346029 0.         0.         0.51653971]
 [0.15063821 0.         0.         0.84936179]]
[[7.12170306e-02 1.51675850e-01 4.59942987e-01 3.17164132e-01]
 [1.79586182e-08 1.70568690e-01 3.46250183e-03 8.25968791e-01]
 [3.05612295e-01 2.12199178e-01 6.32800285e-02 4.18908499e-01]
 [3.22285956e-01 2.20790914e-01 3.28066209e-01 1.28856920e-01]]
[3.34456901e-12 4.77713443e-12 2.25575633e+00 2.47685397e+00]

Initial values
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.54819171 0.         0.         0.45180829]
 [0.63018735 0.         0.         0.36981265]]
[[0.28091143 0.00260081 0.43141308 0.28507468]
 [0.27725454 0.26157888 0.20000889 0.26115768]
 [0.09972905 0.41371784 0.01206804 0.47448506]
 [0.24094784 0.26725151 0.34782094 0.14397971]]
[0.19152753 0.29893667 0.17995657 0.32957923]


100%|██████████| 100/100 [00:01<00:00, 50.32it/s]

Estimation
[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.38338466 0.         0.         0.61661534]
 [0.22920493 0.         0.         0.77079507]]
[[1.52075643e-01 1.78674557e-01 3.17861365e-01 3.51388435e-01]
 [1.01225036e-01 1.70184374e-01 1.03373855e-04 7.28487216e-01]
 [3.42915317e-01 4.03783490e-01 2.59367743e-04 2.53041825e-01]
 [2.15777171e-01 2.23709875e-01 5.21349262e-01 3.91636926e-02]]
[8.61173400e-45 1.35291046e+00 7.64336400e-01 1.60384291e+00]





Marginalized count kernel (Tsuda 2002)

In [87]:
def first_order_marginalized_count(y, A, B, pi):
    alpha, beta, norm = forward_backward(y, A, B, pi)
    onehot_encoder = OneHotEncoder(sparse=False, categories="auto")
    y_binary = onehot_encoder.fit_transform(y[:, None]).astype(int)
    return (y_binary[:, :, None] * alpha[:, None, :]).mean(axis=0).flatten()

def first_order_mc_features(dataset, A, B, pi):
    return np.array([
        marginalized_count(y, A, B, pi)
        for y in dataset
    ])

def second_order_marginalized_count(y, A, B, pi):
    alpha, beta, norm = forward_backward(y, A, B, pi)
    ksi = alpha[:-1, :, None] * A[None, :, :] * B[:, y[1:]].T[:, None, :] * beta[1:, None, :]
    onehot_encoder = OneHotEncoder(sparse=False, categories="auto")
    y_binary = onehot_encoder.fit_transform(y[:, None]).astype(int)
    # axes : t, i1, k1, i2, k2
    return (
        y_binary[:-1, :, None, None, None] *
        y_binary[1:, None, None, :, None] *
        ksi[:, None, :, None, :]
    ).mean(axis=0).flatten()

def second_order_mc_features(dataset, A, B, pi):
    return np.array([
        second_order_marginalized_count(y, A, B, pi)
        for y in dataset
    ])

def mc_features(dataset, A, B, pi, order):
    if order == 1:
        return first_order_mc_features(dataset, A, B, pi)
    elif order == 2:
        return second_order_mc_features(dataset, A, B, pi)

In [88]:
for d in tqdm.trange(3):
    for order in [1, 2]:
        A = np.array(pd.read_csv("data/HMM_{}_A.csv".format(d), index_col=0))
        B = np.array(pd.read_csv("data/HMM_{}_B.csv".format(d), index_col=0))
        pi = np.array(pd.read_csv("data/HMM_{}_pi.csv".format(d), index_col=0))[:, 0]

        train_string = pd.read_csv("data/Xtr{}.csv".format(d), index_col=0)
        train_num = [
            translate(dna_string)
            for dna_string in train_string.values[:, 0]
        ]
        Xtr = mc_features(train_num, A, B, pi, order)
        pd.DataFrame(Xtr).to_csv("data/Xtr{}_HMM_MCK{}.csv".format(d, order), index=False, header=False, sep=" ")

        test_string = pd.read_csv("data/Xte{}.csv".format(d), index_col=0)
        test_num = [
            translate(dna_string)
            for dna_string in test_string.values[:, 0]
        ]
        Xte = mc_features(test_num, A, B, pi, order)
        pd.DataFrame(Xte).to_csv("data/Xte{}_HMM_MCK{}.csv".format(d, order), index=False, header=False, sep=" ")

100%|██████████| 3/3 [00:37<00:00, 12.46s/it]
