In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import lib.assembly_graph
import lib.plot
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import scipy as sp
from collections import defaultdict
from itertools import chain

from tqdm import tqdm

import sparse

np.random.seed(1)

In [None]:
## Figure out why a (n, n) sparse dmatrix times a (n,) depth vector produces an (n, 1) result
# This doesn't seem to happen for dense matrices

n = 5

# Dense
eye = np.eye(n, k=1) # Offset diagonal of 1's
rng = np.arange(n)
# print(eye * rng)                     # Multiplies vector along the bottom 
# print(rng * eye)                     # Multiplies vector along the bottom 
# print(eye * np.expand_dims(rng, 0))  # Multiplies vector along the bottom
# print(eye * np.expand_dims(rng, 1))  # Multiplies vector along the side

# Sparse rng and sparse eye
eye = sp.sparse.csr_matrix(np.eye(n, k=1))
rng = sp.sparse.csr_matrix(np.arange(n))
rng_diag = sp.sparse.diags(np.arange(n), format='dia')  # = rng.multiply(sp.sparse.eye(n))
# print(eye * rng)                     # "ValueError: dimension mismatch"
print((rng * eye).toarray())           # (1, n) row vector?? No idea what's going on here.
print(eye.multiply(rng).toarray())     # Multiplies rng vector along the bottom
print(rng.multiply(eye).toarray())     # Multiplies rng vector along the bottom
print(eye.T.multiply(rng).T.toarray())     # Multiplies rng vector along the side
print((rng_diag * eye).toarray())     # Multiplies rng vector along the side
print((eye * rng_diag).toarray())     # Multiplies rng vector along the bottom

In [None]:
## Figure out why a (n, n) sparse dmatrix times a (n,) depth vector produces an (n, 1) result
# This doesn't seem to happen for dense matrices

n = 5

# Dense
eye = np.eye(n, k=1) # Offset diagonal of 1's
rng = np.arange(n)
# print(eye * rng)                     # Multiplies vector along the bottom 
# print(rng * eye)                     # Multiplies vector along the bottom 
# print(eye * np.expand_dims(rng, 0))  # Multiplies vector along the bottom
# print(eye * np.expand_dims(rng, 1))  # Multiplies vector along the side

# Sparse sparse eye and dense rng
eye = sparse.COO.from_numpy(np.eye(n, k=1))
rng = np.arange(n)
# rng_diag = sp.sparse.diags(np.arange(n), format='dia')  # = rng.multiply(sp.sparse.eye(n))
print((eye * rng).todense())           # Multiplies rng vector along the bottom
print((rng * eye).todense())           # Multiplies rng vector along the bottom
print((eye * rng.reshape((n, 1))).todense())           # Multiplies rng vector along the side
print((np.multiply(eye.T, rng)).T.todense())           # Multiplies rng vector along the side

In [None]:
def simulate_sequence(n):
    sequence = ''.join(np.random.choice(['A', 'C', 'G', 'T'], size=n))
    return sequence


def collect_kmers(sequences):
    all_kmers = set()
    
    for seq in sequences:
        for i in range(len(seq) - k):
            kmer = seq[i: i + k]
            if kmer not in all_kmers:
                kmer_rc = lib.assembly_graph.reverse_complement(kmer)
                all_kmers |= set([kmer, kmer_rc])

    return all_kmers


def build_seed_from_one_sequence(sequence, k):
    # Build graph and simulate depths
    seed = defaultdict(set)
    last_kmer = sequence[:k]
    for i in range(1, len(sequence) - k):
        kmer = sequence[i: i + k]
        seed[last_kmer].add(kmer)
        last_kmer = kmer
        
    return seed


def all_nodes(links):
    return set(links.keys()) | set(chain.from_iterable(links.values()))


def build_depth_from_seed(seed, depth_fn):
    nodes = all_nodes(seed)
    depth = defaultdict(lambda: 0)
    depth.update({unitig: depth_fn(unitig) for unitig in nodes})
    return depth

In [None]:
def mean_depth(r, l, d, weight):
    # TODO: Confirm that this broadcasts correctly
    return (r + l + weight * d) / (2 + weight)

def nan_to_num(x, value=0):
    return np.where(np.isnan(x), 0, x)

In [None]:
sp.sparse.coo_matrix.tocsr

In [None]:
def build_coo_dgraph_from_seed(seed, depth):
    right, left = lib.assembly_graph.build_full_from_seed_graph(seed)
    depth = pd.Series(lib.assembly_graph.add_reverse_complement_depth(depth)).astype(float)
    
    nodes = all_nodes(right)
    idx = pd.Series(depth.index, name='unitig').reset_index().set_index('unitig').squeeze()
    assert (depth.index.isin(right.keys()) | depth.index.isin(left.keys())).all()
    items = []
    for l in depth.index:
        i = idx[l]
        for r in right[l]:
            j = idx[r]
            items.append(((i, j), 1))
    dgraph = sparse.COO.from_iter(items, shape=(len(idx), len(idx)))
    return dgraph, depth.values, idx.index

