In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import lib.assembly_graph
import lib.plot
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import scipy as sp

from tqdm import tqdm

In [None]:
def build_from_seed(seed, depth):
    right, left = lib.assembly_graph.build_full_from_seed_graph(seed)
    depth = pd.Series(lib.assembly_graph.add_reverse_complement_depth(depth)).astype(float)
    assert (depth.index.isin(right.keys()) | depth.index.isin(left.keys())).all()
    dgraph = pd.DataFrame(np.zeros((len(depth), len(depth))), index=depth.index, columns=depth.index)
    for unitig in depth.index:
        dgraph.loc[unitig, right[unitig]] = 1
    return dgraph, depth

In [None]:
def initialize_messages(dgraph, depth):
    # Step -1
    send_to_r = dgraph
    send_to_l = dgraph.T
    total_from_l = send_to_r.sum()
    total_from_r = send_to_l.sum()
    send_to_r_next = (send_to_l / total_from_r).multiply(depth, axis=1).T
    send_to_l_next = (send_to_r / total_from_l).multiply(depth, axis=1).T
    send_to_r = send_to_r_next.fillna(0)
    send_to_l = send_to_l_next.fillna(0)

    # Step 0
    total_from_l = send_to_r.sum()
    total_from_r = send_to_l.sum()
    send_to_r_next = (send_to_l / total_from_r).multiply(depth, axis=1).T
    send_to_l_next = (send_to_r / total_from_l).multiply(depth, axis=1).T
    send_to_r = send_to_r_next.fillna(0)
    send_to_l = send_to_l_next.fillna(0)

    return send_to_r, send_to_l

In [None]:
def iterate_messages(send_to_r, send_to_l, depth, weight=1.0):
    total_from_l = send_to_r.sum()
    total_from_r = send_to_l.sum()
    # Update depth
    next_depth = (total_from_r + total_from_l + (weight * depth)) / (2 + weight)
    # Scale the depth so there's no overall loss.
    depth = next_depth * (depth.sum() / next_depth.sum())
    # Calculate next message
    send_to_r_next = (send_to_l / total_from_r).multiply(depth, axis=1).T
    send_to_l_next = (send_to_r / total_from_l).multiply(depth, axis=1).T
    send_to_r = send_to_r_next.fillna(0)
    send_to_l = send_to_l_next.fillna(0)
    return send_to_r, send_to_l, depth

# Tall saw-horse

In [None]:
seed = {
    'AACCG': ['ACCGG'],
    'ACCGG': ['CCGGG', 'CCGGA'],
    'TACCG': ['ACCGG'],
    'TAACC': ['AACCG'],
    'TTACC': ['TACCG'],
    'CCGGG': ['CGGGT'],
    'CCGGA': ['CGGAT'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'AACCG': 9,
    'ACCGG': 10,
    'CCGGG': 9,
    'CCGGA': 1,
    'TACCG': 1,
    'TAACC': 9,
    'TTACC': 1,
    'CGGGT': 9,
    'CGGAT': 1,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

In [None]:
seed = {
    'AACCG': ['ACCGG'],
    'ACCGG': ['CCGGG', 'CCGGA'],
    'TACCG': ['ACCGG'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'AACCG': 3,
    'ACCGG': 4,
    'CCGGG': 3,
    'CCGGA': 1,
    'TACCG': 1,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

# Saw-horse

# Cycle w/ Switch-back

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGT'],
    'CCGGT': ['CGGTA'],
    'CGGTA': ['GGTAC'],
    'GGTAC': ['GTACC'],
    'GTACC': ['TACCC'],
    'TACCC': ['ACCCG'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGT': 1,
    'CGGTA': 1,
    'GGTAC': 1,
    'GTACC': 1,
    'TACCC': 1,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

# Six-Cycle

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA'],
    'CCGGA': ['CGGAC'],
    'CGGAC': ['GGACC'],
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG']
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 2,
    'CGGAC': 1,
    'GGACC': 1,
    'GACCC': 1,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

# Six-cycle w/ Spur

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA', 'CCGGC'],
    'CCGGA': ['CGGAC'],
    'CGGAC': ['GGACC'],
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG']
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 1,
    'CGGAC': 1,
    'GGACC': 1,
    'GACCC': 1,
    'CCGGC': 1,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

# Double-six-cycle

In [None]:
seed = {
    # Top cycle
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG'],
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA'],
    'CCGGA': ['CGGAC'],
    
    # Link
    'CGGAC': ['GGACC', 'GGACT'],
    
    # Bottom cycle
    'GGACT': ['GACTC'],
    'GACTC': ['ACTCG'],
    'ACTCG': ['CTCGG'],
    'CTCGG': ['TCGGA'],
    'TCGGA': ['CGGAC'],
    
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    # Top cycle
    'GGACC': 1,
    'GACCC': 1,
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 1,
    
    # Link
    'CGGAC': 3,
    
    # Bottom 
    'GGACT': 2, 
    'GACTC': 2,
    'ACTCG': 2,
    'CTCGG': 2,
    'TCGGA': 2,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth

# Lonely-stick

In [None]:
seed = {
    # Top cycle
    'GGACC': ['GACCT'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    # Top cycle
    'GGACC': 1,
    'GACCT': 2,
})
dgraph, depth0 = build_from_seed(seed, observed_depth)
send_to_r, send_to_l = initialize_messages(dgraph, depth0)

depth = depth0
thresh = 1e-5
i = 0
tbar = tqdm(position=0, leave=True)
while True:
    send_to_r, send_to_l, new_depth = iterate_messages(
        send_to_r, send_to_l, depth, weight=1.0
    )
    delta = new_depth - depth
    change = np.sqrt(np.sum(np.square(new_depth - depth)))
    depth = new_depth
    tbar.update()
    tbar.set_postfix({'change': change})
    if change < thresh:
        tbar.refresh()
        break
    
sns.heatmap(send_to_r + send_to_l)
depth