In [1]:
# -*- coding: utf-8 -*-
import re
from typing import Dict, List, Tuple, Optional, Any
import pandas as pd
import os

# ===================== I. Extract top-scoring templates by class from FASTA =====================

def classify_template(name: str) -> str:
    """Classify by template name prefix into HV/HJ/HC/LV/LJ/LC/Unknown."""
    if name.startswith("IGHV"):
        return "HV"
    elif name.startswith("IGHJ"):
        return "HJ"
    elif name.startswith("IGH"):
        return "HC"
    elif name.startswith("IGKV") or name.startswith("IGLV"):
        return "LV"
    elif name.startswith("IGKJ") or name.startswith("IGLJ"):
        return "LJ"
    elif name.startswith("IGKC") or name.startswith("IGLC"):
        return "LC"
    else:
        return "Unknown"

def extract_template_info_from_fasta(fasta_path: str, output_csv_path: Optional[str] = None) -> pd.DataFrame:
    """
    Parse template name, score, and type from FASTA headers.
    Expected header form: ">IGHV1-46 ... score:123 ..."
    """
    records = []
    with open(fasta_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.startswith(">"):
                continue
            header = line[1:].strip()
            parts = header.split()
            template = parts[0]
            score = None
            for p in parts:
                if p.startswith("score:"):
                    try:
                        score = int(p.split(":", 1)[1])
                    except ValueError:
                        score = None
                    break
            records.append({"template": template, "type": classify_template(template), "score": score})

    df = pd.DataFrame(records)
    if output_csv_path:
        df.to_csv(output_csv_path, index=False)
        print(f"✅ Saved to: {output_csv_path}")
    return df


# ===================== II. Parse Stitch config & read six segments =====================

def parse_stitch_config_segments(cfg_path: str) -> Tuple[Dict[str, Dict[str, str]], str]:
    """
    Returns:
      segments: {'heavy': {'V','J','C'}, 'light': {'V','J','C'}}  -> corresponding FASTA paths
      runname : e.g. '50ugS2P6'
    Supports arbitrary prefixes like “Mus Heavy Chain-> / Mus Light Chain->” or “Human Heavy Chain->”.
    """
    segments = {'heavy': {'V': None, 'J': None, 'C': None},
                'light': {'V': None, 'J': None, 'C': None}}

    current_chain: Optional[str] = None
    in_segment = False
    seg_name = None
    seg_path = None

    name_re = re.compile(r'^\s*Name\s*:\s*(\S+)\s*$', re.IGNORECASE)
    path_re = re.compile(r'^\s*Path\s*:\s*(.+?\.(?:fa|fasta|faa|fna))\s*$', re.IGNORECASE)
    run_re  = re.compile(r'^\s*Runname\s*:\s*(.+?)\s*$', re.IGNORECASE)

    # Support arbitrary prefixes like “Mus Heavy Chain->”, “Human Light Chain->”, etc.
    heavy_hdr_re = re.compile(r'^\s*.*\bHeavy\s+Chain->\s*$', re.IGNORECASE)
    light_hdr_re = re.compile(r'^\s*.*\bLight\s+Chain->\s*$', re.IGNORECASE)

    # Support optional spaces before/after “Segment->” and “<-”
    seg_enter_re = re.compile(r'^\s*Segment\s*->\s*$')
    seg_leave_re = re.compile(r'^\s*<-\s*$')

    runname: Optional[str] = None

    def commit():
        nonlocal seg_name, seg_path
        if not (current_chain and seg_name and seg_path):
            return
        n = seg_name.upper().strip().strip('"').strip("'")
        p = seg_path.strip().strip('"').strip("'")
        if current_chain == 'heavy':
            if n == 'IGHV': segments['heavy']['V'] = p
            elif n == 'IGHJ': segments['heavy']['J'] = p
            elif n == 'IGHC': segments['heavy']['C'] = p
        elif current_chain == 'light':
            if n in ('IGLV', 'IGKV'): segments['light']['V'] = p
            elif n in ('IGLJ', 'IGKJ'): segments['light']['J'] = p
            elif n in ('IGLC', 'IGKC'): segments['light']['C'] = p

    with open(cfg_path, 'r', encoding='utf-8') as f:
        for raw in f:
            # Runname can appear anywhere in the file
            if runname is None:
                m_run = run_re.match(raw)
                if m_run:
                    runname = m_run.group(1).strip().strip('"').strip("'")

            line = raw.strip()

            # Detect chain header
            if heavy_hdr_re.match(line):
                current_chain = 'heavy'; in_segment = False; seg_name = seg_path = None
                continue
            if light_hdr_re.match(line):
                current_chain = 'light'; in_segment = False; seg_name = seg_path = None
                continue

            # Enter/exit a Segment block
            if seg_enter_re.match(line):
                in_segment = True; seg_name = seg_path = None
                continue
            if seg_leave_re.match(line):
                if in_segment: commit()
                in_segment = False; seg_name = seg_path = None
                continue

            if not in_segment or current_chain not in ('heavy', 'light'):
                continue

            m_path = path_re.match(raw)
            if m_path:
                seg_path = m_path.group(1); commit(); continue
            m_name = name_re.match(raw)
            if m_name:
                seg_name = m_name.group(1); commit(); continue

    if in_segment:
        commit()

    missing = [(ch, rgn) for ch in ('heavy', 'light') for rgn in ('V','J','C') if not segments[ch][rgn]]
    if missing:
        raise ValueError(f"The following slots did not resolve to FASTA paths: {missing}")
    if not runname:
        raise ValueError("Runname not found; please check the config file.")

    return segments, runname


def load_annotated_fasta(path: str, key_mode: str = "first_token") -> Dict[str, str]:
    """
    Read annotated FASTA and return {id: single-line sequence}.
    Notes:
      - Strip each line, then join with ''.join(buf) (no separator) to avoid inter-line spaces.
      - Remove NBSP (\xa0).
    """
    records: Dict[str, str] = {}
    cur_key: Optional[str] = None
    buf: List[str] = []

    def _push():
        if cur_key is not None:
            raw = "".join(buf).replace("\xa0", " ")
            records[cur_key] = raw

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.startswith(">"):
                _push()
                head = line[1:].strip()
                key = head.split()[0] if key_mode == "first_token" else head
                cur_key = key
                buf = []
            else:
                buf.append(line.strip())
    _push()
    return records

def extract_sequence_by_id(fasta_path: str, target_id: str, key_mode: str = "first_token") -> str:
    db = load_annotated_fasta(fasta_path, key_mode=key_mode)
    lut = {k.upper(): k for k in db.keys()}
    k = lut.get(target_id.upper())
    if k is None:
        raise KeyError(f"ID '{target_id}' not found in {fasta_path}")
    return db[k]


# ---- Template parsing (V/J/C) ----

def _ctx(seq: str, i: int, span: int = 25) -> str:
    s = max(0, i - span); e = min(len(seq), i + span)
    return seq[s:e]

def _parse_template_with_rules(seq: str, default_label: str, allowed_cdr: List[str]) -> Tuple[str, List[str]]:
    """
    Strict parser:
      - Recognizes only two bracketed forms: (Conserved X...) and (CDR1/2/3 Y...).
      - Outside brackets only A–Z are allowed.
      - Unclosed/unknown brackets or illegal characters raise errors (with context).
    """
    pure: List[str] = []
    labels: List[str] = []
    i, n = 0, len(seq)

    while i < n:
        ch = seq[i]
        if ch == '(':
            j = seq.find(')', i + 1)
            if j == -1:
                raise ValueError(f"Unclosed '(' (pos={i}). context='{_ctx(seq, i)}'")

            token = seq[i+1:j]
            token_norm = re.sub(r"\s+", " ", token.strip())

            if token_norm.startswith("Conserved"):
                aa_block = re.sub(r'[^A-Z]', '', token_norm[len("Conserved"):])
                if not aa_block:
                    raise ValueError(f"No amino acids detected in (Conserved ...) (pos={i}). token='{token_norm}'")
                for aa in aa_block:
                    pure.append(aa); labels.append("Conserved")
                i = j + 1; continue

            if token_norm.startswith("CDR"):
                m = re.match(r"^CDR\s*([1-3])\s*([A-Za-z\s]*)$", token_norm)
                if not m:
                    raise ValueError(f"Unrecognized CDR annotation (pos={i}). token='{token_norm}'")
                cdr_tag = f"CDR{m.group(1)}"
                aa_block = re.sub(r'[^A-Z]', '', m.group(2))
                if not aa_block:
                    raise ValueError(f"No amino acids detected in {cdr_tag} (pos={i}). token='{token_norm}'")
                tag = cdr_tag if cdr_tag in allowed_cdr else default_label
                for aa in aa_block:
                    pure.append(aa); labels.append(tag)
                i = j + 1; continue

            # Any other bracket form: strict error
            raise ValueError(f"Unknown bracket annotation (pos={i}). token='{token_norm}'. context='{_ctx(seq, i)}'")

        # Outside brackets: only allow A–Z (no inter-line spaces at this point)
        if 'A' <= ch <= 'Z':
            pure.append(ch); labels.append(default_label); i += 1; continue

        raise ValueError(f"Illegal character '{ch}' (pos={i}). context='{_ctx(seq, i)}'")

    if len(pure) != len(labels):
        raise RuntimeError("Parsed sequence length does not match label length")
    return "".join(pure), labels

def parse_v_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Variable", ["CDR1","CDR2","CDR3"])

def parse_j_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Variable", ["CDR3"])

def parse_c_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Constant", [])


def get_six_segments(cfg_path: str,
                     target_ids: Dict[str, Dict[str, str]],
                     key_mode: str = "first_token") -> Tuple[Dict[str, Dict[str, Tuple[str, List[str]]]], str]:
    """Read six IDs (V/J/C for heavy and light), parse them, and return (six_segments, runname)."""
    seg_paths, runname = parse_stitch_config_segments(cfg_path)
    out: Dict[str, Dict[str, Tuple[str, List[str]]]] = {'heavy': {}, 'light': {}}

    for chain in ('heavy', 'light'):
        for region in ('V', 'J', 'C'):
            fasta_path = seg_paths[chain][region]
            tid = target_ids[chain][region]
            annotated = extract_sequence_by_id(fasta_path, tid, key_mode=key_mode)

            if region == 'V':
                seq, labels = parse_v_template_to_labels(annotated)
            elif region == 'J':
                seq, labels = parse_j_template_to_labels(annotated)
            else:
                seq, labels = parse_c_template_to_labels(annotated)

            out[chain][region] = (seq, labels)
    return out, runname


# ===================== III. Concatenate full chains and extract CDRs & flanks =====================

def _concat_chain(segments: Dict[str, Tuple[str, List[str]]]) -> Tuple[str, List[str]]:
    """
    Concatenate V/J/C segments in order to form a single chain and merge labels.
    segments: {'V': (seq, labels), 'J': (seq, labels), 'C': (seq, labels)}
    """
    full_seq_parts: List[str] = []
    full_labels: List[str] = []
    for r in ("V", "J", "C"):
        seq, lab = segments[r]
        if len(seq) != len(lab):
            raise RuntimeError(f"Sequence length does not match label length in segment {r}")
        full_seq_parts.append(seq)
        full_labels.extend(lab)
    return "".join(full_seq_parts), full_labels

def _contiguous_ranges(labels: List[str], target: str) -> List[Tuple[int, int]]:
    """Find all contiguous ranges [start, end] (inclusive) where labels equal `target`."""
    ranges: List[Tuple[int, int]] = []
    start: Optional[int] = None
    for i, lab in enumerate(labels):
        if lab == target:
            if start is None:
                start = i
        else:
            if start is not None:
                ranges.append((start, i - 1))
                start = None
    if start is not None:
        ranges.append((start, len(labels) - 1))
    return ranges

def _slice_seq_labels(seq: str, labels: List[str], s: int, e: int) -> Tuple[str, List[str]]:
    if s > e:
        return "", []
    return seq[s:e+1], labels[s:e+1]

def _preceding_k(seq: str, labels: List[str], start_idx: int, k: int = 7) -> Tuple[str, List[str], Tuple[int, int]]:
    """Return the k residues immediately before start_idx (best effort), their labels, and the (s,e) range."""
    if start_idx <= 0:
        return "", [], (-1, -1)
    s = max(0, start_idx - k)
    e = start_idx - 1
    frag_seq, frag_lab = _slice_seq_labels(seq, labels, s, e)
    return frag_seq, frag_lab, (s, e)

def _following_k(seq: str, labels: List[str], end_idx: int, k: int = 7) -> Tuple[str, List[str], Tuple[int, int]]:
    """Return the k residues immediately after end_idx (best effort), their labels, and the (s,e) range."""
    n = len(seq)
    if end_idx >= n - 1:
        return "", [], (-1, -1)
    s = end_idx + 1
    e = min(n - 1, end_idx + k)
    frag_seq, frag_lab = _slice_seq_labels(seq, labels, s, e)
    return frag_seq, frag_lab, (s, e)

def extract_cdr_and_flanks(
    six_segments: Dict[str, Dict[str, Tuple[str, List[str]]]],
    k: int = 7
) -> Dict[str, Dict[str, Dict[str, Any]]]:
    """
    Extract CDR1/2/3 sequences and ranges for Heavy/Light chains, plus the k residues before/after each CDR.
    Return structure:
      result[chain][cdr] = {
        "sequence": str,
        "ranges": List[(s,e)],
        "preceding7_seq": str, "preceding7_labels": List[str], "preceding7_range": (s,e),
        "following7_seq": str, "following7_labels": List[str], "following7_range": (s,e),
      }
    """
    result: Dict[str, Dict[str, Dict[str, Any]]] = {}
    for chain in ("heavy", "light"):
        full_seq, full_labels = _concat_chain(six_segments[chain])
        chain_out: Dict[str, Dict[str, Any]] = {}
        for cdr_tag in ("CDR1", "CDR2", "CDR3"):
            ranges = _contiguous_ranges(full_labels, cdr_tag)
            cdr_seq = "".join(full_seq[s:e+1] for s, e in ranges) if ranges else ""

            if ranges:
                first_start = ranges[0][0]
                last_end = ranges[-1][1]
                pre_seq, pre_lab, pre_rng = _preceding_k(full_seq, full_labels, first_start, k=k)
                fol_seq, fol_lab, fol_rng = _following_k(full_seq, full_labels, last_end, k=k)
            else:
                pre_seq, pre_lab, pre_rng = "", [], (-1, -1)
                fol_seq, fol_lab, fol_rng = "", [], (-1, -1)

            chain_out[cdr_tag] = {
                "sequence": cdr_seq,
                "ranges": ranges,
                "preceding7_seq": pre_seq,
                "preceding7_labels": pre_lab,
                "preceding7_range": pre_rng,
                "following7_seq": fol_seq,
                "following7_labels": fol_lab,
                "following7_range": fol_rng,
            }
        result[chain] = chain_out
    return result

def _pretty_print_cdr(info: Dict[str, Dict[str, Dict[str, Any]]], k: int = 7) -> None:
    for chain in ("heavy", "light"):
        print(f"\n=== {chain.upper()} ===")
        for cdr in ("CDR1", "CDR2", "CDR3"):
            item = info[chain][cdr]
            seq = item["sequence"]
            ranges = item["ranges"]
            pre7 = item["preceding7_seq"]; pre7_labels = item["preceding7_labels"]; ps, pe = item["preceding7_range"]
            fol7 = item["following7_seq"];  fol7_labels = item["following7_labels"]; fs, fe = item["following7_range"]
            print(f"- {cdr}:")
            print(f"  ranges: {ranges if ranges else '[]'}")
            print(f"  sequence ({len(seq)} aa): {seq if seq else 'N/A'}")
            print(f"  preceding{k} ({ps}..{pe}): {pre7 if pre7 else 'N/A'}")
            if pre7_labels: print(f"  preceding{k} labels: {pre7_labels}")
            print(f"  following{k} ({fs}..{fe}): {fol7 if fol7 else 'N/A'}")
            if fol7_labels: print(f"  following{k} labels: {fol7_labels}")


In [2]:
# -*- coding: utf-8 -*-
import heapq
from typing import Dict, List, Tuple, Optional

def read_kmer_set_from_csv(file_path):
    import pandas as pd
    df = pd.read_csv(file_path)
    # Filter kmers of length 7 and convert to a dict
    kmer_set = {kmer: count for kmer, count in zip(df['kmer'], df['count']) if len(kmer) == 7}
    return kmer_set

def beam_search(start_kmer: str,
                start_kmer_template: Optional[str],
                kmers: Dict[str, int],
                template_sequence: str,
                template_labels: List[str],
                beam_width: int,
                max_iterations: int,
                top_n: int = 3,
                dist_threshold: int = 5,
                cdr_tail: int = 6,
                stop_sequence: Optional[str] = None,
                direction: int = 1,
                distance_guard: int = 1,                   # 1 = enable distance guard, -1 = disable
                ban_sequences: Optional[List[str]] = None, # Long-sequence blacklist: convert I→L then generate all 7-mers
                template_weight_mode: int = 1,             # 1 = template-weighted; -1 = counts only, no weighting
                min_overlap: int = 4                       # New: minimal overlap ∈ {4,5,6}
                ) -> Dict[str, int]:
    """
    Beam search on fixed 7-mers (auto-anchoring + bi-directional extension)

    Additional params:
      - template_weight_mode:
          1 : template-weighted scoring (Conserved×10, non-Conserved×2; mismatches not up-weighted)
         -1 : rely solely on kmer counts (no template multiplier)
      - stop_sequence: any length (≥1); in forward mode compare to suffix, in reverse mode compare to prefix
      - min_overlap: minimal overlap used for extension (4/5/6)
          * 6 -> use only 6-overlap
          * 5 -> use 6 and 5
          * 4 -> use 6, 5, 4

    Returns: {assembled_sequence: score}
    """
    # -------------- checks --------------
    assert direction in (1, -1), "direction must be 1 (forward) or -1 (reverse)"
    assert len(start_kmer) == 7, "start_kmer must be length-7"
    if start_kmer_template is not None:
        assert len(start_kmer_template) == 7, "start_kmer_template must be length-7"
    if stop_sequence is not None:
        assert isinstance(stop_sequence, str) and len(stop_sequence) >= 1, "stop_sequence must be a non-empty string"
    assert len(template_sequence) == len(template_labels), "template_sequence and template_labels must align"
    assert template_weight_mode in (1, -1), "template_weight_mode must be 1 or -1"
    assert min_overlap in (4, 5, 6), "min_overlap must be one of {4,5,6}"

    # -------------- helpers --------------
    def levenshtein(a: str, b: str) -> int:
        la, lb = len(a), len(b)
        if la == 0: return lb
        if lb == 0: return la
        prev = list(range(lb + 1)); curr = [0] * (lb + 1)
        for i in range(1, la + 1):
            curr[0] = i; ai = a[i - 1]
            for j in range(1, lb + 1):
                cost = 0 if ai == b[j - 1] else 1
                curr[j] = min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost)
            prev, curr = curr, prev
        return prev[lb]

    def is_conserved(label) -> bool:
        return str(label) == "Conserved"

    def is_cdr(label) -> bool:
        u = str(label).upper()
        return u.startswith("CDR1") or u.startswith("CDR2") or u.startswith("CDR3")

    def build_protected_indices(labels: List[str], length: int, tail: int) -> set:
        """Symmetric protection: for each CDR, protect 'tail' template positions on both sides."""
        prot = set()
        for i, lab in enumerate(labels):
            if is_cdr(lab):
                s = max(0, i - tail)
                e = min(length - 1, i + tail)
                for j in range(s, e + 1):
                    prot.add(j)
        return prot

    def find_all_positions(seq: str, sub: str) -> List[int]:
        out, i = [], seq.find(sub, 0)
        while i != -1:
            out.append(i)
            i = seq.find(sub, i + 1)
        return out

    # -------------- Blacklist: derive banned 7-mers from long sequences (I→L for blacklist only; kmers themselves have L) --------------
    banned7L: set = set()
    if ban_sequences:
        for long_seq in ban_sequences:
            if not long_seq:
                continue
            sL = long_seq.upper().replace('I', 'L')
            if len(sL) >= 7:
                for i in range(len(sL) - 6):
                    banned7L.add(sL[i:i+7])

    def contains_banned7(seq: str) -> bool:
        if not banned7L:
            return False
        s = seq.upper()
        for i in range(0, len(s) - 6):
            if s[i:i+7] in banned7L:
                return True
        return False

    # -------------- basics --------------
    n_template = len(template_sequence)
    protected_idx = build_protected_indices(template_labels, n_template, cdr_tail)

    # Build overlap list and priority based on min_overlap (larger overlap = higher priority)
    overlap_order = list(range(6, min_overlap - 1, -1))  # e.g., min=5 -> [6,5]
    # Higher overlap gets higher rank score
    overlap_priority = {ov: (len(overlap_order) - i) for i, ov in enumerate(overlap_order)}

    anchor = start_kmer_template if start_kmer_template is not None else start_kmer
    positions = find_all_positions(template_sequence, anchor)
    if not positions:
        return {}

    # A score adjuster controlled by template_weight_mode
    def adjust_score(cnt: int, use_template: bool, tgt_pos: int, ext_aa: str) -> int:
        """
        Adjust score by template match, with I/L equivalence considered.
        """
        if template_weight_mode == -1:
            return cnt

        if use_template and 0 <= tgt_pos < n_template:
            template_aa = template_sequence[tgt_pos]
            # I/L equivalence
            if template_aa == 'I' and ext_aa == 'L':
                template_aa = 'L'
            if template_aa == ext_aa:
                return cnt * (10 if is_conserved(template_labels[tgt_pos]) else 2)
        return cnt

    # -------------- Single-anchor search --------------
    def search_from_pos(pos: int) -> Dict[str, int]:
        def next_ext_pos(curr_len: int) -> int:
            added = curr_len - 7
            return (pos + 7 + added) if direction == 1 else (pos - 1 - added)

        def expand_fn(current_seq: str) -> List[Tuple[str, int, str]]:
            tgt_pos = next_ext_pos(len(current_seq))
            use_template = (0 <= tgt_pos < n_template)
            cand_by_aa: Dict[str, Tuple[str, int, int]] = {}

            for ov in overlap_order:
                if ov > len(current_seq):
                    continue

                if direction == 1:
                    anchor_sub = current_seq[-ov:]
                    ext_idx = -1 - (6 - ov)   # 6/5/4 -> -1/-2/-3
                    for kmer, cnt in kmers.items():
                        # Candidate-level blacklist
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.startswith(anchor_sub):
                            continue
                        if abs(ext_idx) > len(kmer):
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                else:
                    anchor_sub = current_seq[:ov]
                    for kmer, cnt in kmers.items():
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.endswith(anchor_sub):
                            continue
                        if len(kmer) <= ov:
                            continue
                        ext_idx = len(kmer) - ov - 1  # 6/5/4 -> 0/1/2
                        if ext_idx < 0:
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

            pot = [(k, s, aa) for aa, (k, s, _) in cand_by_aa.items()]
            pot.sort(key=lambda x: -x[1])
            return pot[:beam_width]

        frontier: List[Tuple[str, int]] = [(start_kmer, 0)]

        for _ in range(max_iterations):
            new_frontier: List[Tuple[str, int]] = []
            for seq, sc in frontier:
                for kmer, ext_sc, aa in expand_fn(seq):
                    new_seq = (seq + aa) if direction == 1 else (aa + seq)
                    new_sc  = sc + ext_sc

                    # Sequence-level blacklist
                    if banned7L and len(new_seq) >= 7 and contains_banned7(new_seq):
                        continue

                    # Duplicate 7-mer filter
                    if len(new_seq) >= 7:
                        if direction == 1:
                            tail7 = new_seq[-7:]
                            if tail7 in new_seq[:-7]:
                                continue
                        else:
                            head7 = new_seq[:7]
                            if head7 in new_seq[7:]:
                                continue

                    # ---- Termination check (arbitrary-length stop_sequence) ----
                    if stop_sequence:
                        slen = len(stop_sequence)
                        if len(new_seq) >= slen:
                            if (direction == 1 and new_seq[-slen:] == stop_sequence) or \
                               (direction == -1 and new_seq[:slen] == stop_sequence):
                                return {new_seq: new_sc}

                    # 7-mer distance guard (within template & non-protected sites; governed by distance_guard)
                    if distance_guard == 1:
                        tgt_pos = next_ext_pos(len(seq))  # template coordinate of the new AA (before extension)
                        if (0 <= tgt_pos < n_template) and (tgt_pos not in protected_idx) and len(new_seq) >= 7:
                            if direction == 1:
                                q7 = new_seq[-7:]; t_start = max(0, tgt_pos - 6); t_end = tgt_pos + 1
                            else:
                                q7 = new_seq[:7];  t_start = tgt_pos;              t_end = min(n_template, tgt_pos + 7)
                            templ7 = template_sequence[t_start:t_end]
                            if len(templ7) == 7 and levenshtein(q7, templ7) >= dist_threshold:
                                continue

                    new_frontier.append((new_seq, new_sc))

            if not new_frontier:
                break
            frontier = heapq.nlargest(beam_width, new_frontier, key=lambda x: x[1])

        best = heapq.nlargest(top_n, frontier, key=lambda x: x[1])
        return {seq: score for (seq, score) in best}

    # -------------- Multi-anchor aggregation --------------
    aggregated: Dict[str, int] = {}
    for p in positions:
        sub = search_from_pos(p)
        for s, v in sub.items():
            if (s not in aggregated) or (v > aggregated[s]):
                aggregated[s] = v

    if not aggregated:
        return {}
    return dict(sorted(aggregated.items(), key=lambda x: -x[1])[:top_n])


