# Explore graph-refined coverage estimation

Coverage estimation for unitigs on an assembly graph
is inherently noisy, due to the fact that sampling is finite.

- Strategy for using assembly graph to improve binning of contigs
    - Coverage estimates for contigs have a degree of noise, especially for short contigs.  Contigs that are adjacent on the assembly graph present additional information that could be incorporated into coverage estimates, the problem is that many splits represent strain-variation, so adjacent contigs may actually have very _different_ coverages.  One thing we expect is that "true" coverages should be conserved across branches, so that when an upstream single path splits into downstream two paths, the coverage of the two should sum to the coverage of the one.  Theoretically, this information should present an opportunity to improve empirical coverage estimates on the graph, thereby enabling binning of shorter contigs.
- Proposed algorithm:
    - Really this should be a pretty easy optimization problem to formulate...
    - "Latent" coverage at a node is defined as the mean latent coverage of all neighbors (should be a reasonably quick matrix multiplication formulation for this...) plus some local perturbation; minimize both the magnitude of the perturbations as well as the difference between latent coverages and observed coverages
    - Consider also fragmenting really long unitigs

In [None]:
import numpy as np

## Example 1

Node labels

```
   \             /
 (0)\           /(3)
     \   (2)   /
      ---------
     /         \
 (1)/           \(4)
   /             \
```


Node "expected" coverages
```
   \             /
  x0\           /x3
     \    x2   /
      ---------
     /         \
  x1/           \x4
   /             \
```

Node coverage error (both true and observation error)
```
   \             /
  e0\           /e3
     \    e2   /
      ---------
     /         \
  e1/           \e4
   /             \
```

Node observed coverage
```
    \             /
x0+e0\           /x3+e3
      \  x2+e2  /
       ---------
      /         \
x1+e0/           \x4+e4
    /             \
```

Turns out this is a [Hidden Markov Random Field](https://en.wikipedia.org/wiki/Hidden_Markov_random_field).

In [None]:
# Two directed adjacency matrices

# (E0) One point estimate for the true coverage
# is the observed coverage.

# A point estimate for the coverage "from"
# a connected node on the left (equivilantly right)
# is the coverage of that node minus the
# coverage that _it_ shares with nodes to its right
# (equivilantly left)

# The expected coverage for each node is
# the mean of the sum of the coverage "from"
# nodes that connect on the left and the
# sum "from" the right.

# (E1) Another point estimate for the true
# coverage is recursively defined as
# the the mean of any other point estimates

# (E2) Another point estimate for the true
# coverage of each node is the weighted
# mean of point estimates,
# in particular point estimates obtained
# from E0 and E1 above,
# alternatively E0 and E2 (recursively)

In [None]:
def norm_col(x):
    return np.nan_to_num(x / x.sum(0, keepdims=True))

def norm_row(x):
    return np.nan_to_num(x / x.sum(1, keepdims=True))

def contrib_from_side(x, a):
    return np.where(a.sum(1) != 0,
                    (norm_col(a * x.T) @ x.T).flatten(),
                    np.nan)

def contrib_from_neighbors(x, left, right):
    from_left = contrib_from_side(x, left)
    from_right = contrib_from_side(x, right)
    
    # Fill NAs with the other side so that means
    # will only reflect the connected sides.
    from_left = np.where(~np.isnan(from_left),
                         from_left, from_right)
    from_right = np.where(~np.isnan(from_right),
                          from_right, from_left)
    
    return (from_left + from_right) / 2

In [None]:
# Nodes that connect on the left.
L = np.array([
#    0  1  2  3  4
    [0, 0, 0, 0, 0],  # 0
    [0, 0, 0, 0, 0],  # 1
    [1, 1, 0, 0, 0],  # 2
    [0, 0, 1, 0, 0],  # 3
    [0, 0, 1, 0, 0],  # 4
])

# Nodes that connect on the right
R = np.array([
#    0  1  2  3  4
    [0, 0, 1, 0, 0],  # 0
    [0, 0, 1, 0, 0],  # 1
    [0, 0, 0, 1, 1],  # 2
    [0, 0, 0, 0, 0],  # 3
    [0, 0, 0, 0, 0],  # 4
])


learning_rate = 0.5

# Initial coverage point estimates (e.g. from observed coverage)
y_obs = np.array([[2.0, 0.5, 2.0, 1.0, 3.0]])

# How heavily does each nodes observed coverage affect its
# estimate versus the effect of its neighbors.
# Should be between 0 and 1
# (?) How to convert unitig length to a value
# between 0 and 1?  It's obvious that
# it should be saturating with length,
# but probably not to 1.0
# since there's always coverage noise even with really long
# contigs.
obs_weight = np.array([0.25, 0.25, 0.75, 0.25, 0.5])
y_iter = y_obs
for i in range(100):
    y_next = ((obs_weight * y_obs)
              + ((1 - obs_weight)
                 * (learning_rate * contrib_from_neighbors(y_iter, L, R)
                    + (1 - learning_rate) * y_iter)))
    delta = y_next - y_iter
    y_iter = y_next
print(y_iter, delta)

In [None]:
# Coverage point estimates.
Y_itr = np.array([1.5, 1.0, 2.0, 1.0, 3.0])


# Neighbor-based point estimates
def calculate_Y_nbr(x, L, R):
    return (((L @ x) + (R @ x)) / 2)


# Initial
print(Y_itr)

def step_Y_itr(x, L, R):
    X_nbr = calculate_Y_nbr(x, L, R)
    return 0.5 * x + 0.5 * x_nbr

# Iterate
Y_itr = step_Y_itr(Y_itr, L, R)
print(Y_itr)

# Iterate
Y_itr = step_Y_itr(Y_itr, L, R)
print(Y_itr)

# Iterate
Y_itr = step_Y_itr(Y_itr, L, R)
print(Y_itr)

In [None]:
# Observed coverages on each node.
#             0  1  2  3  4
X = np.array([1, 1, 2, 1, 1])