def build_csr_dgraph_from_seed(seed, depth):
    right, left = lib.assembly_graph.build_full_from_seed_graph(seed)
    depth = pd.Series(lib.assembly_graph.add_reverse_complement_depth(depth)).astype(float)
    
    nodes = all_nodes(right)
    idx = pd.Series(depth.index, name='unitig').reset_index().set_index('unitig').squeeze()
    assert (depth.index.isin(right.keys()) | depth.index.isin(left.keys())).all()
    ii, jj = [], []
    for l in depth.index:
        for r in right[l]:
            ii.append(idx[l])
            jj.append(idx[r])
    dgraph = sp.sparse.coo_matrix((np.ones_like(ii), (ii, jj)), shape=(len(idx), len(idx))).tocsr()
    return dgraph, depth.values, idx.index


def build_dgraph_from_seed(seed, depth):
    right, left = lib.assembly_graph.build_full_from_seed_graph(seed)
    depth = pd.Series(lib.assembly_graph.add_reverse_complement_depth(depth)).astype(float)
    
    nodes = all_nodes(right)
    idx = pd.Series(depth.index, name='unitig').reset_index().set_index('unitig').squeeze()
    assert (depth.index.isin(right.keys()) | depth.index.isin(left.keys())).all()
    items = []
    for l in depth.index:
        i = idx[l]
        for r in right[l]:
            j = idx[r]
            items.append(((i, j), 1))
    dgraph = sparse.COO.from_iter(items, shape=(len(idx), len(idx)))
    return dgraph.todense(), depth.values, idx.index