In [3]:
# -*- coding: utf-8 -*-
import heapq
from typing import Dict, List, Tuple, Optional

def beam_search_C(start_kmer: str,
                  start_kmer_template: Optional[str],
                  kmers: Dict[str, int],
                  template_sequence: str,
                  template_labels: List[str],
                  beam_width: int,
                  max_iterations: int,
                  top_n: int = 3,
                  dist_threshold: int = 5,
                  cdr_tail: int = 6,
                  stop_sequence: Optional[str] = None,
                  direction: int = 1,                    # 1=forward, -1=reverse
                  distance_guard: int = 1,               # 1=enable distance guard, -1=disable
                  ban_sequences: Optional[List[str]] = None,  # Long-sequence blacklist: convert I→L, then generate all 7-mers
                  template_weight_mode: int = 1,         # 1=template-weighted; -1=use count only (no weighting)
                  min_overlap: int = 4                   # Minimal overlap (∈{4,5,6}); actual range is [min_overlap..6]
                  ) -> Dict[str, int]:
    """
    Beam search on fixed 7-mers (auto-anchoring + bi-directional extension)

    Key parameter:
      - min_overlap: set the minimal overlap used.
          =6 -> only overlap=6
          =5 -> use overlap=6,5
          =4 -> use overlap=6,5,4

    Count thresholds (defaults):
      - overlap=6: count ≥ 1
      - overlap=5/4: count ≥ 2

    Special rule:
      - When using ov=6 at the current step and the template coordinate is valid, if
        the template amino acid template_sequence[tgt_pos] is not present among candidates,
        forcibly add a synthetic candidate (synthetic kmer; count=1). Then apply the
        normal template weighting (Conserved×10, non-Conserved×5).

    Changes:
      - Template weighting respects I/L equivalence; candidates with final adjusted score 1 are discarded.
      - Distance guard and blacklist checks also respect I/L equivalence to avoid inconsistencies caused by I vs L.
    """
    # -------------- checks --------------
    assert direction in (1, -1), "direction must be 1 (forward) or -1 (reverse)"
    assert len(start_kmer) == 7, "start_kmer must be length-7"
    if start_kmer_template is not None:
        assert len(start_kmer_template) == 7, "start_kmer_template must be length-7"
    if stop_sequence is not None:
        assert len(stop_sequence) == 5, "stop_sequence must be length-5"
    assert len(template_sequence) == len(template_labels), "template_sequence and template_labels must align"
    assert template_weight_mode in (1, -1), "template_weight_mode must be 1 or -1"
    assert min_overlap in (4, 5, 6), "min_overlap must be one of {4,5,6}"

    # -------------- helpers --------------
    def il_norm(s: str) -> str:
        # I/L equivalence: normalize to L (handle both cases)
        return s.replace('I', 'L').replace('i', 'l')

    def il_key(ch: str) -> str:
        # Single-character I/L equivalence key
        return 'L' if ch in ('I', 'L', 'i', 'l') else ch

    def levenshtein(a: str, b: str) -> int:
        # Edit distance under I/L normalization
        a = il_norm(a)
        b = il_norm(b)
        la, lb = len(a), len(b)
        if la == 0: return lb
        if lb == 0: return la
        prev = list(range(lb + 1)); curr = [0] * (lb + 1)
        for i in range(1, la + 1):
            curr[0] = i; ai = a[i - 1]
            for j in range(1, lb + 1):
                cost = 0 if ai == b[j - 1] else 1
                curr[j] = min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost)
            prev, curr = curr, prev
        return prev[lb]

    def is_conserved(label) -> bool:
        return str(label) == "Conserved"

    def is_cdr(label) -> bool:
        u = str(label).upper()
        return u.startswith("CDR1") or u.startswith("CDR2") or u.startswith("CDR3")

    def build_protected_indices(labels: List[str], length: int, tail: int) -> set:
        """Symmetric protection: for each CDR, protect 'tail' template positions on both sides."""
        prot = set()
        for i, lab in enumerate(labels):
            if is_cdr(lab):
                s = max(0, i - tail)
                e = min(length - 1, i + tail)
                for j in range(s, e + 1):
                    prot.add(j)
        return prot

    def find_all_positions(seq: str, sub: str) -> List[int]:
        out, i = [], seq.find(sub, 0)
        while i != -1:
            out.append(i)
            i = seq.find(sub, i + 1)
        return out

    # -------------- Blacklist: derive banned 7-mers from long sequences (I→L for blacklist only; kmers themselves use L) --------------
    banned7L: set = set()
    if ban_sequences:
        for long_seq in ban_sequences:
            if not long_seq:
                continue
            sL = il_norm(long_seq.upper())
            if len(sL) >= 7:
                for i in range(len(sL) - 6):
                    banned7L.add(sL[i:i+7])

    def contains_banned7(seq: str) -> bool:
        if not banned7L:
            return False
        s = il_norm(seq.upper())
        for i in range(0, len(s) - 6):
            if s[i:i+7] in banned7L:
                return True
        return False

    # -------------- basics --------------
    n_template = len(template_sequence)
    protected_idx = build_protected_indices(template_labels, n_template, cdr_tail)

    # Build overlap order from min_overlap downwards (largest to smallest)
    overlap_order = [ov for ov in (6, 5, 4) if ov >= min_overlap]
    # Priority: earlier indices have higher priority
    overlap_priority = {ov: (len(overlap_order) - i) for i, ov in enumerate(overlap_order)}
    # Minimal counts: 6→1; 5/4→2
    min_count_by_overlap = {ov: (1 if ov == 6 else 2) for ov in overlap_order}

    anchor = start_kmer_template if start_kmer_template is not None else start_kmer
    positions = find_all_positions(template_sequence, anchor)
    if not positions:
        return {}

    # Template-weighted scoring (with I/L equivalence)
    def adjust_score(cnt: int, use_template: bool, tgt_pos: int, ext_aa: str) -> int:
        if template_weight_mode == -1:
            return cnt
        if use_template and 0 <= tgt_pos < n_template:
            templ_aa = template_sequence[tgt_pos]
            templ_norm = 'L' if templ_aa in ('I', 'L') else templ_aa
            ext_norm   = 'L' if ext_aa in ('I', 'L') else ext_aa
            if templ_norm == ext_norm:
                return cnt * (10 if is_conserved(template_labels[tgt_pos]) else 5)
        return cnt

    # -------------- single-anchor search --------------
    def search_from_pos(pos: int) -> Dict[str, int]:
        def next_ext_pos(curr_len: int) -> int:
            added = curr_len - 7
            return (pos + 7 + added) if direction == 1 else (pos - 1 - added)

        def expand_fn(current_seq: str) -> List[Tuple[str, int, str]]:
            tgt_pos = next_ext_pos(len(current_seq))
            use_template = (0 <= tgt_pos < n_template)
            cand_by_aa: Dict[str, Tuple[str, int, int]] = {}

            for ov in overlap_order:
                if ov > len(current_seq):
                    continue

                if direction == 1:
                    anchor_sub = current_seq[-ov:]
                    ext_idx = -1 - (6 - ov)   # ov=6/5/4 -> -1/-2/-3
                    for kmer, cnt in kmers.items():
                        if banned7L and il_norm(kmer.upper()) in banned7L:
                            continue
                        if not kmer.startswith(anchor_sub):
                            continue
                        if cnt < min_count_by_overlap[ov]:
                            continue
                        if abs(ext_idx) > len(kmer):
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        # Discard candidates with adjusted score 1
                        if adj == 1:
                            continue

                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                    # ★ When ov=6, if the template AA is missing among candidates, add a synthetic candidate (forward)
                    if 6 in overlap_order and ov == 6 and use_template:
                        templ_aa = template_sequence[tgt_pos]
                        cand_keys_norm = {il_key(k) for k in cand_by_aa.keys()}
                        if il_key(templ_aa) not in cand_keys_norm:
                            synthetic_kmer = anchor_sub + templ_aa
                            if not (banned7L and il_norm(synthetic_kmer.upper()) in banned7L):
                                adj = adjust_score(1, use_template, tgt_pos, templ_aa)
                                if adj != 1:
                                    pri = overlap_priority[6]
                                    cand_by_aa[templ_aa] = (synthetic_kmer, adj, pri)

                else:
                    anchor_sub = current_seq[:ov]
                    for kmer, cnt in kmers.items():
                        if banned7L and il_norm(kmer.upper()) in banned7L:
                            continue
                        if not kmer.endswith(anchor_sub):
                            continue
                        if len(kmer) <= ov:
                            continue
                        if cnt < min_count_by_overlap[ov]:
                            continue
                        ext_idx = len(kmer) - ov - 1  # 6/5/4 -> 0/1/2
                        if ext_idx < 0:
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        if adj == 1:
                            continue

                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                    # ★ When ov=6, add template AA if missing (reverse)
                    if 6 in overlap_order and ov == 6 and use_template:
                        templ_aa = template_sequence[tgt_pos]
                        cand_keys_norm = {il_key(k) for k in cand_by_aa.keys()}
                        if il_key(templ_aa) not in cand_keys_norm:
                            synthetic_kmer = templ_aa + anchor_sub
                            if not (banned7L and il_norm(synthetic_kmer.upper()) in banned7L):
                                adj = adjust_score(1, use_template, tgt_pos, templ_aa)
                                if adj != 1:
                                    pri = overlap_priority[6]
                                    cand_by_aa[templ_aa] = (synthetic_kmer, adj, pri)

            pot = [(k, s, aa) for aa, (k, s, _) in cand_by_aa.items()]
            pot.sort(key=lambda x: -x[1])
            return pot[:beam_width]

        frontier: List[Tuple[str, int]] = [(start_kmer, 0)]

        for _ in range(max_iterations):
            new_frontier: List[Tuple[str, int]] = []
            for seq, sc in frontier:
                for kmer, ext_sc, aa in expand_fn(seq):
                    new_seq = (seq + aa) if direction == 1 else (aa + seq)
                    new_sc  = sc + ext_sc

                    # Sequence-level blacklist
                    if banned7L and len(new_seq) >= 7 and contains_banned7(new_seq):
                        continue

                    # Duplicate 7-mer filter (keep original logic: no I/L normalization here)
                    if len(new_seq) >= 7:
                        if direction == 1:
                            tail7 = new_seq[-7:]
                            if tail7 in new_seq[:-7]:
                                continue
                        else:
                            head7 = new_seq[:7]
                            if head7 in new_seq[7:]:
                                continue

                    # 5-mer termination
                    if stop_sequence is not None and len(new_seq) >= 5:
                        if (direction == 1 and new_seq[-5:] == stop_sequence) or \
                           (direction == -1 and new_seq[:5] == stop_sequence):
                            return {new_seq: new_sc}

                    # 7-mer distance guard (within template & at non-protected sites)
                    if distance_guard == 1:
                        tgt_pos = (pos + len(new_seq)) if direction == 1 else (pos - len(new_seq) + 6)
                        if (0 <= tgt_pos < n_template) and (tgt_pos not in protected_idx) and len(new_seq) >= 7:
                            if direction == 1:
                                q7 = new_seq[-7:]; t_start = max(0, tgt_pos - 6); t_end = tgt_pos + 1
                            else:
                                q7 = new_seq[:7];  t_start = tgt_pos;              t_end = min(n_template, tgt_pos + 7)
                            templ7 = template_sequence[t_start:t_end]
                            if len(templ7) == 7 and levenshtein(q7, templ7) >= dist_threshold:
                                continue

                    new_frontier.append((new_seq, new_sc))

            if not new_frontier:
                break
            frontier = heapq.nlargest(beam_width, new_frontier, key=lambda x: x[1])

        best = heapq.nlargest(top_n, frontier, key=lambda x: x[1])
        return {seq: score for (seq, score) in best}

    # -------------- multi-anchor aggregation --------------
    aggregated: Dict[str, int] = {}
    for p in positions:
        sub = search_from_pos(p)
        for s, v in sub.items():
            if (s not in aggregated) or (v > aggregated[s]):
                aggregated[s] = v

    if not aggregated:
        return {}
    return dict(sorted(aggregated.items(), key=lambda x: -x[1])[:top_n])


