# Setup

In [None]:
import numpy as np
import argparse
from multiprocessing import Pool
import time
import os, shutil

# Reading Sequence from Fasta

In [None]:
def read_sequences(fasta_path):
    file = open(fasta_path, 'r')
    line_no = 1
    sequences = []
    while True:
        line = file.readline()
        if not line:
            break
        if(line_no % 2 == 0):
            sequences.append(line.strip())
        line_no += 1
    seq_count = len(sequences)
    seq_len = len(sequences[0])
    return sequences, seq_count, seq_len 

# Reading Motif Length from Txt

In [None]:
def read_motiflen(txt_path):
    file = open(txt_path, 'r')
    motif_len = int(file.readline().strip())
    return motif_len

# calculate PWM

In [None]:
def calc_pwm(sequences, location, seq_count, motif_len, skip_flag, skip_idx):
    #print(location)
    #print([sequences[s][location[s]:location[s]+motif_len]for s in range(len(sequences))])

    pwm = [] # A C G T -> column order

    pseudo_count = 0.0001

    for loc in range(motif_len):
        A = 0
        C = 0
        G = 0
        T = 0
        for idx in range(seq_count):
            if(skip_flag and idx == skip_idx):
                continue
            base = sequences[idx][location[idx] + loc]
            
            #print('loc', loc, 'idx', idx, 'base', base)
            if(base == 'A'):
                A += 1
            elif(base == 'C'):
                C += 1
            elif(base == 'G'):
                G += 1
            elif(base == 'T'):
                T += 1
        pwm.append([A, C, G, T])
    
    pwm = np.array(pwm)
    #print(pwm)

    pwm_sum = np.sum(pwm, axis=1)
    #print(pwm_sum)

    pwm = pwm / pwm_sum[:, None]
    #print(pwm)

    pwm += pseudo_count
    #print(pwm)

    pwm_sum = np.sum(pwm, axis=1)
    #print(pwm_sum)

    pwm = pwm / pwm_sum[:, None]
    #print(pwm)
    
    return pwm

# Information Content

In [None]:
def calc_information_content(pwm, base_pwm, motif_len):
    ic = 0

    #print(pwm)
    #print(base_pwm)

    for i in range(motif_len):
        for j in range(4):
            if(pwm[i][j] != 0):
                ic += pwm[i][j] * np.log2(pwm[i][j] / base_pwm[i][j])
    return ic


# Markov Chain

In [None]:
def exec_chain(sequences, seq_len, seq_count, motif_len, epoch_count):
    #start_time = time.time()
    best_ic = np.NINF
    best_loc = None
    best_pwm = None # ML rows, 4 columns

    pseudo_count = 0.0001

    base_pwm = [[0.25] * 4] * motif_len

    #print(base_pwm)

    location = np.random.randint(low=0, high=seq_len-motif_len+1, size=seq_count)

    for epoch in range(epoch_count):
        #print('Epoch', epoch)
        skip_idx = np.random.randint(low=0, high=seq_count)
        #print('skip_idx', skip_idx)
        pwm = calc_pwm(sequences, location, seq_count, motif_len, True, skip_idx)
        #print(pwm)
        prob = []

        for i in range(seq_len - motif_len + 1): # per substring
            q = 1
            p = 1
            for j in range(motif_len): # per motif position
                base = sequences[skip_idx][i + j]
                if(base == 'A'):
                    q *= pwm[j][0]
                    p *= base_pwm[j][0]
                elif(base == 'C'):
                    q *= pwm[j][1]
                    p *= base_pwm[j][1]
                elif(base == 'G'):
                    q *= pwm[j][2]
                    p *= base_pwm[j][2]
                elif(base == 'T'):
                    q *= pwm[j][3]
                    p *= base_pwm[j][3]
            prob.append(q/p)
        
        prob = np.array(prob)

        prob_sum = np.sum(prob)
        #print(prob_sum)

        prob = prob / prob_sum
        #print(prob)

        choice = np.random.choice(a=list(range(len(prob))), p=prob)
        #print('choice', choice)
        location[skip_idx] = choice
        #print(location)
        pwm = calc_pwm(sequences, location, seq_count, motif_len, False, -1)

        ic = calc_information_content(pwm, base_pwm, motif_len)
        if(ic > best_ic):
            best_ic = ic
            best_loc = location
            best_pwm = pwm

    print('best_ic', best_ic)

    #runtime = time.time() - start_time

    return best_ic, best_loc, best_pwm#, runtime



# Gibbs Sampler

