In [1]:
import numpy as np
import pandas as pd
import tqdm

In [2]:
def translate(dna_string):
    dna_string = dna_string.replace("A", "0")
    dna_string = dna_string.replace("T", "1")
    dna_string = dna_string.replace("G", "2")
    dna_string = dna_string.replace("C", "3")
    seq = list(dna_string)
    return np.array(seq).astype(int)

In [3]:
dataset_string = pd.read_csv("data/Xtr0.csv", index_col=0)
dataset = [
    translate(dna_string)
    for dna_string in dataset_string.values[:, 0]
]

In [4]:
states = np.arange(4)
observations = np.arange(4)

pi = np.random.random(4)
pi /= pi.sum()

A = np.random.random((4, 4))
A /= A.sum(axis=1)[:, None]

B = np.random.random((4, 4))
B /= B.sum(axis=1)[:, None]

def generate(A, B, pi, T):
    x = []
    y = []
    for t in range(T):
        if t == 0:
            x.append(np.random.choice(states, p=pi))
        else:
            x.append(np.random.choice(states, p=A[x[-1]]))
        y.append(np.random.choice(observations, p=B[x[-1]]))
    x = np.array(x)
    y = np.array(y)
    return x, y

x, y = generate(A, B, pi, 100)

In [5]:
def loglikelihood(y, A, B, pi):
    T = len(y)
    n_s = A.shape[0]
    
    norm = np.empty(T)

    alpha = np.empty((T, n_s))
    alpha[0] = pi * B[:, y[0]]
    norm[0] = 1/alpha[0].sum()
    for t in range(T-1):
        alpha[t+1] = B[:, y[t+1]] * A.T.dot(alpha[t])
        norm[t+1] = 1/alpha[t+1].sum()
        alpha[t+1] *= norm[t+1]

    loglike = alpha[T-1].sum() - np.log(norm).sum()
    return loglike

In [6]:
def forward_backward(y, A, B, pi):
    T = len(y)
    n_s = A.shape[0]
    
    norm = np.empty(T)

    alpha = np.empty((T, n_s))
    alpha[0] = pi * B[:, y[0]]
    norm[0] = 1/alpha[0].sum()
    for t in range(T-1):
        alpha[t+1] = B[:, y[t+1]] * A.T.dot(alpha[t])
        norm[t+1] = 1/alpha[t+1].sum()
        alpha[t+1] *= norm[t+1]

    beta = np.empty((T, n_s))
    beta[T-1] = norm[T-1] * 1
    for t in range(T-2, -1, -1):
        beta[t] = A.dot(beta[t+1] * B[:, y[t+1]])
        beta[t] *= norm[t]

    return alpha, beta, norm

In [7]:
def update(y, A, B, pi, alpha, beta, norm):
    T = len(y)
    n_s = A.shape[0]
    n_o = B.shape[1]
    
    new_pi = np.empty(n_s)
    new_A = np.empty((n_s, n_s))
    new_B = np.empty((n_s, n_o))
    
    for i in range(n_s):
        new_pi[i] = alpha[0, i] * beta[0, i] / alpha[0].dot(beta[0])
    
    for i in range(n_s):
        for j in range(n_s):
            new_A[i, j] = sum(
                [alpha[t, i] * A[i, j] * B[j, y[t+1]] * beta[t+1, j] for t in range(T-1)]
            ) / sum(
                [alpha[t, i] * beta[t, i] / norm[t] for t in range(T-1)]
            )
    for j in range(n_s):
        for k in range(n_o):
            new_B[j, k] = sum(
                [alpha[t, j] * beta[t, j] / norm[t] for t in range(T) if y[t] == k]
            ) / sum(
                [alpha[t, j] * beta[t, j] / norm[t] for t in range(T)]
            )
            
    #new_pi = np.ones(n_s) / n_s
    return new_A, new_B, new_pi

In [8]:
def baum_welch_step(y, A, B, pi):
    alpha, beta, norm = forward_backward(y, A, B, pi)
    return update(y, A, B, pi, alpha, beta, norm)