In [4]:
import pandas as pd
import Levenshtein
import time

def suffix_prefix_intersect(s1: str, s2: str) -> str:
    """Return the longest overlap between the suffix of s1 and the prefix of s2; 
    return empty string if there is no overlap."""
    k = min(len(s1), len(s2))
    while k > 0 and s1[-k:] != s2[:k]:
        k -= 1
    return s2[:k]  # Alternatively, return s1[-k:]

def normalize(seq: str) -> str:
    """Treat I and L as equivalent by mapping I to L."""
    if not isinstance(seq, str):
        seq = str(seq)
    return seq.replace('I', 'L')

def get_candidates(query_sequence, kmer_set, max_distance=3, max_overlap=3, conserved_positions=None):
    """
    Retrieve candidate kmers that satisfy the distance and maximum-overlap constraints.
    conserved_positions: a list of positions that are conserved; the candidate k-mer must match
                         query_sequence at these positions (after I/L normalization).
    Returns a list of tuples (kmer, count, distance) sorted by count descending.
    """
    candidates = []
    q_norm = normalize(query_sequence)

    # If not provided, default to empty list
    if conserved_positions is None:
        conserved_positions = []

    for kmer, count in kmer_set.items():
        k_norm = normalize(kmer)

        # Check matches at conserved positions
        if conserved_positions:
            for pos in conserved_positions:
                if q_norm[pos] != k_norm[pos]:  # mismatch at a conserved site
                    break
            else:  # all conserved positions matched
                distance = Levenshtein.distance(q_norm, k_norm)
                if distance <= max_distance and max_overlap_count(q_norm, k_norm) <= max_overlap:
                    candidates.append((kmer, count, distance))
        else:
            distance = Levenshtein.distance(q_norm, k_norm)
            if distance <= max_distance and max_overlap_count(q_norm, k_norm) <= max_overlap:
                candidates.append((kmer, count, distance))

    # Sort by count (descending)
    candidates_sorted = sorted(candidates, key=lambda x: x[1], reverse=True)
    return candidates_sorted

def max_overlap_count(seq1, seq2):
    """
    Compute the maximum number of identical characters in overlap when shifting left/right.
    I==L equivalence is already handled via normalization.
    """
    s1 = normalize(seq1)
    s2 = normalize(seq2)
    len1, len2 = len(s1), len(s2)
    max_overlap = 0

    # Left shift: move s1 to the right
    for i in range(1, len1):
        n = min(len1 - i, len2)
        if n <= 0:
            break
        overlap = sum(1 for j in range(n) if s1[i + j] == s2[j])
        if overlap > max_overlap:
            max_overlap = overlap

    # Right shift: move s2 to the right
    for i in range(1, len2):
        n = min(len2 - i, len1)
        if n <= 0:
            break
        overlap = sum(1 for j in range(n) if s1[j] == s2[i + j])
        if overlap > max_overlap:
            max_overlap = overlap

    return max_overlap

def get_seed(query_sequence, kmer_set, max_distance=3, max_overlap=3, conserved_positions=None):
    """
    Return the best-scoring seed as (kmer, score), where
    score = count * (2 ** (3 - distance)).
    """
    candidates = get_candidates(query_sequence, kmer_set, max_distance, max_overlap, conserved_positions)
    if not candidates:
        return None

    best = None
    best_score = float("-inf")
    for kmer, count, distance in candidates:
        score = count * (2 ** (3 - distance))
        if score > best_score:
            best = [kmer, score]
            best_score = score

    return best


In [5]:
def get_modified_template(template: str, target: str, placeholder: str = 'X') -> str:
    """
    Under the constraint of allowing only **one contiguous insertion or deletion**,
    adjust the length of `template` to match `target`, choosing the position that
    minimizes the standard Levenshtein distance (insert/delete/substitute cost = 1;
    characters must be exactly equal to match).

    - If `target` is longer: insert `placeholder * n` into `template`.
    - If `target` is shorter: delete one contiguous block of length `n` from `template`.
    - If lengths are equal: return `template` unchanged (character differences are not fixed).
    """

    def lev(a: str, b: str) -> int:
        # Standard Levenshtein distance (no equivalence rules)
        if len(a) < len(b):
            a, b = b, a
        if not b:
            return len(a)
        prev = list(range(len(b) + 1))
        for i, ca in enumerate(a, 1):
            curr = [i]
            for j, cb in enumerate(b, 1):
                cost_sub = 0 if ca == cb else 1
                curr.append(min(
                    prev[j] + 1,        # delete one char from a
                    curr[j - 1] + 1,    # insert one char into a
                    prev[j - 1] + cost_sub  # substitute/match
                ))
            prev = curr
        return prev[-1]

    len_t, len_g = len(template), len(target)
    if len_t == len_g:
        return template

    if len_g > len_t:
        # Need to insert n placeholders into template
        n = len_g - len_t
        best_i, best_d = 0, float('inf')
        # “Virtually remove” a block of length n from target to find the position
        # whose modified target has the smallest distance to template
        for i in range(len_g - n + 1):
            cand = target[:i] + target[i + n:]
            d = lev(cand, template)
            if d < best_d:
                best_d, best_i = d, i
        insert_pos = min(best_i, len_t)
        return template[:insert_pos] + (placeholder * n) + template[insert_pos:]
    else:
        # Need to delete a contiguous block of length n from template
        n = len_t - len_g
        best_i, best_d = 0, float('inf')
        for i in range(len_t - n + 1):
            cand = template[:i] + template[i + n:]
            d = lev(cand, target)
            if d < best_d:
                best_d, best_i = d, i
        return template[:best_i] + template[best_i + n:]