In [None]:
def gibbs_sampler(sequences, seq_len, seq_count, motif_len, chain_count, epoch_count, log_path):
    best_ic = np.NINF
    best_loc = None
    best_pwm = None # ML rows, 4 columns

    log_file = open(log_path, 'w')
    log_file.write('Chain,Best_IC,Runtime')

    total_time = 0
    for chain in range(chain_count):
        print('Chain', chain)
        start_time = time.time()
        ic, loc, pwm = exec_chain(sequences, seq_len, seq_count, motif_len, epoch_count)
        if(ic > best_ic):
            best_ic = ic
            best_loc = loc
            best_pwm = pwm
        end_time = time.time()
        runtime = end_time - start_time
        print('runtime', runtime)
        total_time += runtime
        log_file.write(str(chain) + ',' + str(best_ic) + ',' + str(runtime))
        log_file.flush()
    
    log_file.close()

    return best_ic, best_loc, best_pwm, total_time

In [None]:
def gibbs_sampler2(sequences, seq_len, seq_count, motif_len, chain_count, epoch_count, log_path):
    best_ic = np.NINF
    best_loc = None
    best_pwm = None # ML rows, 4 columns
    best_rt = None

    log_file = open(log_path, 'w')
    log_file.write('Chain,Best_IC,Runtime')

    pool_args = [[sequences, seq_len, seq_count, motif_len, epoch_count]] * chain_count
    pool = Pool()
    result = pool.starmap(exec_chain, pool_args)

    for chain in range(chain_count):
        chain_result = result[chain]
        if(chain_result[0] > best_ic):
            best_ic = chain_result[0]
            best_loc = chain_result[1]
            best_pwm = chain_result[2]
            best_rt = chain_result[3]
        log_file.write(str(chain) + ',' + str(chain_result[0]) + ',' + str(chain_result[3]))
        log_file.flush()
    log_file.close()

    return best_ic, best_loc, best_pwm, best_rt

# Main

In [None]:
def execute(dataset_no, epoch_count, chain_count):
    base_data_dir = '/content/drive/MyDrive/CS466/Project/Data/Dataset-2'
    base_result_dir = '/content/drive/MyDrive/CS466/Project/Result/Dataset-2'

    if not (os.path.exists(base_result_dir)):
        os.mkdir(base_result_dir)

    data_dir = base_data_dir + '/' + str(dataset_no)
    result_dir = base_result_dir + '/' + str(dataset_no)

    if(os.path.exists(result_dir)):
        shutil.rmtree(result_dir)
    os.mkdir(result_dir)

    filenames = ['sequences.fa', 'motiflength.txt', 'motif.txt', 'sites.txt']

    for filename in filenames:
        src = data_dir + '/' + filename
        dst = result_dir + '/' + filename
        shutil.copyfile(src, dst)

    sequence_path = data_dir + '/sequences.fa'
    motiflen_path = data_dir + '/motiflength.txt'

    sequences, seq_count, seq_len = read_sequences(sequence_path)
    motif_len = read_motiflen(motiflen_path)

    log_path = result_dir + '/' + str(dataset_no) + '.log'

    ic, loc, pwm, runtime = gibbs_sampler(sequences, seq_len, seq_count, motif_len, chain_count, epoch_count, log_path)
    file = open(result_dir + '/ic.txt', 'w')
    file.write(str(ic))
    file.close()

    file = open(result_dir + '/runtime.txt', 'w')
    file.write(str(runtime))
    file.close()

    file = open(result_dir + '/predictedmotif.txt', 'w')
    file.write('>MOTIF1 ' + str(motif_len))
    for i in range(motif_len):
        for j in range(4):
            file.write(str(pwm[i][j]) + ' ')
        if(i != motif_len - 1):
            file.write('\n')
    file.write('<')
    file.close()

    file = open(result_dir + '/predictedsites.txt', 'w')
    for i in range(seq_count):
        file.write(str(loc[i]))
        if(i != seq_count - 1):
            file.write(',')
    file.close()

# Command Line Arguments

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("-d", "--dataset_no", type=int, required=True, default=None)
    parser.add_argument("-e", "--epoch_count", type=int, required=True, default=10000)
    parser.add_argument("-c", "--chain_count", type=int, required=True, default=100)

    return parser.parse_args()    

# Main

In [None]:
def main():
    #args = parse_args()

    #dataset_no = args.dataset_no
    epoch_count = 10000#args.epoch_count
    chain_count = 10#args.chain_count
    for dataset_no in range(1, 71):
        execute(dataset_no, epoch_count, chain_count)

In [None]:
main()

Chain 0
best_ic 10.869659446375117
runtime 33.60573387145996
Chain 1
best_ic 10.979885450914143
runtime 33.484482288360596
Chain 2
best_ic 15.964660292578046
runtime 33.54975342750549
Chain 3
best_ic 11.237832521119893
runtime 33.42374348640442
Chain 4
best_ic 11.386995562292997
runtime 33.50162196159363
Chain 5
best_ic 10.825824212265482
runtime 33.40541100502014
Chain 6
best_ic 10.515850388515792
runtime 33.57588076591492
Chain 7
best_ic 10.481872330291008
runtime 33.60708427429199
Chain 8
best_ic 11.393828880039065
runtime 33.56342935562134
Chain 9
best_ic 10.638603928568125
runtime 33.66911768913269