In [None]:
def initialize_messages(dgraph, depth):
    # Step -1
    send_to_r = dgraph
    send_to_l = dgraph.T
    print(type(send_to_r))
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (np.multiply(depth, proportions_r)).T
    send_to_l_next = (np.multiply(depth, proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next
    print(type(send_to_r))

    # Step 0
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (np.multiply(depth, proportions_r)).T
    send_to_l_next = (np.multiply(depth, proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next

    return send_to_r, send_to_l

def iterate_messages(
    send_to_r,
    send_to_l,
    depth,
    new_depth_fn=mean_depth,
    weight=1.0,
):
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    # Update depth
    next_depth = new_depth_fn(total_from_r, total_from_l, depth, weight)
    # Scale the depth so there's no overall loss.
    depth = next_depth * (depth.sum() / next_depth.sum())
    # Calculate next message
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (np.multiply(depth, proportions_r)).T
    send_to_l_next = (np.multiply(depth, proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next
    return send_to_r, send_to_l, depth

def initialize_messages_sparse(dgraph, depth):
    # Step -1
    send_to_r = dgraph
    send_to_l = dgraph.T
    print(type(send_to_r))
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (depth.multiply(proportions_r)).T
    send_to_l_next = (depth.multiply(proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next
    print(type(send_to_r))

    # Step 0
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (depth.multiply(proportions_r)).T
    send_to_l_next = (depth.multiply(proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next

    return send_to_r, send_to_l

def iterate_messages_sparse(
    send_to_r,
    send_to_l,
    depth,
    new_depth_fn=mean_depth,
    weight=1.0,
):
    total_from_l = send_to_r.sum(0)
    total_from_r = send_to_l.sum(0)
    # Update depth
    next_depth = new_depth_fn(total_from_r, total_from_l, depth, weight)
    # Scale the depth so there's no overall loss.
    depth = next_depth * (depth.sum() / next_depth.sum())
    # Calculate next message
    proportions_r = nan_to_num(send_to_l / total_from_r)
    proportions_l = nan_to_num(send_to_r / total_from_l)
    send_to_r_next = (np.multiply(depth, proportions_r)).T
    send_to_l_next = (np.multiply(depth, proportions_l)).T
    send_to_r = send_to_r_next
    send_to_l = send_to_l_next
    return send_to_r, send_to_l, depth


def run_message_passing_sparse(seed, observed_depth, thresh=1e-3):
    dgraph, depth0, idx = build_csr_dgraph_from_seed(seed, observed_depth)
    send_to_r, send_to_l = initialize_messages_sparse(dgraph, depth0)

    depth = depth0
    tbar = tqdm(position=0, leave=True)
    while True:
        send_to_r, send_to_l, new_depth = iterate_messages_sparse(
            send_to_r, send_to_l, depth, new_depth_fn=mean_depth,
        )
        delta = new_depth - depth
        change = np.sqrt(np.sum(np.square(new_depth - depth)))
        depth = new_depth
        tbar.update()
        tbar.set_postfix({'change': change})
        if change < thresh:
            print("CONVERGED")
            break
    # Recover labels
    send_to_r, send_to_l = [pd.DataFrame(send.todense(), index=idx, columns=idx) for send in [send_to_r, send_to_l]]
    depth = pd.Series(depth, index=idx)
    depth0 = pd.Series(depth0, index=idx)
    return depth, send_to_r, send_to_l, depth0

def run_message_passing(seed, observed_depth, thresh=1e-3):
    dgraph, depth0, idx = build_dgraph_from_seed(seed, observed_depth)
    send_to_r, send_to_l = initialize_messages(dgraph, depth0)

    depth = depth0
    tbar = tqdm(position=0, leave=True)
    while True:
        send_to_r, send_to_l, new_depth = iterate_messages(
            send_to_r, send_to_l, depth, new_depth_fn=mean_depth,
        )
        delta = new_depth - depth
        change = np.sqrt(np.sum(np.square(new_depth - depth)))
        depth = new_depth
        tbar.update()
        tbar.set_postfix({'change': change})
        if change < thresh:
            print("CONVERGED")
            break
    # Recover labels
    send_to_r, send_to_l = [pd.DataFrame(send, index=idx, columns=idx) for send in [send_to_r, send_to_l]]
    depth = pd.Series(depth, index=idx)
    depth0 = pd.Series(depth0, index=idx)
    return depth, send_to_r, send_to_l, depth0

# Random Sequence

In [None]:
# Simulate sequence
np.random.seed(1)
sequence = simulate_sequence(int(1e5))

seed = build_seed_from_one_sequence(sequence, 5)

# Simulate depth
depth_fn = lambda kmer: np.exp(np.random.randn() * 2)
observed_depth = build_depth_from_seed(seed, depth_fn)

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth, thresh=1e-4)
depth_table = pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

In [None]:
bins = np.linspace(0, 200)
plt.hist(depth_table, bins=bins, alpha=0.5, histtype='stepfilled')
# plt.hist(depth, bins=bins, alpha=0.7)
None

# Tall saw-horse

In [None]:
seed = {
    'AACCG': ['ACCGG'],
    'ACCGG': ['CCGGG', 'CCGGA'],
    'TACCG': ['ACCGG'],
    'TAACC': ['AACCG'],
    'TTACC': ['TACCG'],
    'CCGGG': ['CGGGT'],
    'CCGGA': ['CGGAT'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'AACCG': 9,
    'ACCGG': 10,
    'CCGGG': 9,
    'CCGGA': 1,
    'TACCG': 1,
    'TAACC': 9,
    'TTACC': 1,
    'CGGGT': 9,
    'CGGAT': 1,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

In [None]:
seed = {
    'AACCG': ['ACCGG'],
    'ACCGG': ['CCGGG', 'CCGGA'],
    'TACCG': ['ACCGG'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'AACCG': 3,
    'ACCGG': 4,
    'CCGGG': 3,
    'CCGGA': 1,
    'TACCG': 1,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

# Saw-horse

# Cycle w/ Switch-back

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGT'],
    'CCGGT': ['CGGTA'],
    'CGGTA': ['GGTAC'],
    'GGTAC': ['GTACC'],
    'GTACC': ['TACCC'],
    'TACCC': ['ACCCG'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGT': 1,
    'CGGTA': 1,
    'GGTAC': 1,
    'GTACC': 1,
    'TACCC': 1,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

# Six-Cycle

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA'],
    'CCGGA': ['CGGAC'],
    'CGGAC': ['GGACC'],
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG']
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 2,
    'CGGAC': 1,
    'GGACC': 1,
    'GACCC': 1,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

# Six-cycle w/ Spur

In [None]:
seed = {
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA', 'CCGGC'],
    'CCGGA': ['CGGAC'],
    'CGGAC': ['GGACC'],
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG']
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 1,
    'CGGAC': 1,
    'GGACC': 1,
    'GACCC': 1,
    'CCGGC': 1,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

# Double-six-cycle

In [None]:
seed = {
    # Top cycle
    'GGACC': ['GACCC'],
    'GACCC': ['ACCCG'],
    'ACCCG': ['CCCGG'],
    'CCCGG': ['CCGGA'],
    'CCGGA': ['CGGAC'],
    
    # Link
    'CGGAC': ['GGACC', 'GGACT'],
    
    # Bottom cycle
    'GGACT': ['GACTC'],
    'GACTC': ['ACTCG'],
    'ACTCG': ['CTCGG'],
    'CTCGG': ['TCGGA'],
    'TCGGA': ['CGGAC'],
    
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    # Top cycle
    'GGACC': 1,
    'GACCC': 1,
    'ACCCG': 1,
    'CCCGG': 1,
    'CCGGA': 1,
    
    # Link
    'CGGAC': 3,
    
    # Bottom 
    'GGACT': 2, 
    'GACTC': 2,
    'ACTCG': 2,
    'CTCGG': 2,
    'TCGGA': 2,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))

# Lonely-stick

In [None]:
seed = {
    # Top cycle
    'GGACC': ['GACCT'],
}
observed_depth = lib.assembly_graph.add_reverse_complement_depth({
    # Top cycle
    'GGACC': 1,
    'GACCT': 2,
})

depth, send_to_r, send_to_l, depth0 = run_message_passing(seed, observed_depth)
sns.heatmap(send_to_r + send_to_l)
pd.DataFrame(dict(old_depth=depth0, new_depth=depth))