In [6]:
def replace_cdr_in_template(
    template: Dict[str, Dict[str, Tuple[str, List[str]]]],
    new_cdr_seq: str,
    chain: str = "light",   # "light" or "heavy"
    region: str = "V",      # "V" or "J"
    cdr_tag: str = "CDR1"   # "CDR1" / "CDR2" / "CDR3"
) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
    """
    Replace only the region labeled `cdr_tag` within template[chain][region] by `new_cdr_seq`,
    and update the labels accordingly; do NOT recompute any cdr_info.

    Note: If this label appears as multiple non-adjacent blocks in that segment,
          the overall span from the first block's start to the last block's end
          is merged into a single interval for replacement.
    """
    # 1) Fetch the segment
    if chain not in template or region not in template[chain]:
        raise KeyError(f"{chain}.{region} segment not found in template")
    seq, labels = template[chain][region]
    if len(seq) != len(labels):
        raise ValueError(f"{chain}.{region} sequence/label length mismatch: len(seq)={len(seq)} len(labels)={len(labels)}")

    # 2) Find contiguous intervals for cdr_tag (using your module's _contiguous_ranges)
    ranges = _contiguous_ranges(labels, cdr_tag)
    if not ranges:
        raise ValueError(f"{cdr_tag} not found in {chain}.{region} segment")

    s, e = ranges[0][0], ranges[-1][1]

    # 3) Normalize new sequence and replace
    new_cdr = re.sub(r"\s+", "", new_cdr_seq).upper()
    new_seq    = seq[:s] + new_cdr + seq[e+1:]
    new_labels = labels[:s] + [cdr_tag] * len(new_cdr) + labels[e+1:]

    # 4) Write back
    template[chain][region] = (new_seq, new_labels)
    return template


In [7]:
# ====== Save results to files ======
import csv

def write_fasta(path: str, header: str, seq: str) -> None:
    """Simple FASTA writer: write the entire sequence on a single line (no wrapping)."""
    with open(path, "w", encoding="utf-8") as f:
        f.write(f">{header}\n")
        f.write(seq.rstrip("\n") + "\n")

def write_labels_csv(path: str, sequence: str, labels: list) -> None:
    """Write labels to CSV: pos (1-based), aa_template, label."""
    assert len(sequence) == len(labels), "template and label lengths do not match"
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["pos", "aa_template", "label"])
        for i, (aa, lab) in enumerate(zip(sequence, labels), start=1):
            w.writerow([i, aa, lab])


In [8]:
from pathlib import Path
# ===================== IV. Usage example (main) =====================
if __name__ == "__main__":
    # A) From tm.fasta, obtain the highest-scoring template name for each type
    fasta_file = r"/data/stitch-v1.4.0-windows/results/SA58/report-monoclonal-tm.fasta"
    Path("/results/SA58/Levenshtein").mkdir(parents=True, exist_ok=True)
    output_csv = r"/results/SA58/Levenshtein/template_search.csv"
    df = extract_template_info_from_fasta(fasta_file, output_csv_path=output_csv)

    # Pick the max score within each group
    best_template_name = (df.loc[df.groupby("type")["score"].idxmax()]
                            .set_index("type")["template"]
                            .to_dict())

    # B) Path to the configuration file
    cfg_path = r"/data/stitch-v1.4.0-windows/batchfiles/SA58.txt"

    # C) Compose six IDs using the best template names
    required = ["HV","HJ","HC","LV","LJ","LC"]
    miss = [k for k in required if k not in best_template_name or not best_template_name[k]]
    if miss:
        raise ValueError(f"Missing in best_template_name extracted from {fasta_file}: {miss}")

    target_ids = {
        "heavy": {"V": best_template_name["HV"], "J": best_template_name["HJ"], "C": best_template_name["HC"]},
        "light": {"V": best_template_name["LV"], "J": best_template_name["LJ"], "C": best_template_name["LC"]},
    }

    # D) Read and parse the six segments
    template, runname = get_six_segments(cfg_path, target_ids)
    print("Runname:", runname)
    for chain in ("heavy", "light"):
        for region in ("V", "J", "C"):
            seq, labels = template[chain][region]
            print(f"{chain}.{region}: {target_ids[chain][region]} | len={len(seq)} | preview={seq}")

    # E) Extract CDR1/2/3 and the preceding/following k=7 residues (sequence + labels) for heavy/light chains
    cdr_info = extract_cdr_and_flanks(template, k=7)
    _pretty_print_cdr(cdr_info, k=7)


✅ Saved to: /results/SA58/Levenshtein/template_search.csv
Runname: SA58
heavy.V: IGHV7-4-1 | len=98 | preview=QVQLVQSGSELKKPGASVKVSCKASGYTFTSYAMNWVRQAPGQGLEWMGWINTNTGNPTYAQGFTGRFVFSLDTSVSTAYLQICSLKAEDTAVYYCAR
heavy.J: IGHJ4 | len=15 | preview=YFDYWGQGTLVTVSS
heavy.C: IGHG1 | len=330 | preview=ASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKSCDKTHTCPPCPAPELLGGPSVFLFPPKPKDTLMISRTPEVTCVVVDVSHEDPEVKFNWYVDGVEVHNAKTKPREEQYNSTYRVVSVLTVLHQDWLNGKEYKCKVSNKALPAPIEKTISKAKGQPREPQVYTLPPSRDELTKNQVSLTCLVKGFYPSDIAVEWESNGQPENNYKTTPPVLDSDGSFFLYSKLTVDKSRWQQGNVFSCSVMHEALHNHYTQKSLSLSPGK
light.V: IGKV3-15 | len=95 | preview=EIVMTQSPATLSVSPGERATLSCRASQSVSSNLAWYQQKPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDFAVYYCQQYNNWP
light.J: IGKJ4 | len=12 | preview=LTFGGGTKVEIK
light.C: IGKC | len=107 | preview=RTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC

=== HEAVY ===
- CDR1:
  ranges: [(25, 32)]
  sequenc

In [9]:
HC_assembly = ''
LC_assembly = ''
file_path = '/data/Fusion/dataset/mAbs/human/SA58/casanovo/50-cleaned/kmer_50_Casanovo.csv'
out_path = '/results/SA58/Levenshtein'
kmer_set = read_kmer_set_from_csv(file_path)

# Light chain
template_sequence, template_labels = template['light']['V']
conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
start_query_sequence = template_sequence[0:7]
start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=None)[0]
start_kmer_template = template_sequence[0:7]
end_query_sequence = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])

ban_seqs = [template['light']['C'][0], template['heavy']['C'][0]]
if seed is None:
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, direction=1,distance_guard=1, ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    if seq[-1] == template_sequence[conser_position[0]]:
        LC_assembly = seq
    else:
        print('check the template information and kmers')  
else:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, stop_sequence=end_kmer,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    LC_assembly = seq


try:
    query_sequence = cdr_info['light']['CDR1']['following7_seq']
    next_start_kmer_template = cdr_info['light']['CDR1']['following7_seq']
    start_kmer = get_seed(query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=[2])[0]
    idx = cdr_info['light']['CDR1']['following7_range'][1]-6
    result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=-1, stop_sequence=LC_assembly[-2:], distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
    seq = next(iter(result))
    Len = len(suffix_prefix_intersect(LC_assembly, seq))
    if Len==2:
        LC_assembly = LC_assembly+seq[Len:]
        CDR1 = seq[Len+3:len(seq)-7]
except:
    start_kmer = LC_assembly[-7:]
    start_kmer_template = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
    end_query_sequence = template_sequence[conser_position[1]:(conser_position[1]+7)]
    next_start_kmer_template = template_sequence[conser_position[1]:(conser_position[1]+7)]
    try:
        end_kmer = get_seed(end_query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=[2])[0]
        idx = cdr_info['light']['CDR1']['following7_range'][1]-6
        result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=1, stop_sequence=end_kmer[:2], distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
        seq = next(iter(result))
        LC_assembly = LC_assembly+seq[7:-2]+end_kmer
        CDR1 = seq[10:-4]
    except:
        print('check the template information and kmers')  

if len(CDR1) != len(cdr_info['light']['CDR1']['sequence']):
    new_cdr1_LV = get_modified_template(cdr_info['light']['CDR1']['sequence'],CDR1)  
    template = replace_cdr_in_template(template, new_cdr1_LV, chain="light", region="V", cdr_tag="CDR1")
    template_sequence, template_labels = template['light']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']


start_kmer = LC_assembly[-7:]
start_kmer_template = next_start_kmer_template[-7:]
end_query_sequence = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[2]-len(LC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]


start_kmer = LC_assembly[-7:]
start_kmer_template = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
end_query_sequence = cdr_info['light']['CDR3']['preceding7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[3]-len(LC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]

start_kmer = LC_assembly[-7:]
start_kmer_template = cdr_info['light']['CDR3']['preceding7_seq']
end_query_sequence = cdr_info['light']['CDR3']['following7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[0,1,3])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(LC_assembly)+30, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
if end_kmer in seq:
    Len = len(seq) - len(end_kmer) - len(start_kmer)
    CDR3 = seq[7:len(seq)-7]
    LC_assembly = LC_assembly+seq[7:]

if len(CDR3) != len(cdr_info['light']['CDR3']['sequence']):
    new_cdr3_LV = get_modified_template(cdr_info['light']['CDR3']['sequence'],CDR3)  
    template = replace_cdr_in_template(template, new_cdr3_LV, chain="light", region="V", cdr_tag="CDR3")
    template_sequence, template_labels = template['light']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
    seq, labels = template['light']['J']
    seq_no_cdr3 = ''.join(aa for aa, lab in zip(seq, labels) if lab != 'CDR3')
    label_no_cdr3 = [lab for lab in labels if lab != 'CDR3']
    template['light']['V_J'] = ( template['light']['V'][0]+seq_no_cdr3, template['light']['V'][1]+label_no_cdr3)
else:
    template['light']['V_J'] = ( template['light']['V'][0]+template['light']['J'][0], template['light']['V'][1]+template['light']['J'][1])


template_sequence, template_labels = template['light']['V_J']
start_kmer = LC_assembly[-7:]
start_kmer_template = cdr_info['light']['CDR3']['following7_seq']
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template_sequence) - len(LC_assembly), top_n = 5,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]

start_kmer = LC_assembly[-7:]
start_kmer_template = template['light']['J'][0][-7:]
template_sequence = template['light']['V_J'][0] + template['light']['C'][0]
template_labels = template['light']['V_J'][1] + template['light']['C'][1]
result = beam_search_C(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template['light']['C'][0]), top_n = 5,direction=1,distance_guard=1, template_weight_mode=1)
seq = next(iter(result))
if len(seq)-7 == len(template['light']['C'][0]):
    LC_assembly = LC_assembly+seq[7:]
    LC_template = template['light']['V_J'][0] + template['light']['C'][0]
    LC_label = template['light']['V_J'][1] + template['light']['C'][1]

if len(LC_assembly) != len(LC_template):
    raise ValueError(
        f"Length mismatch: LC_assembly({len(LC_assembly)}) != LC_template({len(LC_template)}). "
    )

lc_list = list(LC_assembly)
for i, (aa_asm, aa_tpl) in enumerate(zip(LC_assembly, LC_template)):
    if aa_tpl == 'I' and aa_asm == 'L':
        lc_list[i] = 'I'
LC_assembly = ''.join(lc_list)

print(LC_assembly)
print(len(LC_assembly))

# 1) LC_assembly -> FASTA
assembly_fasta = os.path.join(out_path, "LC_assembly.fasta")
write_fasta(assembly_fasta, "LC", LC_assembly)

# 2) LC_template -> FASTA
template_fasta = os.path.join(out_path, "LC_template.fasta")
write_fasta(template_fasta, "LC_template", LC_template)

# 3) LC_label -> CSV (with position and template amino acid)
label_csv = os.path.join(out_path, "LC_label.csv")
write_labels_csv(label_csv, LC_assembly, LC_label)

print("Saved:")
print(" -", assembly_fasta)
print(" -", template_fasta)
print(" -", label_csv)


EVVMTQSPASLSVSPGERATLSCRARASLGLSTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDSALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC
216
Saved:
 - /results/SA58/Levenshtein/LC_assembly.fasta
 - /results/SA58/Levenshtein/LC_template.fasta
 - /results/SA58/Levenshtein/LC_label.csv


In [10]:
# Heavy chain
template_sequence, template_labels = template['heavy']['V']
conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
start_query_sequence = template_sequence[0:7]
start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=None)[0]
start_kmer_template = template_sequence[0:7]
end_query_sequence = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])

ban_seqs = [template['light']['C'][0], template['heavy']['C'][0]]
if seed is None:
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, direction=1,distance_guard=1, ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    if seq[-1] == template_sequence[conser_position[0]]:
        HC_assembly = seq
    else:
        print('check the template information and kmers')  
else:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = seq

try:
    query_sequence = cdr_info['heavy']['CDR1']['following7_seq']
    next_start_kmer_template = cdr_info['heavy']['CDR1']['following7_seq']
    start_kmer = get_seed(query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=[2])[0]
    idx = cdr_info['heavy']['CDR1']['following7_range'][1]-6
    result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=-1, stop_sequence=HC_assembly[-2:], distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
    seq = next(iter(result))
    Len = len(suffix_prefix_intersect(HC_assembly, seq))
    if Len==2:
        HC_assembly = HC_assembly+seq[Len:]
        CDR1 = seq[Len+3:len(seq)-7]
except:
    start_kmer = HC_assembly[-7:]
    start_kmer_template = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
    end_query_sequence = template_sequence[conser_position[1]:(conser_position[1]+7)]
    next_start_kmer_template = template_sequence[conser_position[1]:(conser_position[1]+7)]
    try:
        end_kmer = get_seed(end_query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=[2])[0]
        idx = cdr_info['heavy']['CDR1']['following7_range'][1]-6
        result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=1, stop_sequence=end_kmer[:2], distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
        seq = next(iter(result))
        HC_assembly = HC_assembly+seq[7:-2]+end_kmer
        CDR1 = seq[10:-4]
    except:
        print('check the template information and kmers')  
        
