# filter matches

The next steps is filtering matches based on two criteria:
- their energy: this measures how long and how diverged two matches are. Long and less diverged matches (low energy) are given priority.
- their overlap: to be able to parallelize, we can only perform mergers whose alignments don't overlap.

The formula for the energy is:
$$
E = -L + \alpha C + \beta M
$$
where:
- $L$ is the alignment length. We can measure it as number of matches, to avoid counting indels.
- $C$ is the number of *cuts* we need to make if we want to merge the block on this alignment. This depend on whether the alignment extends to the beginning/end of the query/reference sequences. If not we need to introduce a cut. This number is between 0 and 4.
- $M$ is the number of mismatches in the alignment. It can be estimated from the divergence.
- $\alpha$ and $\beta$ are the `-a` and `-b` CLI parameters of pangraph. They act as weight, and control how much penalty is given for each mismatch and each cut.

What we need to do is:
- for each alignment, evaluate the energy.
- sort alignments in increasing order of energy only keep alignments with negative energy.
- progressively consider every alignment on this list, and accept alignments that do not overlap on the ref or query with previously accepted alignments.

In [1]:
from dataclasses import dataclass

# classes Hit and Alignment as in `alignment.rs`
@dataclass
class Hit:
    name: str
    length: int
    start: int
    stop: int
    # seq: str -> we probably can remove this field. I think we never use it and most of the times we don't have it.

@dataclass
class Alignment:
    qry: Hit
    reff: Hit
    matches: int
    length: int
    quality: int
    orientation: str
    cigar: str
    divergence: float
    align: float

In [2]:
def energy(aln: Alignment, alpha: float, beta: float) -> float:
    L = aln.matches
    M = aln.divergence * L
    C = 4
    if aln.qry.start == 0:
        C -= 1
    if aln.qry.stop == aln.qry.length:
        C -= 1
    if aln.reff.start == 0:
        C -= 1
    if aln.reff.stop == aln.reff.length:
        C -= 1
    return -L + C * alpha + M * beta


We can design a small test for the energy function:

In [3]:
alpha = 10
beta = 10

div = 0.02

# define an alignment with two cuts
aln = Alignment(
    Hit("bl_1", 100, 0, 50),
    Hit("bl_2", 200, 150, 200),
    40,
    60,
    100,
    "+",
    "10I40M10D",
    div,
    0.1
)
e = energy(aln, alpha, beta)
print(f"energy = {e}")
assert e == 40 * (beta * div - 1) + alpha * 2 

energy = -12.0


Next we need to implement a function to decide, given the set of previously accepted alignments, whether a new alignment is compatible. This requires that the matching interval does not match with the matching interval in any other accepted alignment, both for the reference and the query.

I implemented this in a quick way by defining intervals and an operation to check that they do not overlap. But I think in rust intervals might be already implemented in the nextalign module. You can implement the `is_compatible` function in any way you prefer, it just has to verify that the no-overlap condition is respected with any previous alignment.

Here I do it by having a dictionary that links each block name with its list of accepted intervals. For each alignment I check that the matches in the reference block and the query blocks are compatible with the accepted alignments. If so I update the list of accepted alignments and I add the alignment to the list of accepted alignments.

In [4]:
from collections import defaultdict


@dataclass
class Interval:
    s: int
    e: int

    def overlap(self, i):
        # self |-----|       or   |--------|
        # i        |-----|          |----|
        if (i.s >= self.s) and (i.s <= self.e):
            return True
        # self     |-----|
        # i     |-----|
        elif (i.e >= self.s) and (i.e <= self.e):
            return True
        # self     |---|
        # i     |----------|
        elif (self.s >= i.s) and (self.s <= i.e):
            return True
        return False


def no_overlap(int_list: list[Interval], j: Interval):
    for i in int_list:
        if i.overlap(j):
            return False
    return True


def is_compatible(aln: Alignment, accepted_intervals: defaultdict[str, list[Interval]]) -> bool:
    ref_name = aln.reff.name
    ref_interval = Interval(aln.reff.start, aln.reff.stop)
    ref_compatible = no_overlap(accepted_intervals[ref_name], ref_interval)

    qry_name = aln.qry.name
    qry_interval = Interval(aln.qry.start, aln.qry.stop)
    qry_compatible = no_overlap(accepted_intervals[qry_name], qry_interval)

    return (ref_compatible and qry_compatible)

