In [77]:
from pomegranate import *
import matplotlib.pyplot as plot
import re
import pysam
import math
import pandas as pd

In [2]:
def make_global_alignment_model(target, name = None):
    model = HiddenMarkovModel(name = name)
    s = {}
    
    # add states
    #ss = State(None, name=f"{name}:B")
    i0 = State(DiscreteDistribution({ 'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }), name=f"{name}:I0")

    model.add_state(i0)

    #s[ss.name] = ss
    s[i0.name] = i0
    
    for c in range(len(target)):
        dc = State(None, name=f"{name}:D{c+1}")
        
        mc = State(DiscreteDistribution({
            "A": 0.94 if target[c] == 'A' else 0.02,
            "C": 0.94 if target[c] == 'C' else 0.02,
            "G": 0.94 if target[c] == 'G' else 0.02,
            "T": 0.94 if target[c] == 'T' else 0.02,
        }), name=f"{name}:M{c+1}")
        
        ic = State(DiscreteDistribution({ 'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }), name=f"{name}:I{c+1}")

        model.add_states([mc, ic, dc])
        
        s[dc.name] = dc
        s[mc.name] = mc
        s[ic.name] = ic
        
    # add transitions
    model.add_transition(model.start, s[f'{name}:I0'], 0.05)
    model.add_transition(model.start, s[f'{name}:D1'], 0.05)
    model.add_transition(model.start, s[f'{name}:M1'], 0.90)
    
    model.add_transition(s[f'{name}:I0'], s[f'{name}:I0'], 0.70)
    model.add_transition(s[f'{name}:I0'], s[f'{name}:D1'], 0.15)
    model.add_transition(s[f'{name}:I0'], s[f'{name}:M1'], 0.15)
    
    for c in range(1, len(target)):
        model.add_transition(s[f'{name}:D{c}'], s[f'{name}:D{c+1}'], 0.15)
        model.add_transition(s[f'{name}:D{c}'], s[f'{name}:I{c}'],   0.70)
        model.add_transition(s[f'{name}:D{c}'], s[f'{name}:M{c+1}'], 0.15)
        
        model.add_transition(s[f'{name}:I{c}'], s[f'{name}:D{c+1}'], 0.15)
        model.add_transition(s[f'{name}:I{c}'], s[f'{name}:I{c}'],   0.15)
        model.add_transition(s[f'{name}:I{c}'], s[f'{name}:M{c+1}'], 0.70)

        model.add_transition(s[f'{name}:M{c}'], s[f'{name}:D{c+1}'], 0.05)
        model.add_transition(s[f'{name}:M{c}'], s[f'{name}:I{c}'],   0.05)
        model.add_transition(s[f'{name}:M{c}'], s[f'{name}:M{c+1}'], 0.90)
    
    model.add_transition(s[f'{name}:D{len(target)}'], s[f'{name}:I{len(target)}'], 0.70)
    model.add_transition(s[f'{name}:D{len(target)}'], model.end, 0.30)

    model.add_transition(s[f'{name}:I{len(target)}'], s[f'{name}:I{len(target)}'], 0.15)
    model.add_transition(s[f'{name}:I{len(target)}'], model.end, 0.85)

    model.add_transition(s[f'{name}:M{len(target)}'], s[f'{name}:I{len(target)}'], 0.90)
    model.add_transition(s[f'{name}:M{len(target)}'], model.end, 0.10)
    
    model.bake(merge = 'None')
    
    return model

In [3]:
def make_random_repeat_model(name = 'random'):
    model = HiddenMarkovModel(name = name)
    
    # add states
    ri = State(DiscreteDistribution({ 'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }), name=f"{name}:RI")
    rda = State(None, name=f"{name}:RDA")
    rdb = State(None, name=f"{name}:RDB")
    
    model.add_states([ri, rda, rdb])
    
    # add transitions
    model.add_transition(model.start, rda, 0.5)
    model.add_transition(model.start, ri, 0.5)
    
    model.add_transition(ri, ri, 0.8)
    model.add_transition(ri, rda, 0.10)
    model.add_transition(ri, model.end, 0.10)
    
    model.add_transition(rdb, ri, 0.5)
    model.add_transition(rdb, model.end, 0.5)
    
    model.bake(merge = 'None')
    
    return model

In [94]:
def build_full_model(adapters):
    # rewrite these three lines such that when we create each global alignment model,
    # we immediately connect it to the random repeat model.
    full_model = make_random_repeat_model()
    for k in adapters:
        full_model.add_model(make_global_alignment_model(adapters[k], k))
        
    full_model.bake(merge = 'None')
    
    rda = None
    rdb = None
    txb = None

    for s in full_model.states:
        if "random:RDA" in s.name:
            rda = s
        elif "random:RDB" in s.name:
            rdb = s
        elif "10x_Adapter-start" is s.name:
            txb = s
            
    for s in full_model.states:
        if "start" in s.name and "random" not in s.name:
            #print(s.name)
            full_model.add_transition(rda, s, 1.0/len(adapters))

        if "end" in s.name and "random" not in s.name:
            print(adapters[re.split("[:-]", s.name)[0]])
            print(len(adapters[re.split("[:-]", s.name)[0]]))
            
            if re.match("^[A-P]-", s.name):
                print(f'{s.name} -> 10x')
            
                full_model.add_transition(s, txb, 0.95)
                full_model.add_transition(s, rdb, 0.05)
            else:
                print(f'{s.name} -> rdb')
                
                full_model.add_transition(s, rdb, 1.0/len(adapters))

    full_model.bake()
    
    return full_model

In [5]:
def plot(model):
    l = {}
    for s in model.states:
        l[s] = s.name

    model.plot(labels=l)

In [6]:
def annotate(full_model, seq):
    logp, path = full_model.viterbi(seq)

    ppath = []
    for p, (idx, state) in enumerate(path[1:-1]):
        if "start" in state.name or ":RD" in state.name:
            ppath.append("\n")
            
        ppath.append(f'{state.name} ({idx} {p}) ')
                
    return logp, ppath

In [7]:
def reverse_complement(seq):
    complement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
    
    bases = list(seq) 
    bases = reversed([complement.get(base,base) for base in bases])
    bases = ''.join(bases)

    return bases

In [8]:
# with pysam.FastxFile("sirv.fasta") as fh:
#     for q in fh:
#         adapters[q.name] = q.sequence

In [9]:
# "Poly_A": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
adapters = {
    "10x_Adapter": "TCTACACGACGCTCTTCCGATCT",
    "5'_TSO": "TTTCTTATATGGG",
    "Poly_A": "AAAAAAAAAA",
    "3'_Adapter": "GTACTCTGCGTTGATACCACTGCTT",
    "A": "AGCTTACTTGTGAAGA",
    "B": "ACTTGTAAGCTGTCTA",
    "C": "ACTCTGTCAGGTCCGA",
    "D": "ACCTCCTCCTCCAGAA",
    "E": "AACCGGACACACTTAG",
    "F": "AGAGTCCAATTCGCAG",
    "G": "AATCAAGGCTTAACGG",
    "H": "ATGTTGAATCCTAGCG",
    "I": "AGTGCGTTGCGAATTG",
    "J": "AATTGCGTAGTTGGCC",
    "K": "ACACTTGGTCGCAATC",
    "L": "AGTAAGCCTTCGTGTC",
    "M": "ACCTAGATCAGAGCCT",
    "N": "AGGTATGCCGGTTAAG",
    "O": "AAGTCACCGGCACCTT",
    "P": "ATGAAGTGGCTCGAGA"
}

In [95]:
full_model = build_full_model(adapters)

AAAAAAAAAA
10
Poly_A-end -> rdb
ATGAAGTGGCTCGAGA
16
P-end -> 10x
AAGTCACCGGCACCTT
16
O-end -> 10x
AGGTATGCCGGTTAAG
16
N-end -> 10x
ACCTAGATCAGAGCCT
16
M-end -> 10x
AGTAAGCCTTCGTGTC
16
L-end -> 10x
ACACTTGGTCGCAATC
16
K-end -> 10x
AATTGCGTAGTTGGCC
16
J-end -> 10x
AGTGCGTTGCGAATTG
16
I-end -> 10x
ATGTTGAATCCTAGCG
16
H-end -> 10x
AATCAAGGCTTAACGG
16
G-end -> 10x
AGAGTCCAATTCGCAG
16
F-end -> 10x
AACCGGACACACTTAG
16
E-end -> 10x
ACCTCCTCCTCCAGAA
16
D-end -> 10x
ACTCTGTCAGGTCCGA
16
C-end -> 10x
ACTTGTAAGCTGTCTA
16
B-end -> 10x
AGCTTACTTGTGAAGA
16
A-end -> 10x
TTTCTTATATGGG
13
5'_TSO-end -> rdb
GTACTCTGCGTTGATACCACTGCTT
25
3'_Adapter-end -> rdb
TCTACACGACGCTCTTCCGATCT
23
10x_Adapter-end -> rdb


In [103]:
def inspect_transitions_from(model, state_name):
    q = model.dense_transition_matrix()

    a = pd.DataFrame(q, columns=[x.name for x in model.states])
    a.index = [x.name for x in model.states]

    for s in model.states:
        q = a.loc[state_name, s.name]
        if q > 0.0:
            print(f'{state_name} -> {s.name} = {q}')
            
def inspect_transitions_to(model, state_name):
    q = model.dense_transition_matrix()

    a = pd.DataFrame(q, columns=[x.name for x in model.states])
    a.index = [x.name for x in model.states]

    for s in model.states:
        q = a.loc[s.name, state_name]
        if q > 0.0:
            print(f'{s.name} -> {state_name} = {q}')

print("from:")
inspect_transitions_to(full_model, 'A:M16')

print("to:")
inspect_transitions_from(full_model, 'A:M16')

from:
A:I15 -> A:M16 = 0.7
A:M15 -> A:M16 = 0.9
A:D15 -> A:M16 = 0.15
to:
A:M16 -> A:I16 = 0.9
A:M16 -> random:RDB = 0.10000000000000002


In [12]:
# samfile = pysam.AlignmentFile("SM-KM1PN.m64020_201213_022403.corrected.bam", "rb", check_sq=False)

# read_sequences = []
# for read in samfile:
#     read_sequences.append(read.query_sequence)
    
#     if len(read_sequences) > 10:
#         break

# seqs = [list(x) for x in read_sequences]
        
# full_model.fit(sequences=seqs, algorithm='baum-welch')

In [34]:
samfile = pysam.AlignmentFile("SM-KM1PN.m64020_201213_022403.corrected.bam", "rb", check_sq=False)

i = 0
for read in samfile:
    flogp = -math.inf
    for seq in [read.query_sequence, reverse_complement(read.query_sequence)]:
        logp, ppath = annotate(full_model, seq)
                
        if logp > flogp:
            flogp = logp
            fppath = ppath

    print(read.query_name)
    print(flogp)
    print("".join(fppath))
    print("")
        
    i += 1
    
    if i > 10:
        break
        
samfile.close()

m64020_201213_022403/1/ccs
-15111.23438861891

random:RDA (676 0) 
10x_Adapter-start (1000 1) 10x_Adapter:D1 (1001 2) 10x_Adapter:M2 (35 3) 10x_Adapter:M3 (40 4) 10x_Adapter:M4 (41 5) 10x_Adapter:M5 (42 6) 10x_Adapter:M6 (43 7) 10x_Adapter:M7 (44 8) 10x_Adapter:M8 (45 9) 10x_Adapter:M9 (46 10) 10x_Adapter:M10 (25 11) 10x_Adapter:M11 (26 12) 10x_Adapter:M12 (27 13) 10x_Adapter:M13 (28 14) 10x_Adapter:M14 (29 15) 10x_Adapter:M15 (30 16) 10x_Adapter:M16 (31 17) 10x_Adapter:M17 (32 18) 10x_Adapter:M18 (33 19) 10x_Adapter:M19 (34 20) 10x_Adapter:M20 (36 21) 10x_Adapter:M21 (37 22) 10x_Adapter:M22 (38 23) 10x_Adapter:M23 (39 24) 10x_Adapter:I23 (16 25) 
random:RDB (1024 26) random:RI (674 27) random:RI (674 28) random:RI (674 29) random:RI (674 30) random:RI (674 31) random:RI (674 32) random:RI (674 33) random:RI (674 34) random:RI (674 35) random:RI (674 36) random:RI (674 37) random:RI (674 38) random:RI (674 39) random:RI (674 40) random:RI (674 41) random:RI (674 42) random:RI (674 43) 