if len(CDR1) != len(cdr_info['heavy']['CDR1']['sequence']):
    new_cdr1_HV = get_modified_template(cdr_info['heavy']['CDR1']['sequence'],CDR1)  
    template = replace_cdr_in_template(template, new_cdr1_HV, chain="heavy", region="V", cdr_tag="CDR1")
    template_sequence, template_labels = template['heavy']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']

query_sequence1 = cdr_info['heavy']['CDR2']['preceding7_seq']
seed1 = get_seed(query_sequence1, kmer_set, max_distance=2, max_overlap=5, conserved_positions=None)
query_sequence2 = cdr_info['heavy']['CDR2']['following7_seq']
seed2 = get_seed(query_sequence2, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=None)
if seed1 and seed2 and seed1[0]==query_sequence1.replace('I','L') and seed2[0]==query_sequence2.replace('I','L'):
    start_kmer_template = template_sequence[-7:]
    start_kmer = HC_assembly[-7:]
    end_query_sequence = cdr_info['heavy']['CDR2']['preceding7_seq']
    end_kmer = get_seed(end_query_sequence, kmer_set, max_distance=2, max_overlap=5)[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=100, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = HC_assembly+seq[7:]

    start_query_sequence = cdr_info['heavy']['CDR2']['preceding7_seq']
    start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=2, max_overlap=5, conserved_positions=None)[0]
    start_kmer_template = start_query_sequence
    end_query_sequence = cdr_info['heavy']['CDR2']['following7_seq']
    end_kmer = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=None)[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=100, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = HC_assembly+seq[7:]
    CDR2 = seq[7:-7]
    if len(CDR2) != len(cdr_info['heavy']['CDR2']['sequence']):
        new_cdr2_HV = get_modified_template(cdr_info['heavy']['CDR2']['sequence'],CDR2)  
        template = replace_cdr_in_template(template, new_cdr2_HV, chain="heavy", region="V", cdr_tag="CDR2")
        template_sequence, template_labels = template['heavy']['V']
        conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']


    start_kmer = HC_assembly[-7:]
    start_kmer_template = cdr_info['heavy']['CDR2']['following7_seq']
    end_query_sequence = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
    seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[2]-len(HC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = HC_assembly+seq[7:]
else:
    start_kmer_template = next_start_kmer_template[-7:]
    start_kmer = HC_assembly[-7:]
    end_query_sequence = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
    seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
    if seed:
        end_kmer = seed[0] 
        result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=100, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
        seq = next(iter(result))
        HC_assembly = HC_assembly+seq[7:]
        CDR2 = HC_assembly[cdr_info['heavy']['CDR2']['ranges'][0][0]:-(len(HC_assembly)-cdr_info['heavy']['CDR2']['following7_range'][0])]
        if len(CDR2) != len(cdr_info['heavy']['CDR2']['sequence']):
            new_cdr2_HV = get_modified_template(cdr_info['heavy']['CDR2']['sequence'],CDR2)  
            template = replace_cdr_in_template(template, new_cdr2_HV, chain="heavy", region="V", cdr_tag="CDR2")
            template_sequence, template_labels = template['heavy']['V']
            conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']


start_kmer = HC_assembly[-7:]
start_kmer_template = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
end_query_sequence = cdr_info['heavy']['CDR3']['preceding7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[3]-len(HC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
HC_assembly = HC_assembly+seq[7:]

start_kmer = HC_assembly[-7:]
start_kmer_template = cdr_info['heavy']['CDR3']['preceding7_seq']
end_query_sequence = cdr_info['heavy']['CDR3']['following7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[0,1,3])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(HC_assembly)+30, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
if end_kmer in seq:
    Len = len(seq) - len(end_kmer) - len(start_kmer)
    CDR3 = seq[7:len(seq)-7]
    HC_assembly = HC_assembly+seq[7:]

if len(CDR3) != len(cdr_info['heavy']['CDR3']['sequence']):
    new_cdr3_HV = get_modified_template(cdr_info['heavy']['CDR3']['sequence'],CDR3)  
    template = replace_cdr_in_template(template, new_cdr3_HV, chain="heavy", region="V", cdr_tag="CDR3")
    template_sequence, template_labels = template['heavy']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
    seq, labels = template['heavy']['J']
    seq_no_cdr3 = ''.join(aa for aa, lab in zip(seq, labels) if lab != 'CDR3')
    label_no_cdr3 = [lab for lab in labels if lab != 'CDR3']
    template['heavy']['V_J'] = ( template['heavy']['V'][0]+seq_no_cdr3, template['heavy']['V'][1]+label_no_cdr3)
else:
    template['heavy']['V_J'] = ( template['heavy']['V'][0]+template['heavy']['J'][0], template['heavy']['V'][1]+template['heavy']['J'][1])


template_sequence, template_labels = template['heavy']['V_J']
start_kmer = HC_assembly[-7:]
start_kmer_template = cdr_info['heavy']['CDR3']['following7_seq']
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template_sequence) - len(HC_assembly), top_n = 5,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
HC_assembly = HC_assembly+seq[7:]


start_kmer = HC_assembly[-7:]
start_kmer_template = template['heavy']['J'][0][-7:]
template_sequence = template['heavy']['V_J'][0] + template['heavy']['C'][0]
template_labels = template['heavy']['V_J'][1] + template['heavy']['C'][1]
result = beam_search_C(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template['heavy']['C'][0]), top_n = 5,direction=1,distance_guard=1, template_weight_mode=1, min_overlap=6)
seq = next(iter(result))
if len(seq)-7 == len(template['heavy']['C'][0]):
    HC_assembly = HC_assembly+seq[7:]
    HC_template = template['heavy']['V_J'][0] + template['heavy']['C'][0]
    HC_label = template['heavy']['V_J'][1] + template['heavy']['C'][1]

if len(HC_assembly) != len(HC_template):
    raise ValueError(
        f"Length mismatch: HC_assembly({len(HC_assembly)}) != HC_template({len(HC_template)}). "
        "Abort I/L correction."
    )

hc_list = list(HC_assembly)
for i, (aa_asm, aa_tpl) in enumerate(zip(HC_assembly, HC_template)):
    if aa_tpl == 'I' and aa_asm == 'L':
        hc_list[i] = 'I'
HC_assembly = ''.join(hc_list)

print(HC_assembly)
print(len(HC_assembly))

# 1) HC_assembly -> FASTA
assembly_fasta = os.path.join(out_path, "HC_assembly.fasta")
write_fasta(assembly_fasta, "HC", HC_assembly)

# 2) HC_template -> FASTA
template_fasta = os.path.join(out_path, "HC_template.fasta")
write_fasta(template_fasta, "HC_template", HC_template)

# 3) HC_label -> CSV (with position and template amino acid)
label_csv = os.path.join(out_path, "HC_label.csv")
write_labels_csv(label_csv, HC_assembly, HC_label)

print("Saved:")
print(" -", assembly_fasta)
print(" -", template_fasta)
print(" -", label_csv)


QVQLAQSGSELRKPGASVKVSCDTSGHSFTSNALHWVRQAPGQGLEWMGWVNTDTGTPTYAQGFTGRFVFSLDTSARTAYLQISSLKADDTAVFYCARERDYSDYFFDYWGQGTLVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKSCDKTHTCPPCPAPELLGGPSVFLFPPKPKDTLMISRTPEVTCVVVDVSHEDPEVKFNWYVDGVEVHNAKTKPREEQYNSTYRVVSVLTVLHQDWLNGKEYKCKVSNKALPAPIEKTISKAKGQPREPQVYTLPPSRDELTKNQVSLTCLVKGFYPSDIAVEWESNGQPENNYKTTPPVLDSDGSFFLYSKLTVDKSRWQQGNVFSCSVMHEALHNHYTQKSLSLSPGK
450
Saved:
 - /results/SA58/Levenshtein/HC_assembly.fasta
 - /results/SA58/Levenshtein/HC_template.fasta
 - /results/SA58/Levenshtein/HC_label.csv


In [None]:
# =========================
# Configuration (edit as needed)
# =========================
# Write FASTA
fasta_path = Path(r'/results/SA58/Levenshtein/Fusion_Casanovo.fasta')
with open(fasta_path, 'w', encoding='utf-8') as f:
    f.write(">HC\n")
    f.write(HC_assembly + "\n")
    f.write(">LC\n")
    f.write(LC_assembly + "\n")

BASE_PATH = Path(r"/results/SA58/Levenshtein")
PROCESS1_PATH = Path(r"/data/Fusion/dataset/mAbs/human/SA58/process1")
FASTA_PATH = Path(r"/results/SA58/Levenshtein/Fusion_Casanovo.fasta")  # Your FASTA file
AB = "SA58"  # Antibody name (used in I/L assignment output filename)

SEARCHGUI_JAR = Path(r"/code/SearchGUI-4.3.15/SearchGUI-4.3.15.jar")
PEPTIDESHAKER_JAR = Path(r"/code/PeptideShaker-3.0.11/PeptideShaker-3.0.11.jar")
PARAMS_PATH = Path(r"/code/SearchGUI-4.3.15/para")  # Directory containing *.par parameter files

# =========================
# Executable script starts here
# =========================
import os
import shutil
import subprocess
import sys
import logging
import pandas as pd
import re
from datetime import datetime
from pyteomics import mgf, mass
import numpy as np
from collections import OrderedDict

def setup_logging(base_path, log_to_file=True, log_level=logging.INFO):
    """In Jupyter, log to both console and file, avoiding duplicate handlers."""
    logger = logging.getLogger()
    logger.setLevel(log_level)
    # Clear existing handlers to avoid duplicate logs
    for h in list(logger.handlers):
        logger.removeHandler(h)

    fmt = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    # Console
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(log_level)
    ch.setFormatter(fmt)
    logger.addHandler(ch)

    # File
    if log_to_file:
        log_file = Path(base_path) / 'automation_log.txt'
        fh = logging.FileHandler(log_file, encoding='utf-8')
        fh.setLevel(log_level)
        fh.setFormatter(fmt)
        logger.addHandler(fh)

    logging.info(f"Script started at {datetime.now()}")

def get_enzyme_from_filename(filename):
    """Parse enzyme name from .mgf filename (supports variants like Trypsin-ASPN)."""
    normalized_filename = filename.replace('-', '_')
    enzymes = {
        'ASPN': ['ASPN', 'AspN'],
        'Chymotrypsin': ['Chymotrypsin', 'chymo'],
        'Elastase': 'Elastase',
        'Pepsin': ['Pepsin', 'pepsin'],
        'Trypsin': ['Trypsin', 'trypsin']
    }
    for enzyme, variants in enzymes.items():
        if isinstance(variants, list):
            if any(variant in normalized_filename or variant.replace('_', '-') in filename for variant in variants):
                return enzyme
        elif variants in normalized_filename or variants.replace('_', '-') in filename:
            return enzyme
    logging.warning(f"No recognized enzyme in filename: {filename}")
    return None

def copy_fasta_file(fasta_path, ethcd_path):
    """Copy FASTA and rename it into the EThcD folder."""
    fasta_file = Path(fasta_path)
    if not fasta_file.exists():
        logging.error(f"FASTA file not found: {fasta_path}")
        raise FileNotFoundError(f"FASTA file not found: {fasta_path}")
    
    ethcd_fasta = ethcd_path / 'Fusion_Casanovo.fasta'
    shutil.copy(fasta_file, ethcd_fasta)
    logging.info(f"Copied FASTA file to {ethcd_fasta}")

def create_directory_structure(base_path):
    r"""Create searchGUI/Fusion/EThcD directory structure and per-enzyme subfolders."""
    fusion_path = Path(base_path) / 'searchGUI' / 'Fusion'
    ethcd_path = fusion_path / 'EThcD'
    
    for path in [fusion_path, ethcd_path]:
        path.mkdir(parents=True, exist_ok=True)
        logging.info(f"Created directory: {path}")
    
    enzymes = ['ASPN', 'Chymotrypsin', 'Elastase', 'Pepsin', 'Trypsin']
    for enzyme in enzymes:
        (ethcd_path / enzyme).mkdir(exist_ok=True)
        logging.info(f"Created enzyme directory: {ethcd_path / enzyme}")
    
    return ethcd_path

def copy_mgf_files(source_dir, ethcd_path):
    """Copy .mgf files into the corresponding EThcD/enzyme subfolders."""
    source_path = Path(source_dir)
    if not source_path.exists():
        logging.error(f"Source directory not found: {source_path}")
        raise FileNotFoundError(f"Source directory not found: {source_path}")
    
    mgf_files = list(source_path.glob('*.mgf'))
    if not mgf_files:
        logging.warning(f"No .mgf files found in {source_path}")
        return False
    
    for mgf_file in mgf_files:
        enzyme = get_enzyme_from_filename(mgf_file.name)
        if enzyme:
            ethcd_dest = ethcd_path / enzyme / mgf_file.name
            try:
                shutil.copy(mgf_file, ethcd_dest)
                logging.info(f"Copied {mgf_file} to {ethcd_dest}")
            except Exception as e:
                logging.error(f"Failed to copy {mgf_file} to {ethcd_dest}: {str(e)}")
                continue
        else:
            logging.warning(f"No recognized enzyme in filename: {mgf_file.name}")
    
    return True

def merge_psm_reports(base_path):
    """Merge all A_Default_PSM_Report.txt files into a single CSV."""
    fusion_path = Path(base_path) / 'searchGUI' / 'Fusion'
    all_csv = fusion_path / 'Default PSM Report All.csv'
    
    logging.info(f"Searching for PSM reports in {fusion_path}")
    psm_files = list(fusion_path.rglob('A_Default_PSM_Report.txt'))
    if not psm_files:
        logging.warning(f"No A_Default_PSM_Report.txt files found in {fusion_path} or its subdirectories")
        return
    
    logging.info(f"Found {len(psm_files)} PSM report files: {', '.join(str(f) for f in psm_files)}")
    
    dfs = [pd.read_csv(f, sep='\t', encoding='utf-8') for f in psm_files]
    merged_df = pd.concat(dfs, ignore_index=True)
    
    required_cols = ['Protein(s)', 'Spectrum Title', 'Modified Sequence']
    if not all(col in merged_df.columns for col in required_cols):
        logging.error(f"Required columns missing in merged data. Available columns: {list(merged_df.columns)}")
        return
    
    def derive_peptide(modified_seq):
        if pd.isna(modified_seq):
            return ''
        tmp = str(modified_seq)
        tmp = tmp.replace('I', 'L')
        tmp = re.sub(r'^NH2-', '', tmp)
        tmp = re.sub(r'-COOH$', '', tmp)
        tmp = tmp.replace('Q<deam>', 'q')
        tmp = tmp.replace('C<cmm>', 'C')
        tmp = tmp.replace('M<ox>', 'm')
        tmp = tmp.replace('N<deam>', 'n')
        tmp = re.sub(r'-deam$', '', tmp)
        return tmp
    
    merged_df['Peptide'] = merged_df['Modified Sequence'].apply(derive_peptide)
    if merged_df['Peptide'].isna().all():
        logging.warning("All Peptide data derived as invalid")
        return
    
    try:
        all_csv.parent.mkdir(parents=True, exist_ok=True)
        merged_df.to_csv(all_csv, index=False, encoding='utf-8')
        logging.info(f"Saved full merged data to {all_csv} with {len(merged_df)} rows and columns {list(merged_df.columns)}")
    except Exception as e:
        logging.error(f"Failed to write full CSV to {all_csv}: {str(e)}")
        if not os.access(all_csv.parent, os.W_OK):
            logging.error(f"No write permission for directory: {all_csv.parent}")

def run_dbsearch(base_path, process1_path, ab, mass_tol=0.02):
    """Perform I/L assignment based on spectra and PSM data."""
    path = Path(process1_path)
    mgf_file = path / f'spectrum_{ab}_EThcD.mgf'
    psm_file = Path(base_path) / 'searchGUI' / 'Fusion' / 'Default PSM Report All.csv'
    output_dir = Path(base_path) / 'I-L'
    output_file = output_dir / f'{ab}_EThcD_wIL_DBsearch_v3_002.csv'

    # Constants and mass tables
    mass_H = 1.0078
    mass_H2O = 18.0106
    mass_NH3 = 17.0265
    mass_N_terminus = 1.0078
    mass_C_terminus = 17.0027
    mass_CO = 27.9949
    mass_Phosphorylation = 79.96633
    vocab_reverse = [
        "A", "R", "N", "n", "D", "C", "E", "Q", "q", "G", "H", "I", "L", "K", "M",
        "m", "F", "P", "S", "T", "W", "Y", "V",
    ]
    vocab = dict([(x, y) for (y, x) in enumerate(vocab_reverse)])
    vocab_size = len(vocab_reverse)
    mass_AA = {
        "A": 71.03711, "R": 156.10111, "N": 114.04293, "n": 115.02695,
        "D": 115.02694, "C": 160.03065, "E": 129.04259, "Q": 128.05858,
        "q": 129.0426, "G": 57.02146, "H": 137.05891, "I": 113.08406,
        "L": 113.08406, "K": 128.09496, "M": 131.04049, "m": 147.0354,
        "F": 147.06841, "P": 97.05276, "S": 87.03203, "T": 101.04768,
        "W": 186.07931, "Y": 163.06333, "V": 99.06841, "p": 111.032029
    }
    mass_ID = [mass_AA[vocab_reverse[x]] for x in range(vocab_size)]
    mass_ID_np = np.array(mass_ID, dtype=np.float32)

    # Read MGF
    if not mgf_file.exists():
        logging.error(f"MGF file not found: {mgf_file}")
        raise FileNotFoundError(f"MGF file not found: {mgf_file}")
    with open(mgf_file, 'r') as file:
        sps = mgf.read(file, convert_arrays=1, read_charges=False, dtype='float32', use_index=False)
        list_of_spectras = list(sps)

    # Read PSM
    if not psm_file.exists():
        logging.error(f"PSM file not found: {psm_file}")
        raise FileNotFoundError(f"PSM file not found: {psm_file}")
    denovo = pd.read_csv(psm_file)
    de_pep = list(denovo['Peptide'])
    de_scan = list(denovo['Spectrum Title'])

    # I/L assignment
    w_scan, w_AA, w_pos, w_pep = [], [], [], []

    def is_close_in_list(value, lst, tolerance=mass_tol):
        return any(abs(value - item) < tolerance for item in lst)

    for sp in list_of_spectras:
        param = sp['params']
        if 'title' in param and param['title'] in de_scan:
            loc = de_scan.index(param['title'])
            if 'L' in de_pep[loc]:
                pep = de_pep[loc]
                pep_l = list(pep)
                mz = [round(x, 4) for x in sp['m/z array']]
                w_pos_tmp, w_AA_tmp = [], []
                for i in range(len(pep_l)):
                    if pep_l[i] == 'L':
                        i0, i1, l0, l1 = 1, 1, 1, 1
                        z  = round(mass.fast_mass(pep[i:],  ion_type='z',     charge=1, aa_mass=mass_AA), 4)
                        z1 = round(mass.fast_mass(pep[i:],  ion_type='z-dot', charge=1, aa_mass=mass_AA), 4)
                        if is_close_in_list(round(z  - 29.039, 4), mz):
                            itm = [abs(round(z  - 29.039, 4) - x) for x in mz if abs(round(z  - 29.039, 4) - x) < mass_tol]
                            i0 = min(itm) if itm else i0
                        elif is_close_in_list(round(z1 - 29.039, 4), mz):
                            itm = [abs(round(z1 - 29.039, 4) - x) for x in mz if abs(round(z1 - 29.039, 4) - x) < mass_tol]
                            i1 = min(itm) if itm else i1
                        elif is_close_in_list(round(z  - 43.054, 4), mz):
                            itm = [abs(round(z  - 43.054, 4) - x) for x in mz if abs(round(z  - 43.054, 4) - x) < mass_tol]
                            l0 = min(itm) if itm else l0
                        elif is_close_in_list(round(z1 - 43.054, 4), mz):
                            itm = [abs(round(z1 - 43.054, 4) - x) for x in mz if abs(round(z1 - 43.054, 4) - x) < mass_tol]
                            l1 = min(itm) if itm else l1
                        if min(i0, i1) < min(l0, l1):
                            w_pos_tmp.append(i)
                            w_AA_tmp.append('i')
                            pep_l[i] = 'i'
                        elif min(i0, i1) > min(l0, l1):
                            w_pos_tmp.append(i)
                            w_AA_tmp.append('l')
                            pep_l[i] = 'l'
                if w_pos_tmp:
                    de_pep[loc] = ''.join(pep_l)
                    w_pep.append(''.join(pep_l))
                    w_scan.append(param['title'])
                    w_AA.append(' '.join(w_AA_tmp))
                    w_pos.append(' '.join(map(str, w_pos_tmp)))

    # Write back and export
    loc = [de_scan.index(x) for x in w_scan]
    dat = denovo.iloc[loc].copy()
    dat.loc[:, 'I/L_pep'] = w_pep
    dat.loc[:, 'I/L_pos'] = w_pos

    output_dir.mkdir(parents=True, exist_ok=True)
    dat.to_csv(output_file, index=False)
    logging.info(f"Saved I/L assignment results to {output_file} with {len(dat)} rows")

def run_search_commands(base_path, ethcd_path, searchgui_jar, peptideshaker_jar, params_path):
    """Run SearchGUI and PeptideShaker commands and stream their outputs."""
    enzymes = ['ASPN', 'Chymotrypsin', 'Elastase', 'Pepsin', 'Trypsin']
    base_fasta = Path(base_path) / 'searchGUI' / 'Fusion'
    
    def run_command(cmd, description):
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, text=True)
        logging.info(f"Started {description}")
        print(f"\nRunning {description}...")
        
        while True:
            stdout_line = process.stdout.readline()
            if stdout_line == '' and process.poll() is not None:
                break
            if stdout_line:
                print(stdout_line.strip())
                logging.info(f"{description} stdout: {stdout_line.strip()}")
        
        stderr_line = process.stderr.read()
        if stderr_line:
            print(stderr_line.strip(), file=sys.stderr)
            logging.error(f"{description} stderr: {stderr_line.strip()}")
        
        return_code = process.poll()
        if return_code != 0:
            logging.error(f"{description} failed with return code {return_code}")
            raise subprocess.CalledProcessError(return_code, cmd)
        logging.info(f"Successfully completed {description}")

    # Generate decoy for EThcD
    for frag_type in ['EThcD']:
        fasta_file = base_fasta / frag_type / 'Fusion_Casanovo.fasta'
        if not fasta_file.exists():
            logging.error(f"FASTA file not found for FastaCLI: {fasta_file}")
            raise FileNotFoundError(f"FASTA file not found: {fasta_file}")
        
        fasta_cmd = [
            'java', '-cp', str(searchgui_jar),
            'eu.isas.searchgui.cmd.FastaCLI',
            '-in', str(fasta_file),
            '-decoy', '1'
        ]
        run_command(fasta_cmd, f"FastaCLI for {frag_type}")

    # Run SearchGUI & PeptideShaker per enzyme (EThcD only)
    for enzyme in enzymes:
        for frag_type, path in [('EThcD', ethcd_path)]:
            enzyme_path = path / enzyme
            mgf_files = list(enzyme_path.glob('*.mgf'))
            if not mgf_files:
                logging.warning(f"No .mgf files found in {enzyme_path}")
                continue
            
            mgf_file = mgf_files[0]
            fasta_file = base_fasta / frag_type / 'Fusion_Casanovo_concatenated_target_decoy.fasta'
            param_file = params_path / f'{enzyme}{"-EThcD" if frag_type == "EThcD" else ""}.par'
            output_folder = enzyme_path
            psdb_file = enzyme_path / f'{enzyme}.psdb'
            
            if not fasta_file.exists():
                logging.error(f"Decoy FASTA file not found: {fasta_file}")
                raise FileNotFoundError(f"Decoy FASTA file not found: {fasta_file}")
            if not param_file.exists():
                logging.error(f"Parameter file not found: {param_file}")
                raise FileNotFoundError(f"Parameter file not found: {param_file}")
            
            search_cmd = [
                'java', '-cp', str(searchgui_jar),
                'eu.isas.searchgui.cmd.SearchCLI',
                '-spectrum_files', str(mgf_file),
                '-fasta_file', str(fasta_file),
                '-output_folder', str(output_folder),
                '-id_params', str(param_file),
            ]
            # EThcD uses OMSSA in this setup
            search_cmd.extend(['-omssa', '1'])
            run_command(search_cmd, f"SearchCLI for {enzyme} ({frag_type})")
            
            peptide_cmd = [
                'java', '-cp', str(peptideshaker_jar),
                'eu.isas.peptideshaker.cmd.PeptideShakerCLI',
                '-reference', 'A',
                '-fasta_file', str(fasta_file),
                '-identification_files', str(enzyme_path),
                '-spectrum_files', str(enzyme_path),
                '-id_params', str(param_file),
                '-out', str(psdb_file)
            ]
            run_command(peptide_cmd, f"PeptideShakerCLI for {enzyme} ({frag_type})")
            
            report_cmd = [
                'java', '-cp', str(peptideshaker_jar),
                'eu.isas.peptideshaker.cmd.ReportCLI',
                '-reports', '3',
                '-in', str(psdb_file),
                '-out_reports', str(enzyme_path)
            ]
            run_command(report_cmd, f"ReportCLI for {enzyme} ({frag_type})")

def run_pipeline(
    base_path: Path,
    process1_path: Path,
    fasta_path: Path,
    ab: str,
    searchgui_jar: Path,
    peptideshaker_jar: Path,
    params_path: Path):
    """One-click pipeline to run in Jupyter (EThcD only)."""
    setup_logging(base_path)

    if not base_path.exists():
        logging.error(f"Base directory does not exist: {base_path}")
        raise FileNotFoundError(f"Base directory does not exist: {base_path}")

    # EThcD source directory
    source_dir = Path(process1_path) / 'EThcD'

    # Dependency checks
    for path in [searchgui_jar, peptideshaker_jar, params_path]:
        if not path.exists():
            logging.error(f"Required path does not exist: {path}")
            raise FileNotFoundError(f"Required path does not exist: {path}")
    
    ethcd_path = create_directory_structure(base_path)
    copy_fasta_file(fasta_path, ethcd_path)

    # Copy mgf and run searches
    if copy_mgf_files(source_dir, ethcd_path):
        run_search_commands(base_path, ethcd_path, searchgui_jar, peptideshaker_jar, params_path)
    else:
        logging.warning("Skipping search commands due to no valid .mgf files")
    
    # Merge reports and perform I/L assignment
    merge_psm_reports(base_path)
    run_dbsearch(base_path, process1_path, ab)

    logging.info("Script completed")


    # ======================================================
    # ⚠ NOTE:
    # If your disk space is sufficient and you would like to keep all EThcD intermediate files, you can simply
    # comment out the code block below that deletes the EThcD directory.
    # ======================================================
    try:
        shutil.rmtree(ethcd_path)
        logging.info(f"Deleted EThcD directory and its contents: {ethcd_path}")
    except Exception as e:
        logging.error(f"Failed to delete EThcD directory {ethcd_path}: {e}")

# =========================
# ▶️ Run
# =========================
run_pipeline(BASE_PATH, PROCESS1_PATH, FASTA_PATH, AB, SEARCHGUI_JAR, PEPTIDESHAKER_JAR, PARAMS_PATH)

2025-11-18 11:25:36,963 - INFO - Script started at 2025-11-18 11:25:36.963679
2025-11-18 11:25:36,964 - INFO - Created directory: /results/SA58/Levenshtein/searchGUI/Fusion
2025-11-18 11:25:36,965 - INFO - Created directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD
2025-11-18 11:25:36,965 - INFO - Created enzyme directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/ASPN
2025-11-18 11:25:36,965 - INFO - Created enzyme directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/Chymotrypsin
2025-11-18 11:25:36,966 - INFO - Created enzyme directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/Elastase
2025-11-18 11:25:36,967 - INFO - Created enzyme directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/Pepsin
2025-11-18 11:25:36,967 - INFO - Created enzyme directory: /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/Trypsin
2025-11-18 11:25:36,968 - INFO - Copied FASTA file to /results/SA58/Levenshtein/searchGUI/Fusion/EThcD/Fusion_Casanovo.fasta
2025-11-18 11:25:37

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import pandas as pd
import numpy as np
import re

"""Discriminating Isoleucine from Leucine"""


def compute_boundaries_from_labels(
    labels,
    cdr_labels=('CDR1', 'CDR2', 'CDR3'),
    const_label='Constant'
):
    """
    Parameters
    ----------
    labels : list[str]
        Per-position region labels for the heavy chain (1-based positions implied).
        e.g. ["FR1","FR1",...,"CDR1","CDR1",...,"FR2",...,"CONSTANT",...]
    cdr_labels : tuple[str, str, str]
        The exact label names for CDR1/2/3 in `labels`.
    const_label : str
        The exact label name for the constant region in `labels`.

    Returns
    -------
    list[int]
        [CDR1_start, CDR1_end,
         CDR2_start, CDR2_end,
         CDR3_start, CDR3_end,
         CONSTANT_start]
        All positions are 1-based indices.
    """
    # normalize to strings and build (1-based) index map
    labels_norm = [str(x) for x in labels]
    n = len(labels_norm)

    out = []
    for cdr in cdr_labels:
        pos = [i+1 for i, lab in enumerate(labels_norm) if lab == cdr]
        if not pos:
            raise ValueError(f"Label '{cdr}' not found in provided labels.")
        # In well-formed data CDR is contiguous; min/max are the start/end
        out.extend([min(pos), max(pos)])

    const_pos = [i+1 for i, lab in enumerate(labels_norm) if lab == const_label]
    if not const_pos:
        raise ValueError(f"Label '{const_label}' not found in provided labels.")
    out.append(min(const_pos))  # first CONSTANT residue

    return out


# ===== Read DB-search results and pre-process peptides =====
dat = pd.read_csv('/results/SA58/Levenshtein/I-L/SA58_EThcD_wIL_DBsearch_v3_002.csv')
peptide = dat['I/L_pep'].astype(str).copy()

# normalize lowercase p/m/q/n to uppercase
# (keep lowercase i/l as evidence for counting)
peptide = peptide.str.replace('p', 'Q').str.replace('m', 'M') \
                 .str.replace('q', 'Q').str.replace('n', 'N')

# ===== Reference sequences =====
HC = HC_assembly
LC = LC_assembly
HC1 = HC.replace('I', 'L')
LC1 = LC.replace('I', 'L')

HCCR = compute_boundaries_from_labels(HC_label)
LCCR = compute_boundaries_from_labels(LC_label)

# ===== Build Region vector =====
def build_regions(seq_len, ccr):
    # ccr: list of 7 boundary positions (1-based)
    parts = []
    parts += ['FR1'] * (ccr[0]-1)
    parts += ['CDR1'] * (ccr[1]-ccr[0]+1)
    parts += ['FR2'] * (ccr[2]-ccr[1]-1)
    parts += ['CDR2'] * (ccr[3]-ccr[2]+1)
    parts += ['FR3'] * (ccr[4]-ccr[3]-1)
    parts += ['CDR3'] * (ccr[5]-ccr[4]+1)
    parts += ['FR4'] * (ccr[6]-ccr[5]-1)
    parts += ['CONSTANT'] * (seq_len-ccr[6]+1)
    return parts

def init_chain_df(seq, regions):
    n = len(seq)
    df = pd.DataFrame({
        'w-I': np.zeros(n, dtype=int),
        'w-L': np.zeros(n, dtype=int),
        'T-I': np.zeros(n, dtype=int),
        'T-L': np.zeros(n, dtype=int),
        'chy&pep-L': np.zeros(n, dtype=int),
        'Region': regions,
        'Newseq': list(seq)  # initialize with original residue characters
    })
    return df

lc = init_chain_df(LC, build_regions(len(LC), LCCR))
hc = init_chain_df(HC, build_regions(len(HC), HCCR))

# ===== Tally peptide evidence on HC/LC (count lowercase i/l only) =====
for Pep in peptide:
    pep_norm = (Pep.replace('I', 'L')
                    .replace('i', 'L')
                    .replace('l', 'L'))
    # HC
    start0 = HC1.find(pep_norm)  # 0-based; -1 if not found
    if start0 >= 0:
        star = start0 + 1  # convert to 1-based
        for j, aa in enumerate(Pep, start=1):
            if aa == 'i':
                hc.at[star + j - 2, 'w-I'] += 1
            if aa == 'l':
                hc.at[star + j - 2, 'w-L'] += 1
    # LC
    start0 = LC1.find(pep_norm)
    if start0 >= 0:
        star = start0 + 1
        for j, aa in enumerate(Pep, start=1):
            if aa == 'i':
                lc.at[star + j - 2, 'w-I'] += 1
            if aa == 'l':
                lc.at[star + j - 2, 'w-L'] += 1

# ===== Mark theoretical I/L residues =====
THC = HC_template
TLC = LC_template

for i, aa in enumerate(THC, start=1):
    if aa == 'I':
        hc.at[i-1, 'T-I'] = 1
    if aa == 'L':
        hc.at[i-1, 'T-L'] = 1

for i, aa in enumerate(TLC, start=1):
    if aa == 'I':
        lc.at[i-1, 'T-I'] = 1
    if aa == 'L':
        lc.at[i-1, 'T-L'] = 1

# ===== Casanovo data: clean peptide strings and count chymotrypsin/pepsin L-termini =====
casa = pd.read_csv('/data/Fusion/dataset/mAbs/human/SA58/casanovo/50-cleaned/SA58_casanovo_stitch_HCD.csv')
casa = casa[casa['Denovo Score'] >= 90].copy()

pattern = r'\((?:\+\.98|\+15\.99|\+57\.02|-17\.03)\)'

def rm_mods(s):
    s = '' if s is None else str(s)
    return re.sub(pattern, '', s)
    return s

casa['Peptide'] = casa['Peptide'].astype(str).map(rm_mods)

for idx, row in casa.iterrows():
    pep = row['Peptide']
    if not pep:
        continue
    if pep[-1] != 'L':
        continue

    # Peptide end position (1-based)
    start0 = HC1.find(pep)
    if start0 >= 0:
        end_pos = start0 + len(pep)  # 1-based end
        enzyme = str(row.iloc[1])
        if re.search(r'(chymo|pepsin)', enzyme, flags=re.IGNORECASE):
            hc.at[end_pos - 1, 'chy&pep-L'] += 1
        continue

    start0 = LC1.find(pep)
    if start0 >= 0:
        end_pos = start0 + len(pep)  # 1-based end
        enzyme = str(row.iloc[1])
        if re.search(r'(chymo|pepsin)', enzyme, flags=re.IGNORECASE):
            lc.at[end_pos - 1, 'chy&pep-L'] += 1

# ===== Decide Newseq residue by residue (HC rules) =====
def decide_chain_il(df, i, is_heavy: bool):
    """
    df: hc or lc; columns ['w-I','w-L','T-I','T-L','chy&pep-L','Region','Newseq']
    i: 0-based index
    is_heavy: True for HC, False for LC
    """
    aa0 = df.at[i, 'Newseq']
    if aa0 not in ('I', 'L'):
        return

    wI, wL = df.at[i, 'w-I'], df.at[i, 'w-L']
    tI, tL = df.at[i, 'T-I'], df.at[i, 'T-L']
    chL = df.at[i, 'chy&pep-L']
    region = df.at[i, 'Region']

    if is_heavy:
        pos1 = i + 1
        if pos1 < HCCR[6]:
            if 'CDR' in region:
                if (wI + wL) != 0:
                    if wI > wL:
                        df.at[i, 'Newseq'] = 'I'
                    elif wI < wL:
                        df.at[i, 'Newseq'] = 'L'
                    else:
                        if (tI + tL) != 0:
                            df.at[i, 'Newseq'] = 'I' if tI == 1 else 'L'
                        else:
                            if chL == 0:
                                df.at[i, 'Newseq'] = 'I'
                else:
                    if (tI + tL) != 0:
                        df.at[i, 'Newseq'] = 'I' if tI == 1 else 'L'
                    else:
                        if chL == 0:
                            df.at[i, 'Newseq'] = 'I'
            else:
                if (tI + tL) != 0:
                    df.at[i, 'Newseq'] = 'I' if tI == 1 else 'L'
                else:
                    if (wI + wL) != 0:
                        df.at[i, 'Newseq'] = 'I' if wI > wL else 'L'
                    else:
                        if chL == 0:
                            df.at[i, 'Newseq'] = 'I'
        else:
            if (tI + tL) != 0:
                df.at[i, 'Newseq'] = 'I' if tI == 1 else 'L'
            else:
                if (wI + wL) != 0:
                    df.at[i, 'Newseq'] = 'I' if wI > wL else 'L'
                else:
                    if chL == 0:
                        df.at[i, 'Newseq'] = 'I'
    else:
        if (tI + tL) != 0:
            df.at[i, 'Newseq'] = 'I' if tI == 1 else 'L'
        else:
            if (wI + wL) != 0:
                df.at[i, 'Newseq'] = 'I' if wI > wL else 'L'
            else:
                if chL == 0:
                    df.at[i, 'Newseq'] = 'I'

# Apply decisions
for i in range(len(hc)):
    decide_chain_il(hc, i, is_heavy=True)

for i in range(len(lc)):
    decide_chain_il(lc, i, is_heavy=False)

# ===== Make transposed tables=====
hc_t = hc.T.copy()
hc_t.columns = list(HC)  # column names are the original residues
lc_t = lc.T.copy()
lc_t.columns = list(LC)

hc_newseq_str = ''.join(hc_t.loc['Newseq'])
lc_newseq_str = ''.join(lc_t.loc['Newseq'])
print("HC Newseq:", hc_newseq_str)
print("LC Newseq:", lc_newseq_str)

# Write CSVs
hc_t.to_csv('/results/SA58/Levenshtein/I-L/IL-HC-score-v4.csv', index=True, encoding='utf-8')
lc_t.to_csv('/results/SA58/Levenshtein/I-L/IL-LC-score-v4.csv', index=True, encoding='utf-8')

# Write fasta
fasta_path = Path(r'/results/SA58/Levenshtein/I-L/Fusion_Casanovo_IL.fasta')
with open(fasta_path, 'w', encoding='utf-8') as f:
    f.write(">HC\n")
    f.write(hc_newseq_str + "\n")
    f.write(">LC\n")
    f.write(lc_newseq_str + "\n")


In [None]:
# -*- coding: utf-8 -*-
"""
Coverage depth (7-mer) for LC/HC with region coloring.

Outputs:
- coverage-depth.pdf / coverage-depth.png
- coverage-depth.txt
"""

from pathlib import Path
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt

# ---------------------------
# Paths (change as needed)
# ---------------------------
in_dir = Path('/data/Fusion/dataset/mAbs/human/SA58/casanovo/50-cleaned')
out_dir = Path('/results/SA58/Levenshtein/I-L')
in_csv = in_dir / 'cleaned_peptides.csv'
out_pdf = out_dir / 'coverage-depth.pdf'
out_png = out_dir / 'coverage-depth.png'
out_txt = out_dir / 'coverage-depth.txt'

# ---------------------------
# Sequences (from your R script)
#   - HC/LC_orig are used for writing to coverage-depth.txt
#   - HC/LC (I->L) are used for coverage matching
# ---------------------------
HC_orig = hc_newseq_str
LC_orig = lc_newseq_str
HC = HC_orig.replace('I', 'L')
LC = LC_orig.replace('I', 'L')

# ---------------------------
# Optional: boundaries/labels
# If you already have HC_label/LC_label as per-position region lists, set them here.
# Otherwise, leave them as None and the script will fall back to ind1/ind2.
# ---------------------------


def compute_boundaries_from_labels(labels, cdr_labels=('CDR1','CDR2','CDR3'), const_label='CONSTANT'):
    """Extract [CDR1_s, CDR1_e, CDR2_s, CDR2_e, CDR3_s, CDR3_e, CONSTANT_s] from a label list."""
    labs = [str(x) for x in labels]
    out = []
    for cdr in cdr_labels:
        pos = [i+1 for i, lab in enumerate(labs) if lab == cdr]
        if not pos:
            raise ValueError(f"Label '{cdr}' not found.")
        out.extend([min(pos), max(pos)])
    const_pos = [i+1 for i, lab in enumerate(labs) if lab.upper().startswith('CONST')]
    if not const_pos:
        raise ValueError("Constant region not found.")
    out.append(min(const_pos))
    return out


LCCR = compute_boundaries_from_labels(LC_label)
ind1 = LCCR
HCCR = compute_boundaries_from_labels(HC_label)
ind2 = HCCR

# ---------------------------
# Load peptides and remove modification annotations
# ---------------------------
def rm_mods(s: str) -> str:
    """Remove the four modification annotations: (+57.02) (+15.99) (+.98) (-17.03)."""
    s = '' if s is None else str(s)
    pattern = r'\((?:\+57\.02|\+15\.99|\+\.98|-17\.03)\)'
    return re.sub(pattern, '', s)

df = pd.read_csv(in_csv)
peptides = df['Peptide'].dropna().astype(str).map(rm_mods)

# ---------------------------
# Coverage calculation (7-mer), matching the R while-loop
# ---------------------------
def add_coverage_for_chain(seq_L: str, peptides_list: pd.Series, cov_array: np.ndarray, kmer: int = 7) -> int:
    """
    - Slide a 7-mer window over each peptide;
    - First match in a contiguous block: +1 for all 7 positions;
    - Subsequent matches in the same block: +1 at the 7-mer end position only;
    - If mismatch and we were in a block (k!=0): j += 7; else j += 1; then reset k=0.
    """
    count = 0
    for pep in peptides_list:
        nAA = len(pep)
        j = 0
        k = 0
        while j <= nAA - kmer:
            mer = pep[j:j+kmer]
            start = seq_L.find(mer)
            if start != -1:
                count += 1
                j += 1
                k += 1
                if k == 1:
                    cov_array[start:start+kmer] += 1
                else:
                    cov_array[start + kmer - 1] += 1
            else:
                if k != 0:
                    j += kmer
                else:
                    j += 1
                k = 0
    return count

lc_cov = np.zeros(len(LC), dtype=int)
hc_cov = np.zeros(len(HC), dtype=int)
count_lc = add_coverage_for_chain(LC, peptides, lc_cov)
count_hc = add_coverage_for_chain(HC, peptides, hc_cov)

# ---------------------------
# Build per-position DataFrames (Region from labels if provided; otherwise from boundaries)
# ---------------------------
def pos_to_region(pos, ind):
    """Map 1-based position to region using 7 boundaries."""
    if pos < ind[0]:
        return 'FR1'
    if ind[0] <= pos <= ind[1]:
        return 'CDR1'
    if ind[1] < pos < ind[2]:
        return 'FR2'
    if ind[2] <= pos <= ind[3]:
        return 'CDR2'
    if ind[3] < pos < ind[4]:
        return 'FR3'
    if ind[4] <= pos <= ind[5]:
        return 'CDR3'
    if ind[5] < pos < ind[6]:
        return 'FR4'
    return 'CONSTANT'

def normalize_const(x: str) -> str:
    """Normalize 'Constant'/'CONSTANT' etc. to 'CONSTANT'."""
    s = str(x)
    return 'CONSTANT' if s.lower().startswith('const') else s

df1 = pd.DataFrame({
    'Position': np.arange(1, len(lc_cov)+1, dtype=int),
    'Depth': lc_cov.astype(int),
    'Region': ([normalize_const(lab) for lab in LC_label]
               if LC_label is not None
               else [pos_to_region(p, ind1) for p in range(1, len(lc_cov)+1)])
})

df2 = pd.DataFrame({
    'Position': np.arange(1, len(hc_cov)+1, dtype=int),
    'Depth': hc_cov.astype(int),
    'Region': ([normalize_const(lab) for lab in HC_label]
               if HC_label is not None
               else [pos_to_region(p, ind2) for p in range(1, len(hc_cov)+1)])
})

# ---------------------------
# Plot: two stacked panels, ggplot-like styling and a single legend
# ---------------------------
region_colors = {
    'FR1': '#C7E9C0FF',
    'CDR1': '#FDAE6BFF',
    'FR2': '#A1D99BFF',
    'CDR2': '#FD8D3CFF',
    'FR3': '#74C476FF',
    'CDR3': '#E6550DFF',
    'FR4': '#31A354FF',
    'CONSTANT': '#4D85BD'
}

# ggplot-like minimal theme
plt.rcParams.update({
    "axes.facecolor": "white",
    "figure.facecolor": "white",
    "grid.color": "#d9d9d9",
    "grid.linestyle": "-",
    "grid.linewidth": 0.8,
    "axes.edgecolor": "#e6e6e6",
    "axes.grid": True,
    "axes.axisbelow": True,
    "xtick.color": "#4d4d4d",
    "ytick.color": "#4d4d4d",
    "axes.labelcolor": "#4d4d4d",
    "text.color": "#4d4d4d",
})

def draw_panel(ax, df, title, const_start, tag):
    """Draw one panel with colored bars, dot-dash reference lines, and a tag."""
    colors = [region_colors.get(r, '#999999') for r in df['Region']]
    ax.bar(df['Position'], df['Depth'], color=colors, width=1.0, align='center', edgecolor='none')

    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Position', fontsize=11)
    ax.set_ylabel('Depth', fontsize=11)

    ax.set_xlim(0.5, len(df))
    ymax = max(5, float(df['Depth'].max()))
    ax.set_ylim(0, ymax * 1.05)

    dashseq = (0, (6, 6, 2, 6))  # dot-dash
    mean_depth = float(df['Depth'].mean())
    ax.axhline(mean_depth, linestyle=dashseq, color='black', linewidth=1.1)
    ax.axvline(const_start - 0.5, linestyle=dashseq, color='black', linewidth=1.1)

    ax.text(len(df) * 0.82, ymax * 0.92, f"Mean Depth = {mean_depth:.2f}",
            ha='left', va='top', fontsize=10)

    # panel tag (A/B)
    ax.text(0.01, 0.96, tag, transform=ax.transAxes,
            fontsize=12, fontweight='bold', va='top', ha='left')

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10), sharex=False)
draw_panel(ax1, df1, 'Light Chain', const_start=ind1[6], tag='A')
draw_panel(ax2, df2, 'Heavy Chain', const_start=ind2[6], tag='B')