Here are some tests for these functions, but they might be superfluous if we implement them in other ways.

In [5]:
# test overlap
ov = Interval(100,200).overlap(Interval(210, 390))
assert not ov

ov = Interval(100,220).overlap(Interval(210, 390))
assert ov

no_ov = no_overlap([Interval(100,200), Interval(300,400)], Interval(210, 290))
assert no_ov

no_ov = no_overlap([Interval(100,200), Interval(300,400)], Interval(210, 390))
assert not no_ov

# test compatibility function

accepted_intervals = defaultdict(list)
accepted_intervals["block_0"] = [Interval(100,200), Interval(300,400)]
accepted_intervals["block_1"] = [Interval(200,300), Interval(400,500)]

aln = Alignment(
    Hit("block_0", 1000, 210, 290),
    Hit("block_1", 1000, 310, 390),
    80,
    80,
    10,
    "-",
    "90M",
    0.05,
    None,
)

assert is_compatible(aln, accepted_intervals)


accepted_intervals = defaultdict(list)
accepted_intervals["block_0"] = [Interval(100,200), Interval(300,400)]
accepted_intervals["block_1"] = [Interval(200,300), Interval(400,500)]

aln = Alignment(
    Hit("block_0", 1000, 310, 390),
    Hit("block_1", 1000, 310, 390),
    80,
    80,
    10,
    "-",
    "90M",
    0.05,
    None,
)

assert not is_compatible(aln, accepted_intervals)


Finally I combine these to get the final `filter_matches` function. It takes as input the list of alignment and the two command line parameters `-a` and `-b`, and returns the list of filtered alignments, ordered by energy.

In [20]:

def update_intervals(aln: Alignment, accepted_intervals: dict[str, list[Interval]]):
    ref_name = aln.reff.name
    ref_interval = Interval(aln.reff.start, aln.reff.stop)
    qry_name = aln.qry.name
    qry_interval = Interval(aln.qry.start, aln.qry.stop)
    accepted_intervals[ref_name].append(ref_interval)
    accepted_intervals[qry_name].append(qry_interval)



def filter_matches(alignments: list[Alignment], alpha: float, beta: float) -> list[Alignment]:
    # here I evaluate the energy twice, but one could also just evaluate it once and save it, and use it to sort.
    # discard alignments with positive energy
    alignments = list(filter(lambda a : energy(a, alpha, beta) < 0, alignments))
    # sort in increasing order of energy.
    alignments = sorted(alignments, key=lambda a: energy(a, alpha, beta))

    # iteratively accept alignments if they do not overlap with previously accepted ones
    accepted_aln = []
    accepted_intervals = defaultdict(list)
    for aln in alignments:
        if is_compatible(aln, accepted_intervals):
            accepted_aln.append(aln)
            update_intervals(aln, accepted_intervals)

    return accepted_aln


In [21]:
alpha = 10
beta = 10

# for simplicity, avoid setting superfluous arguments
kwargs = dict(
    length=None,
    quality=None,
    orientation=None,
    align=None,
)

aln_0 = Alignment(
        Hit("bl0", 500, 100, 200),
        Hit("bl1", 500, 200, 300),
        matches=100,
        cigar="100M",
        divergence=0.05,
        **kwargs,
    ) # energy = -10

aln_1 = Alignment(
        Hit("bl2", 500, 100, 200),
        Hit("bl3", 500, 200, 300),
        matches=100,
        cigar="100M",
        divergence=0.02,
        **kwargs,
    ) # energy = -40

aln_2 = Alignment(
        Hit("bl2", 500, 150, 250),
        Hit("bl4", 500, 200, 300),
        matches=100,
        cigar="100M",
        divergence=0.05,
        **kwargs,
    ) # energy = -10 but incompatible with aln_1

aln_3 = Alignment(
        Hit("bl5", 500, 100, 200),
        Hit("bl6", 500, 200, 300),
        matches=100,
        cigar="100M",
        divergence=0.1,
        **kwargs,
    ) # energy = 40, discarded

alignments = [
    aln_0,
    aln_1,
    aln_2,
    aln_3,
]

filt_aln = filter_matches(alignments, alpha, beta)
expected = [aln_1, aln_0]
assert filt_aln == expected
