# splitting matches

I want to implement the `split_matches` function. This function takes a match and splits it into smaller alignments if it contains big in/dels.

There is only **one parameter** to this function, which is the threshold block length of pangraph. Usually this is set to 100bp, but can be controlled by the user ([parameter](https://neherlab.github.io/pangraph/cli/build/#Options)).

Let's consider the following cigar:
```
324M 3I 54M 3D 13M 300D 200M 25I 4M 100I 30M 200D 150M
```

We want to start from any region with matches and extend it left and right, until we hit an indel longer than 100bp (or the threshold length). At this point we want to stop the match and look for the next match region to start for the new one.

In the previous example this gives:
```
| keep               | no   | keep        | no   | too short | no   | keep                     |
| 324M 3I 54M 3D 13M | 300D | 200M 25I 4M | 100I | 30M       | 200D | 15M 1I 34M 1D 40M 1D 20M |
```

Basically we do the following:
- parse the cigar string starting from the first match.
- we continue through until we hit an indel longer than 100bp.
- at this point we check whether we have >100bp of matches in total.
  - If so, we keep this region and output the corresponding sub-alignment.
  - If not, we discard it.
- we then move to the next match and repeat until the end of the cigar string.

Here is an example implementation in python.

In [2]:
from dataclasses import dataclass

# classes Hit and Alignment as in `alignment.rs`
@dataclass
class Hit:
    name: str
    length: int
    start: int
    stop: int
    # seq: str -> we probably do not need this at this stage?

@dataclass
class Alignment:
    qry: Hit
    reff: Hit
    matches: int
    length: int
    quality: int
    orientation: str
    cigar: str
    divergence: float
    align: float

In [40]:
from itertools import groupby
from typing import Generator, Tuple

def parse_cigar(cigar: str) -> Generator[Tuple[int, str], None, None]:
    """utility function to parse cigar strings"""
    cig_iter = groupby(cigar, lambda c: c.isdigit())
    for _, n in cig_iter:
        yield int("".join(n)), "".join(next(cig_iter)[1])

def keep_groups(cigar: str, thr_len: int) -> list[list[int]]:
    """
    Given a cigar string, returns a list of tuples (start_index, end_index)
    of cigar elements to keep. A kept group must:
    - always start and end with a match.
    - have at least `thr_len` matches.
    - have no indels longer than `thr_len`.
    """

    cig = list(parse_cigar(cigar))
    groups = []
    g_start, last_match = None, None
    M_sum, I_sum, D_sum = 0, 0, 0
    for i, (n, op) in enumerate(cig):

        #  discard leading indels
        if (g_start is None):
            if op in "ID":
                continue
            else:
                g_start = i
                
        # add length of matches and indels
        if op in "M=X":
            M_sum += n
            # reset n. indels from last match
            I_sum = 0
            D_sum = 0
            last_match = i
        elif op == "I":
            I_sum += n
        elif op == "D":
            D_sum += n


        # if too long indel, split groups
        if max(I_sum, D_sum) >= thr_len:
            if (last_match is not None) and (M_sum >= thr_len):
                # add only if at least 100 matches, otherwise discard.
                groups.append((g_start, last_match))
            g_start, last_match = None, None
            M_sum, I_sum, D_sum = 0, 0, 0

    # add last one
    if (last_match is not None) and (M_sum >= thr_len):
        groups.append((g_start, last_match))

    return groups

    

Test `keep_groups` independently

```
| no                   | keep                 | no      | keep              | no   | keep
| 10I 20D 10M 20I 190D | 40M 1D 1I 40M 1I 40M | 1D 100I | 200M 60I 60D 140M | 200D | 40M 2I 70M
  0   1   2   3   4      5   6  7  8   9  10    11 12     13   14  15  16     17     18  19 20
```

In [41]:

test_cigar = "10I 20D 10M 20I 190D 40M 1D 1I 40M 1I 40M 1D 100I 200M 60I 60D 140M 200D 40M 2I 70M"
test_cigar = test_cigar.replace(" ", "")
expected = [(5,10), (13, 16), (18, 20)]
result = keep_groups(test_cigar, 100)
assert result == expected, f"Expected {expected}, got {result}"

In [42]:
def cigar_position_start(cigar: str, cigar_idx: int, accepted_op: str) -> int:
    pos = 0
    for i, (n, op) in enumerate(parse_cigar(cigar)):
        if i == cigar_idx:
            return pos
        if op in accepted_op:
            pos += n

def cigar_position_end(cigar: str, cigar_idx: int, accepted_op: str) -> int:
    pos = 0
    for i, (n, op) in enumerate(parse_cigar(cigar)):
        if op in accepted_op:
            pos += n
        if i == cigar_idx:
            return pos - 1


def group_positions(cigar: str, start_idx: int, end_idx: int) -> Tuple[int, int, int, int]:
    
    qry_beg = cigar_position_start(cigar, start_idx, "MI=X")
    qry_end = cigar_position_end(cigar, end_idx, "MI=X")
    ref_beg = cigar_position_start(cigar, start_idx, "MD=X")
    ref_end = cigar_position_end(cigar, end_idx, "MD=X")

    return qry_beg, qry_end, ref_beg, ref_end

def generate_subalignment(aln: Alignment, start_cigar_idx: int, end_cigar_idx: int) -> Alignment:
    """Given an alignment and a pair of indices, returns a new alignment that spans only the
    interval between the two selected cigar indices (extremes included)."""

    qs, qe, rs, re = group_positions(aln.cigar, start_cigar_idx, end_cigar_idx)

    qry = Hit(aln.qry.name, aln.qry.length, aln.qry.start + qs, aln.qry.start + qe)
    reff = Hit(aln.reff.name, aln.reff.length, aln.reff.start + rs, aln.reff.start + re)

    matches = sum(n for n, op in parse_cigar(aln.cigar) if (op in "M=X") and (n >= start_cigar_idx) and (n <= end_cigar_idx))
    length = sum(n for n, op in parse_cigar(aln.cigar) if (n >= start_cigar_idx) and (n <= end_cigar_idx))
    quality = aln.quality
    orientation = aln.orientation
    cigar = aln.cigar[start_cigar_idx:end_cigar_idx+1]
    divergence = aln.divergence
    align = aln.align

    return Alignment(qry, reff, matches, length, quality, orientation, cigar, divergence, align)


    

def split_matches(aln: Alignment, thr_len: int) -> list[Alignment]:
    
    kept_groups = keep_groups(aln.cigar, thr_len)

    return [generate_subalignment(aln, start, end) for start, end in kept_groups]