# single legend on the right
from matplotlib.lines import Line2D
handles = [Line2D([0], [0], color=clr, lw=8, label=lab) for lab, clr in region_colors.items()]
fig.legend(
    handles=handles,
    labels=[h.get_label() for h in handles],
    loc='center left', bbox_to_anchor=(0.985, 0.5),
    frameon=False, ncol=1, title='region'
)

fig.tight_layout(rect=[0.02, 0.02, 0.97, 0.98])
fig.savefig(out_pdf, bbox_inches='tight')
fig.savefig(out_png, dpi=500, bbox_inches='tight')
plt.close(fig)

# ---------------------------
# Means (print like the R script)
# ---------------------------
mean_lc_all = float(lc_cov.mean())
mean_lc_var = float(lc_cov[:ind1[6]-1].mean()) if (ind1[6]-1) > 0 else float('nan')
mean_hc_all = float(hc_cov.mean())
mean_hc_var = float(hc_cov[:ind2[6]-1].mean()) if (ind2[6]-1) > 0 else float('nan')
print([mean_lc_all, mean_lc_var, mean_hc_all, mean_hc_var])
print([mean_lc_all, mean_hc_all])

# ---------------------------
# Write coverage-depth.txt (identical structure as R: LC first, then HC)
# ---------------------------
Df1 = pd.DataFrame({
    'Posion': df1['Position'].to_list(),  # keep R's spelling "Posion"
    'LC': list(LC_orig),
    'Depth': df1['Depth'].astype(int).to_list(),
    'Region': df1['Region'].to_list()
}).T
Df1.index = ['Posion', 'LC', 'Depth', 'Region']