def baum_welch(y, A0, B0, pi0, iterations=100):
    A, B, pi = A0, B0, pi0
    for it in tqdm.trange(iterations):
        A, B, pi = baum_welch_step(y, A, B, pi)
    return A, B, pi

In [9]:
def baum_welch_step_dataset(dataset, A, B, pi):
    
    n_s = A.shape[0]
    n_o = B.shape[1]
    
    new_A_num = np.zeros((n_s, n_s))
    new_A_den = np.zeros((n_s, n_s))
    new_B_num = np.zeros((n_s, n_o))
    new_B_den = np.zeros((n_s, n_o))
    
    for y in dataset:
        T = len(y)
        alpha, beta, norm = forward_backward(y, A, B, pi)

        for i in range(n_s):
            for j in range(n_s):
                new_A_num[i, j] += sum(
                    [alpha[t, i] * A[i, j] * B[j, y[t+1]] * beta[t+1, j] for t in range(T-1)]
                )
                new_A_den[i, j] += sum(
                    [alpha[t, i] * beta[t, i] / norm[t] for t in range(T-1)]
                )
        for j in range(n_s):
            for k in range(n_o):
                new_B_num[j, k] += sum(
                    [alpha[t, j] * beta[t, j] / norm[t] for t in range(T) if y[t] == k]
                )
                new_B_den[j, k] += sum(
                    [alpha[t, j] * beta[t, j] / norm[t] for t in range(T)]
                )

    return new_A_num/new_A_den, new_B_num/new_B_den, [0.25, 0.25, 0.25, 0.25]

def baum_welch_dataset(y, A0, B0, pi0, iterations=100):
    A, B, pi = A0, B0, pi0
    for it in tqdm.trange(iterations):
        A, B, pi = baum_welch_step_dataset(y, A, B, pi)
    return A, B, pi

In [10]:
def baum_welch_dna(y, iterations=100):
    A_guess = np.array([
        [  0,   1,   0,   0,], # exon 1 => exon 2
        [  0,   0,   1,   0,], # exon 2 => exon 3
        [0.5,   0,   0, 0.5,], # exon 3 => exon 1 or intron
        [0.5,   0,   0, 0.5,],
    ])
    A0 = np.random.random((4, 4)) * A_guess
    A0 /= A0.sum(axis=1)[:, None]
    
    B0 = np.random.random((4, 4)) / 4
    B0 /= B0.sum(axis=1)[:, None]
    
    pi0 = np.random.random(4)
    pi0 /= pi0.sum()
    
    print(A0)
    print(B0)
    print(pi0)
    
    if type(y) != type([]):
        y = [y]
    return baum_welch_dataset(y, A0, B0, pi0, iterations)

In [11]:
new_A, new_B, new_pi = baum_welch_dna(dataset[:10], 100)
print(new_A)
print(new_B)
print(new_pi)

  2%|▏         | 2/100 [00:00<00:08, 11.66it/s]

[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.68724832 0.         0.         0.31275168]
 [0.50241789 0.         0.         0.49758211]]
[[0.17266948 0.51251705 0.22207537 0.0927381 ]
 [0.33847022 0.04182091 0.49670836 0.12300051]
 [0.57097777 0.10182772 0.30127992 0.02591459]
 [0.29083936 0.41171089 0.21290062 0.08454913]]
[0.08884133 0.36061344 0.29969832 0.25084691]


100%|██████████| 100/100 [00:06<00:00, 15.81it/s]

[[0.         1.         0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.3413248  0.         0.         0.6586752 ]
 [0.39080701 0.         0.         0.60919299]]
[[2.61930577e-01 2.49944032e-01 2.15091484e-01 2.73033907e-01]
 [1.86882336e-01 5.02135784e-02 2.23948215e-01 5.38955870e-01]
 [5.43629305e-01 2.76186237e-01 3.52228794e-04 1.79832230e-01]
 [6.79022981e-02 3.21890362e-01 4.64550633e-01 1.45656707e-01]]
[0.25, 0.25, 0.25, 0.25]



