# splitting matches

I want to implement the `split_matches` function. This function takes a match and splits it into smaller alignments if it contains big in/dels.

There is only **one parameter** to this function, which is the threshold block length of pangraph. Usually this is set to 100bp, but can be controlled by the user ([parameter](https://neherlab.github.io/pangraph/cli/build/#Options)).

Let's consider the following cigar:
```
324M 3I 54M 3D 13M 300D 200M 25I 4M 100I 30M 200D 150M
```

We want to start from any region with matches and extend it left and right, until we hit an indel longer than 100bp (or the threshold length). At this point we want to stop the match and look for the next match region to start for the new one.

In the previous example this gives:
```
| keep               | no   | keep        | no   | too short | no   | keep                     |
| 324M 3I 54M 3D 13M | 300D | 200M 25I 4M | 100I | 30M       | 200D | 15M 1I 34M 1D 40M 1D 20M |
```

Basically we do the following:
- parse the cigar string starting from the first match.
- we continue through until we hit an indel longer than 100bp.
- at this point we check whether we have >100bp of matches in total.
  - If so, we keep this region and output the corresponding sub-alignment.
  - If not, we discard it.
- we then move to the next match and repeat until the end of the cigar string.

Here is an example implementation in python.

In [2]:
from dataclasses import dataclass

# classes Hit and Alignment as in `alignment.rs`
@dataclass
class Hit:
    name: str
    length: int
    start: int
    stop: int
    # seq: str -> we probably can remove this field. I think we never use it and most of the times we don't have it.

@dataclass
class Alignment:
    qry: Hit
    reff: Hit
    matches: int
    length: int
    quality: int
    orientation: str
    cigar: str
    divergence: float
    align: float

In [3]:
from itertools import groupby
from typing import Generator, Tuple

def parse_cigar(cigar: str) -> Generator[Tuple[int, str], None, None]:
    """utility function to parse cigar strings"""
    cig_iter = groupby(cigar, lambda c: c.isdigit())
    for _, n in cig_iter:
        yield int("".join(n)), "".join(next(cig_iter)[1])

def keep_groups(cigar: str, thr_len: int) -> list[list[int]]:
    """
    Given a cigar string, returns a list of tuples (start_index, end_index)
    of cigar elements to keep. A kept group must:
    - always start and end with a match.
    - have at least `thr_len` matches.
    - have no indels longer than `thr_len`.
    """

    cig = list(parse_cigar(cigar))
    groups = []
    g_start, last_match = None, None
    M_sum, I_sum, D_sum = 0, 0, 0
    for i, (n, op) in enumerate(cig):

        #  discard leading indels
        if (g_start is None):
            if op in "ID":
                continue
            else:
                g_start = i
                
        # add length of matches and indels
        if op in "M=X":
            M_sum += n
            # reset n. indels from last match
            I_sum = 0
            D_sum = 0
            last_match = i
        elif op == "I":
            I_sum += n
        elif op == "D":
            D_sum += n


        # if too long indel, split groups
        if max(I_sum, D_sum) >= thr_len:
            if (last_match is not None) and (M_sum >= thr_len):
                # add only if at least 100 matches, otherwise discard.
                groups.append((g_start, last_match))
            g_start, last_match = None, None
            M_sum, I_sum, D_sum = 0, 0, 0

    # add last one
    if (last_match is not None) and (M_sum >= thr_len):
        groups.append((g_start, last_match))

    return groups

    

Test `keep_groups` independently

```
| no                   | keep                 | no      | keep              | no   | keep
| 10I 20D 10M 20I 190D | 40M 1D 1I 40M 1I 40M | 1D 100I | 200M 60I 60D 140M | 200D | 40M 2I 70M
  0   1   2   3   4      5   6  7  8   9  10    11 12     13   14  15  16     17     18  19 20
```

In [4]:

test_cigar = "10I 20D 10M 20I 190D 40M 1D 1I 40M 1I 40M 1D 100I 200M 60I 60D 140M 200D 40M 2I 70M"
test_cigar = test_cigar.replace(" ", "")
expected = [(5,10), (13, 16), (18, 20)]
result = keep_groups(test_cigar, 100)
assert result == expected, f"Expected {expected}, got {result}"

In [5]:
def cigar_position_start(cigar: str, cigar_idx: int, accepted_op: str) -> int:
    pos = 0
    for i, (n, op) in enumerate(parse_cigar(cigar)):
        if i == cigar_idx:
            return pos
        if op in accepted_op:
            pos += n

def cigar_position_end(cigar: str, cigar_idx: int, accepted_op: str) -> int:
    pos = 0
    for i, (n, op) in enumerate(parse_cigar(cigar)):
        if op in accepted_op:
            pos += n
        if i == cigar_idx:
            return pos


def group_positions(cigar: str, start_idx: int, end_idx: int) -> Tuple[int, int, int, int]:
    """Given a start and end index of a cigar string, returns the start and end position
    of the interval in the query and reference sequences.
    Position use python 0-based indexing (right-end excluded).
    """
    
    qry_beg = cigar_position_start(cigar, start_idx, "MI=X")
    qry_end = cigar_position_end(cigar, end_idx, "MI=X")
    ref_beg = cigar_position_start(cigar, start_idx, "MD=X")
    ref_end = cigar_position_end(cigar, end_idx, "MD=X")

    return qry_beg, qry_end, ref_beg, ref_end



We test it in the following example:

```
                            |-------------------------------|
idx 0   1      2   3    4    5     6  7       8   9    10    11    12  13  14
CG  3I  6M     3I  4M   4D   5M    2I 7M      3D  4I   5M    5D    3M  3I  2M

        0    5     6            18    19       28      29           41
ref --- MMMMMM --- MMMM DDDD MMMMM -- MMMMMMM DDD ---- MMMMM DDDDD MMM --- MM
qry III MMMMMM III MMMM ---- MMMMM II MMMMMMM --- IIII MMMMM ----- MMM III MM
    0                15      16            29     30      38       39      46
```

In [6]:
cigar = "3I 6M 3I 4M 4D 5M 2I 7M 3D 4I 5M 5D 3M 3I 2M"
cigar = cigar.replace(" ", "")
start_idx, end_idx = 5, 10
expected = (16, 39, 14, 34)
result = group_positions(cigar, start_idx, end_idx)
assert result == expected, f"Expected {expected}, got {result}"

In [16]:

def generate_subalignment(aln: Alignment, start_cigar_idx: int, end_cigar_idx: int) -> Alignment:
    """Given an alignment and a pair of indices, returns a new alignment that spans only the
    interval between the two selected cigar indices (python 0-based indexing, right-extreme
    excluded)."""

    qs, qe, rs, re = group_positions(aln.cigar, start_cigar_idx, end_cigar_idx)

    rs, re = aln.reff.start + rs, aln.reff.start + re
    if aln.orientation == "+":
        qs, qe = aln.qry.start + qs, aln.qry.start + qe
    elif aln.orientation == "-":
        qs, qe = aln.qry.stop - qe, aln.qry.stop - qs
    else:
        raise ValueError(f"Invalid orientation {aln.orientation}")
    qry = Hit(aln.qry.name, aln.qry.length, qs, qe)
    reff = Hit(aln.reff.name, aln.reff.length, rs, re)

    sub_cigar = list(parse_cigar(aln.cigar))[start_cigar_idx:end_cigar_idx+1]
    matches = sum(n for n, op in sub_cigar if (op in "M=X"))
    length = sum(n for n, op in sub_cigar)
    quality = aln.quality
    orientation = aln.orientation
    cigar = "".join(f"{n}{op}" for n, op in sub_cigar)
    # TODO: possibly recalculate divergence from the sequence + cigar string?
    divergence = aln.divergence
    align = aln.align

    return Alignment(qry, reff, matches, length, quality, orientation, cigar, divergence, align)


def side_patches(aln: Alignment, thr_len: int) -> Alignment:
    """If lateral overhangs are shorter than thr_len, add them to the alignment
    to avoid excessive fragmentation."""

    # check ref
    rs, re, rL = aln.reff.start, aln.reff.stop, aln.reff.length
    
    if (rs > 0) and (rs < thr_len):
        # append left side reference patch
        delta_L = rs
        aln.reff.start = 0
        aln.length += delta_L
        aln.cigar = f"{delta_L}D" + aln.cigar
        
    if (re < rL) and (rL - re < thr_len):
        # append right reference patch
        delta_L = rL - re
        aln.reff.stop = rL
        aln.length += delta_L
        aln.cigar = aln.cigar + f"{delta_L}D"

    # check query
    qs, qe, qL = aln.qry.start, aln.qry.stop, aln.qry.length
    
    if (qs > 0) and (qs < thr_len):
        # append query start
        delta_L = qs
        aln.qry.start = 0
        aln.length += delta_L
        extra_ins = f"{delta_L}I"
        if aln.orientation == "+":
            aln.cigar = extra_ins + aln.cigar
        else:
            aln.cigar = aln.cigar + extra_ins

    if (qe < qL) and (qL - qe < thr_len):
        # append query end
        delta_L = qL - qe
        aln.qry.stop = qL
        aln.length += delta_L
        extra_ins = f"{delta_L}I"
        if aln.orientation == "+":
            aln.cigar = aln.cigar + extra_ins
        else:
            aln.cigar = extra_ins + aln.cigar


    return aln



def split_matches(aln: Alignment, thr_len: int) -> list[Alignment]:
    
    kept_groups = keep_groups(aln.cigar, thr_len)
    sub_aln = [generate_subalignment(aln, start, end) for start, end in kept_groups]
    sub_aln = [side_patches(aln, thr_len) for aln in sub_aln] 
    return sub_aln


we test it on this example, with threshold length = 10bp:
```
CG      3I  6M     3I  3M  4D   5M    14I            7M      3D  4I   5M    5D    3M  3I
    
100 +       0                      17                18                            40
ref     --- MMMMMM --- MMM DDDD MMMMM -------------- MMMMMMM DDD ---- MMMMM DDDDD MMM ---
qry     III MMMMMM III MMM ---- MMMMM IIIIIIIIIIIIII MMMMMMM --- IIII MMMMM ----- MMM III
200 +       3                      19                34                            52
groups      |-----------------------|                |------------------------------|
```

In [17]:
cg = "3I 6M 3I 3M 4D 5M 14I 7M 3D 4I 5M 5D 3M 3I"
cg = cg.replace(" ", "")

kwargs = {
    "quality": 10,
    "orientation": "+",
    "divergence": 0.1,
    "align": None,
}

aln = Alignment(
    Hit("qry", 500, 200, 255),
    Hit("ref", 500, 100, 140),
    matches=sum(n for n, op in parse_cigar(cg) if op in "M=X"),
    length=sum(n for n, op in parse_cigar(cg)),
    cigar=cg,
    **kwargs
)

expected = [
    Alignment(
        Hit("qry", 500, 203, 220),
        Hit("ref", 500, 100, 118),
        matches=14,
        length=21,
        cigar="6M3I3M4D5M",
        **kwargs
    ),
    Alignment(
        Hit("qry", 500, 234, 253),
        Hit("ref", 500, 118, 141),
        matches=15,
        length=27,
        cigar="7M3D4I5M5D3M",
        **kwargs
    ),
]
results = split_matches(aln, thr_len=10)

assert len(results) == len(expected), f"Expected {len(expected)} alignments, got {len(results)}"
for i, (res, exp) in enumerate(zip(results, expected)):
    for field in ["qry", "reff", "matches", "length", "cigar", "quality", "orientation", "divergence", "align"]:
        e, r = getattr(exp, field), getattr(res, field)
        assert e == r, f"{i} - Expected {field} to be {e}, got {r}"

We also perform the same test when the mapping is on the reverse strand:
```
CG      3I  6M     3I  3M  4D   5M    14I            7M      3D  4I   5M    5D    3M  3I
    
100 +       0                      17                18                            40
ref     --- MMMMMM --- MMM DDDD MMMMM -------------- MMMMMMM DDD ---- MMMMM DDDDD MMM ---
qry     III MMMMMM III MMM ---- MMMMM IIIIIIIIIIIIII MMMMMMM --- IIII MMMMM ----- MMM III
200 +       52                     36                21                             3
groups      |-----------------------|                |------------------------------|
```

In [18]:
cg = "3I 6M 3I 3M 4D 5M 14I 7M 3D 4I 5M 5D 3M 3I"
cg = cg.replace(" ", "")

kwargs = {
    "quality": 10,
    "orientation": "-",
    "divergence": 0.1,
    "align": None,
}

aln = Alignment(
    Hit("qry", 500, 200, 256),
    Hit("ref", 500, 100, 141),
    matches=sum(n for n, op in parse_cigar(cg) if op in "M=X"),
    length=sum(n for n, op in parse_cigar(cg)),
    cigar=cg,
    **kwargs
)

expected = [
    Alignment(
        Hit("qry", 500, 236, 253),
        Hit("ref", 500, 100, 118),
        matches=14,
        length=21,
        cigar="6M3I3M4D5M",
        **kwargs
    ),
    Alignment(
        Hit("qry", 500, 203, 222),
        Hit("ref", 500, 118, 141),
        matches=15,
        length=27,
        cigar="7M3D4I5M5D3M",
        **kwargs
    ),
]
results = split_matches(aln, thr_len=10)

assert len(results) == len(expected), f"Expected {len(expected)} alignments, got {len(results)}"
for i, (res, exp) in enumerate(zip(results, expected)):
    for field in ["qry", "reff", "matches", "length", "cigar", "quality", "orientation", "divergence", "align"]:
        e, r = getattr(exp, field), getattr(res, field)
        assert e == r, f"{i} - Expected {field} to be {e}, got {r}"

## test side patches


we also test a case where the side patch function should act, with threshold length = 10bp:
```
CG      3I  3D  6M     3I  3M  4D   5M    14I            7M      3D  4I   5M    5D    3M  4I   12D
    
            0   3                      20                21                            43                55
ref     --- DDD MMMMMM --- MMM DDDD MMMMM -------------- MMMMMMM DDD ---- MMMMM DDDDD MMM ---- DDDDDDDDDDDD
qry     III --- MMMMMM III MMM ---- MMMMM IIIIIIIIIIIIII MMMMMMM --- IIII MMMMM ----- MMM IIII ------------
200 +           3                      19                34                            52   56
groups          |-----------------------|                |------------------------------|
side patch  |---------------------------|                |-----------------------------------|
```

In [22]:
cg = "3I 3D 6M 3I 3M 4D 5M 14I 7M 3D 4I 5M 5D 3M 4I 12D"
cg = cg.replace(" ", "")

kwargs = {
    "quality": 10,
    "orientation": "+",
    "divergence": 0.1,
    "align": None,
}

aln = Alignment(
    Hit("qry", 257, 200, 257),
    Hit("ref", 56, 0, 56),
    matches=sum(n for n, op in parse_cigar(cg) if op in "M=X"),
    length=sum(n for n, op in parse_cigar(cg)),
    cigar=cg,
    **kwargs
)

expected = [
    Alignment(
        Hit("qry", 257, 203, 220),
        Hit("ref", 56, 0, 21),
        matches=14,
        length=24,
        cigar="3D6M3I3M4D5M",
        **kwargs
    ),
    Alignment(
        Hit("qry", 257, 234, 257),
        Hit("ref", 56, 21, 44),
        matches=15,
        length=31,
        cigar="7M3D4I5M5D3M4I",
        **kwargs
    ),
]
results = split_matches(aln, thr_len=10)

assert len(results) == len(expected), f"Expected {len(expected)} alignments, got {len(results)}"
for i, (res, exp) in enumerate(zip(results, expected)):
    for field in ["qry", "reff", "matches", "length", "cigar", "quality", "orientation", "divergence", "align"]:
        e, r = getattr(exp, field), getattr(res, field)
        assert e == r, f"{i} - Expected {field} to be {e}, got {r}"

We also perform the same test when the mapping is on the reverse strand:
```
CG          3I  3D  6M     3I  3M  4D   5M    14I            7M      3D  4I   5M    5D    3M  4I   5D
        
                0   3                      20                21                            43         48
ref         --- DDD MMMMMM --- MMM DDDD MMMMM -------------- MMMMMMM DDD ---- MMMMM DDDDD MMM ---- DDDDD
qry         III --- MMMMMM III MMM ---- MMMMM IIIIIIIIIIIIII MMMMMMM --- IIII MMMMM ----- MMM IIII -----
300 +       56      53                     37                22                             4    0
groups              |-----------------------|                |------------------------------|
side patch  |-------------------------------|                |-------------------------------      -----|
```

In [27]:
cg = "3I 3D 6M 3I 3M 4D 5M 14I 7M 3D 4I 5M 5D 3M 4I 5D"
cg = cg.replace(" ", "")

kwargs = {
    "quality": 10,
    "orientation": "-",
    "divergence": 0.1,
    "align": None,
}

aln = Alignment(
    Hit("qry", 257, 200, 257),
    Hit("ref", 49, 0, 49),
    matches=sum(n for n, op in parse_cigar(cg) if op in "M=X"),
    length=sum(n for n, op in parse_cigar(cg)),
    cigar=cg,
    **kwargs
)

expected = [
    Alignment(
        Hit("qry", 257, 237, 257),
        Hit("ref", 49, 0, 21),
        matches=14,
        length=27,
        cigar="3I3D6M3I3M4D5M",
        **kwargs
    ),
    Alignment(
        Hit("qry", 257, 204, 223),
        Hit("ref", 49, 21, 49),
        matches=15,
        length=32,
        cigar="7M3D4I5M5D3M5D",
        **kwargs
    ),
]
results = split_matches(aln, thr_len=10)

assert len(results) == len(expected), f"Expected {len(expected)} alignments, got {len(results)}"
for i, (res, exp) in enumerate(zip(results, expected)):
    for field in ["qry", "reff", "matches", "length", "cigar", "quality", "orientation", "divergence", "align"]:
        e, r = getattr(exp, field), getattr(res, field)
        assert e == r, f"{i} - Expected {field} to be {e}, got {r}"