Df2 = pd.DataFrame({
    'Posion': df2['Position'].to_list(),
    'HC': list(HC_orig),
    'Depth': df2['Depth'].astype(int).to_list(),
    'Region': df2['Region'].to_list()
}).T
Df2.index = ['Posion', 'HC', 'Depth', 'Region']

in_dir.mkdir(parents=True, exist_ok=True)
Df1.to_csv(out_txt, sep='\t', header=False)
Df2.to_csv(out_txt, sep='\t', header=False, mode='a')

print(f"Saved: {out_pdf}\nSaved: {out_png}\nSaved: {out_txt}")


In [None]:
# Jupyter cell: I/L-aware k-mer assembly for BOTH HC and LC, with robust logging (log-only)

import logging
from pathlib import Path
from typing import List, Dict, Optional
import pandas as pd

# =========================
# Parameters (edit as needed)
# =========================
K = 7
KMERS_CSV = Path("/data/Fusion/dataset/mAbs/human/SA58/casanovo/50-cleaned/kmer_50_Casanovo.csv")
LOG_PATH = Path("/results/SA58/Levenshtein/I-L/stdout.log")

# Sequences & boundaries you already computed earlier in the notebook
SEQ_HC = hc_newseq_str          # heavy-chain sequence string (with I)
SEQ_LC = lc_newseq_str          # light-chain sequence string (with I)
OPEN_LENGTH_HC = int(HCCR[6])   # CONSTANT start (1-based) for HC
OPEN_LENGTH_LC = int(LCCR[6])   # CONSTANT start (1-based) for LC

RELAXED_MIN_OVERLAP = 2         # allow down to 2 after entering conserved region

# =========================
# Robust logging for notebooks (safe across re-runs)
# =========================
LOG_FMT = "%(asctime)s,%(msecs)03d %(levelname)s %(message)s"
DATE_FMT = "%Y-%m-%d %H:%M:%S"

def setup_logger(log_path: Path, level=logging.DEBUG) -> logging.Logger:
    """Create/refresh a file logger that survives repeated Jupyter runs."""
    logger = logging.getLogger("assembler")
    logger.setLevel(level)
    logger.propagate = False

    # Close & remove existing handlers (prevents Bad file descriptor on re-run)
    for h in list(logger.handlers):
        try:
            h.flush(); h.close()
        except Exception:
            pass
        logger.removeHandler(h)

    # delay=True: open file descriptor on first emit
    fh = logging.FileHandler(log_path, encoding="utf-8", delay=True)
    fh.setLevel(level)
    fh.setFormatter(logging.Formatter(LOG_FMT, DATE_FMT))
    logger.addHandler(fh)
    return logger

logger = setup_logger(LOG_PATH)

# =========================
# Helpers
# =========================
def read_kmers_csv(path: Path, k: int) -> List[str]:
    """
    Read kmers from CSV: prefer 'kmer' column (case-insensitive), else first column.
    Keep uppercase strings of exact length k.
    """
    df = pd.read_csv(path)
    cols = list(df.columns)
    kmer_col: Optional[str] = next((c for c in cols if str(c).strip().lower() == "kmer"), None)
    if kmer_col is None:
        kmer_col = cols[0]
    kmers = (
        df[kmer_col]
        .dropna()
        .astype(str)
        .str.strip()
        .str.upper()
        .tolist()
    )
    return [s for s in kmers if len(s) == k]

def build_prefix_indices(kmers_norm: List[str], min_overlap: int, max_overlap: int) -> Dict[int, Dict[str, List[int]]]:
    """
    Index normalized kmers (I->L) for all overlaps in [min_overlap, max_overlap].
    Returns: {overlap_len: {prefix: [indices in original kmers list (stable order)]}}
    """
    idx: Dict[int, Dict[str, List[int]]] = {o: {} for o in range(min_overlap, max_overlap + 1)}
    for i, km in enumerate(kmers_norm):
        for o in idx.keys():
            pref = km[:o]
            idx[o].setdefault(pref, []).append(i)
    return idx

def assemble_greedy_multi_overlap_ILaware(
    seq: str,
    kmers: List[str],
    k: int,
    open_length: int,
    relaxed_min_overlap: int,
    logger: logging.Logger,
    label: str = "HC",
) -> str:
    """
    Greedy right extension with I/L-aware matching and dynamic overlap relaxation.
    - Matching uses normalized form (I->L); contig text uses ORIGINAL template (keeps 'I').
    - Before contig length > open_length: try overlaps from (k-1) down to 4.
    - After contig length > open_length: allow overlaps down to relaxed_min_overlap (e.g., 2).
    - Do not reuse kmers (by index); CSV order is respected.
    - Log lines are prefixed with [label] to distinguish HC/LC.
    """
    # sanitize numerics (defensive in notebooks)
    open_length = int(open_length)
    relaxed_min_overlap = int(relaxed_min_overlap)
    k = int(k)

    seq = str(seq).strip().upper()
    assert len(seq) >= k, f"[{label}] Sequence length must be >= k."
    seed = seq[:k]
    seed_norm = seed.replace("I", "L")

    kmers = [s.upper() for s in kmers]
    kmers_norm = [s.replace("I", "L") for s in kmers]

    prefix_idx = build_prefix_indices(kmers_norm, min_overlap=relaxed_min_overlap, max_overlap=k - 1)
    base_overlaps = list(range(k - 1, 3, -1))  # e.g., k=7 -> [6,5,4]

    contig = seed
    contig_norm = seed_norm
    used = set()
    warned_small = False
    noted_conserved = False

    logger.info(f"[{label}] starting to assemble kmers to contigs...")
    logger.debug(f"[{label}] kmer 1: starting with kmer: {seed}")

    while True:
        # Decide which overlaps we try this round
        if len(contig) > open_length:
            if not noted_conserved:
                logger.info(
                    f"[{label}] contig length {len(contig)} > {open_length}: entering conserved region; "
                    f"low variability expected. Enabling smaller overlaps down to {relaxed_min_overlap} "
                    f"and using template residues for growth."
                )
                noted_conserved = True
            overlaps_to_try = list(range(k - 1, relaxed_min_overlap - 1, -1))  # e.g., 6..2
        else:
            overlaps_to_try = base_overlaps[:]  # e.g., 6..4
            if not warned_small:
                # If <4 overlaps could apply now, warn once but keep them disabled until threshold
                possible = any(
                    prefix_idx.get(o, {}).get(contig_norm[-o:], [])
                    for o in range(relaxed_min_overlap, 4)  # 2,3
                )
                if possible:
                    logger.warning(
                        f"[{label}] contig length {len(contig)} ≤ {open_length}: overlaps <4 (2/3) exist "
                        f"but are disabled until entering conserved region (> {open_length})."
                    )
                    warned_small = True

        # Attempt extension using the prioritized overlaps
        extended = False
        for o in overlaps_to_try:
            suffix = contig_norm[-o:]  # normalized suffix
            candidates = prefix_idx.get(o, {}).get(suffix, [])
            next_idx = next((ci for ci in candidates if ci not in used), None)
            if next_idx is None:
                continue

            used.add(next_idx)
            kmer_log = kmers[next_idx]            # keep L-form in log
            add_len = k - o                       # number of new residues to append

            start_pos = len(contig)
            end_pos = min(len(seq), start_pos + add_len)
            if end_pos <= start_pos:
                logger.debug(f"[{label}] no template residues left to append; stopping at length {len(contig)}")
                extended = False
                break

            append_seq = seq[start_pos:end_pos]   # append from ORIGINAL template (keeps I)
            contig += append_seq
            contig_norm += append_seq.replace("I", "L")

            logger.debug(f"[{label}] found overlap (o={o}) kmer: {kmer_log}, contig grew to: {contig}")
            extended = True
            break  # after extension, restart from highest priority

        if not extended:
            break

    logger.debug(f"[{label}] finished with contig:{contig}")
    return contig

# =========================
# Run both chains (append logs to the same file)
# =========================
kmers = read_kmers_csv(KMERS_CSV, K)

_ = assemble_greedy_multi_overlap_ILaware(
    seq=SEQ_HC,
    kmers=kmers,
    k=K,
    open_length=OPEN_LENGTH_HC,
    relaxed_min_overlap=RELAXED_MIN_OVERLAP,
    logger=logger,
    label="HC",
)

_ = assemble_greedy_multi_overlap_ILaware(
    seq=SEQ_LC,
    kmers=kmers,
    k=K,
    open_length=OPEN_LENGTH_LC,
    relaxed_min_overlap=RELAXED_MIN_OVERLAP,
    logger=logger,
    label="LC",
)

print(f"Assembly log written to: {LOG_PATH.resolve()}")

# -------- SAFE CLEANUP FOR NOTEBOOK RE-RUNS --------
for h in list(logger.handlers):
    try:
        h.flush(); h.close()
    except Exception:
        pass
    logger.removeHandler(h)
# (Alternatively: logging.shutdown())
