In [1]:
# -*- coding: utf-8 -*-
import re
from typing import Dict, List, Tuple, Optional, Any
import pandas as pd

# ===================== Ⅰ. 从 FASTA 提取每类最高分 template =====================

def classify_template(name: str) -> str:
    """按模板名前缀归类为 HV/HJ/HC/LV/LJ/LC/Unknown。"""
    if name.startswith("IGHV"):
        return "HV"
    elif name.startswith("IGHJ"):
        return "HJ"
    elif name.startswith("IGH"):
        return "HC"
    elif name.startswith("IGKV") or name.startswith("IGLV"):
        return "LV"
    elif name.startswith("IGKJ") or name.startswith("IGLJ"):
        return "LJ"
    elif name.startswith("IGKC") or name.startswith("IGLC"):
        return "LC"
    else:
        return "Unknown"

def extract_template_info_from_fasta(fasta_path: str, output_csv_path: Optional[str] = None) -> pd.DataFrame:
    """
    从 .fasta 的 header 提取 template 名、score 与类型(type)。
    期望 header 形如：">IGHV1-46 ... score:123 ..."
    """
    records = []
    with open(fasta_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.startswith(">"):
                continue
            header = line[1:].strip()
            parts = header.split()
            template = parts[0]
            score = None
            for p in parts:
                if p.startswith("score:"):
                    try:
                        score = int(p.split(":", 1)[1])
                    except ValueError:
                        score = None
                    break
            records.append({"template": template, "type": classify_template(template), "score": score})

    df = pd.DataFrame(records)
    if output_csv_path:
        df.to_csv(output_csv_path, index=False)
        print(f"✅ 结果已保存到: {output_csv_path}")
    return df


# ===================== Ⅱ. 解析 Stitch 配置 & 读取六段序列 =====================

def parse_stitch_config_segments(cfg_path: str) -> Tuple[Dict[str, Dict[str, str]], str]:
    """
    返回:
      segments: {'heavy': {'V','J','C'}, 'light': {'V','J','C'}}  -> 对应的 fasta 路径
      runname : 例如 '50ugS2P6'
    """
    segments = {'heavy': {'V': None, 'J': None, 'C': None},
                'light': {'V': None, 'J': None, 'C': None}}

    current_chain: Optional[str] = None
    in_segment = False
    seg_name = None
    seg_path = None

    name_re = re.compile(r'^\s*Name\s*:\s*(\S+)\s*$', re.IGNORECASE)
    path_re = re.compile(r'^\s*Path\s*:\s*(.+?\.(?:fa|fasta|faa|fna))\s*$', re.IGNORECASE)
    run_re  = re.compile(r'^\s*Runname\s*:\s*(.+?)\s*$', re.IGNORECASE)

    runname: Optional[str] = None

    def commit():
        nonlocal seg_name, seg_path
        if not (current_chain and seg_name and seg_path):
            return
        n = seg_name.upper().strip().strip('"').strip("'")
        p = seg_path.strip().strip('"').strip("'")
        if current_chain == 'heavy':
            if n == 'IGHV': segments['heavy']['V'] = p
            elif n == 'IGHJ': segments['heavy']['J'] = p
            elif n == 'IGHC': segments['heavy']['C'] = p
        elif current_chain == 'light':
            if n in ('IGLV', 'IGKV'): segments['light']['V'] = p
            elif n in ('IGLJ', 'IGKJ'): segments['light']['J'] = p
            elif n in ('IGLC', 'IGKC'): segments['light']['C'] = p

    with open(cfg_path, 'r', encoding='utf-8') as f:
        for raw in f:
            # Runname 可在文件任意处
            if runname is None:
                m_run = run_re.match(raw)
                if m_run:
                    runname = m_run.group(1).strip().strip('"').strip("'")

            line = raw.strip()

            if line.startswith('Heavy Chain->'):
                current_chain = 'heavy'; in_segment = False; seg_name = seg_path = None; continue
            if line.startswith('Light Chain->'):
                current_chain = 'light'; in_segment = False; seg_name = seg_path = None; continue

            if line.startswith('Segment->'):
                in_segment = True; seg_name = seg_path = None; continue
            if line == '<-':
                if in_segment: commit()
                in_segment = False; seg_name = seg_path = None; continue

            if not in_segment or current_chain not in ('heavy', 'light'):
                continue

            m_path = path_re.match(raw)
            if m_path:
                seg_path = m_path.group(1); commit(); continue
            m_name = name_re.match(raw)
            if m_name:
                seg_name = m_name.group(1); commit(); continue

    if in_segment:
        commit()

    missing = [(ch, rgn) for ch in ('heavy', 'light') for rgn in ('V','J','C') if not segments[ch][rgn]]
    if missing:
        raise ValueError(f"以下槽位未解析到 FASTA 路径: {missing}")
    if not runname:
        raise ValueError("Runname 未找到，请检查配置文件。")

    return segments, runname


def load_annotated_fasta(path: str, key_mode: str = "first_token") -> Dict[str, str]:
    """
    读取带注释 FASTA，返回 {id: 单行文本}。
    关键点：每行先 strip()，再用 ''.join(buf) 无分隔符拼接，避免跨行插入空格；
           同时移除 NBSP（\xa0）。
    """
    records: Dict[str, str] = {}
    cur_key: Optional[str] = None
    buf: List[str] = []

    def _push():
        if cur_key is not None:
            raw = "".join(buf).replace("\xa0", " ")
            records[cur_key] = raw

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.startswith(">"):
                _push()
                head = line[1:].strip()
                key = head.split()[0] if key_mode == "first_token" else head
                cur_key = key
                buf = []
            else:
                buf.append(line.strip())
    _push()
    return records

def extract_sequence_by_id(fasta_path: str, target_id: str, key_mode: str = "first_token") -> str:
    db = load_annotated_fasta(fasta_path, key_mode=key_mode)
    lut = {k.upper(): k for k in db.keys()}
    k = lut.get(target_id.upper())
    if k is None:
        raise KeyError(f"ID '{target_id}' not found in {fasta_path}")
    return db[k]


# ---- 模板解析（V/J/C） ----

def _ctx(seq: str, i: int, span: int = 25) -> str:
    s = max(0, i - span); e = min(len(seq), i + span)
    return seq[s:e]

def _parse_template_with_rules(seq: str, default_label: str, allowed_cdr: List[str]) -> Tuple[str, List[str]]:
    """
    严格模式解析：
      - 仅识别 (Conserved X...) 和 (CDR1/2/3 Y...) 两类括号标注；
      - 括号外仅允许 A-Z；
      - 未闭合/未知括号或非法字符会抛错（包含上下文）。
    """
    pure: List[str] = []
    labels: List[str] = []
    i, n = 0, len(seq)

    while i < n:
        ch = seq[i]
        if ch == '(':
            j = seq.find(')', i + 1)
            if j == -1:
                raise ValueError(f"未闭合的 '(' (pos={i}). context='{_ctx(seq, i)}'")

            token = seq[i+1:j]
            token_norm = re.sub(r"\s+", " ", token.strip())

            if token_norm.startswith("Conserved"):
                aa_block = re.sub(r'[^A-Z]', '', token_norm[len("Conserved"):])
                if not aa_block:
                    raise ValueError(f"(Conserved ...) 中未检测到氨基酸 (pos={i}). token='{token_norm}'")
                for aa in aa_block:
                    pure.append(aa); labels.append("Conserved")
                i = j + 1; continue

            if token_norm.startswith("CDR"):
                m = re.match(r"^CDR\s*([1-3])\s*([A-Za-z\s]*)$", token_norm)
                if not m:
                    raise ValueError(f"无法解析的 CDR 标注 (pos={i}). token='{token_norm}'")
                cdr_tag = f"CDR{m.group(1)}"
                aa_block = re.sub(r'[^A-Z]', '', m.group(2))
                if not aa_block:
                    raise ValueError(f"{cdr_tag} 中未检测到氨基酸 (pos={i}). token='{token_norm}'")
                tag = cdr_tag if cdr_tag in allowed_cdr else default_label
                for aa in aa_block:
                    pure.append(aa); labels.append(tag)
                i = j + 1; continue

            # 其它括号内容：严格报错
            raise ValueError(f"未知括号标注 (pos={i}). token='{token_norm}'. context='{_ctx(seq, i)}'")

        # 括号外：此时无跨行空格，仅允许 A–Z
        if 'A' <= ch <= 'Z':
            pure.append(ch); labels.append(default_label); i += 1; continue

        raise ValueError(f"非法字符 '{ch}' (pos={i}). context='{_ctx(seq, i)}'")

    if len(pure) != len(labels):
        raise RuntimeError("解析后序列与标签长度不一致")
    return "".join(pure), labels

def parse_v_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Variable", ["CDR1","CDR2","CDR3"])

def parse_j_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Variable", ["CDR3"])

def parse_c_template_to_labels(seq: str) -> Tuple[str, List[str]]:
    return _parse_template_with_rules(seq, "Constant", [])


def get_six_segments(cfg_path: str,
                     target_ids: Dict[str, Dict[str, str]],
                     key_mode: str = "first_token") -> Tuple[Dict[str, Dict[str, Tuple[str, List[str]]]], str]:
    """根据 six IDs 读取六段并解析，返回 (six_segments, runname)。"""
    seg_paths, runname = parse_stitch_config_segments(cfg_path)
    out: Dict[str, Dict[str, Tuple[str, List[str]]]] = {'heavy': {}, 'light': {}}

    for chain in ('heavy', 'light'):
        for region in ('V', 'J', 'C'):
            fasta_path = seg_paths[chain][region]
            tid = target_ids[chain][region]
            annotated = extract_sequence_by_id(fasta_path, tid, key_mode=key_mode)

            if region == 'V':
                seq, labels = parse_v_template_to_labels(annotated)
            elif region == 'J':
                seq, labels = parse_j_template_to_labels(annotated)
            else:
                seq, labels = parse_c_template_to_labels(annotated)

            out[chain][region] = (seq, labels)
    return out, runname


# ===================== Ⅲ. 拼接整链并提取 CDR & 前/后7 =====================

def _concat_chain(segments: Dict[str, Tuple[str, List[str]]]) -> Tuple[str, List[str]]:
    """
    将 V/J/C 三段顺序拼接为一条链，并合并标签。
    segments: {'V': (seq, labels), 'J': (seq, labels), 'C': (seq, labels)}
    """
    full_seq_parts: List[str] = []
    full_labels: List[str] = []
    for r in ("V", "J", "C"):
        seq, lab = segments[r]
        if len(seq) != len(lab):
            raise RuntimeError(f"{r} 段序列与标签长度不一致")
        full_seq_parts.append(seq)
        full_labels.extend(lab)
    return "".join(full_seq_parts), full_labels

def _contiguous_ranges(labels: List[str], target: str) -> List[Tuple[int, int]]:
    """在整条链的标签序列中，找到等于 target 的所有连续区间 [start, end]（闭区间）。"""
    ranges: List[Tuple[int, int]] = []
    start: Optional[int] = None
    for i, lab in enumerate(labels):
        if lab == target:
            if start is None:
                start = i
        else:
            if start is not None:
                ranges.append((start, i - 1))
                start = None
    if start is not None:
        ranges.append((start, len(labels) - 1))
    return ranges

def _slice_seq_labels(seq: str, labels: List[str], s: int, e: int) -> Tuple[str, List[str]]:
    if s > e:
        return "", []
    return seq[s:e+1], labels[s:e+1]

def _preceding_k(seq: str, labels: List[str], start_idx: int, k: int = 7) -> Tuple[str, List[str], Tuple[int, int]]:
    """返回 start_idx 之前连续 k 个氨基酸（不够就尽量），及其标签与闭区间 (s,e)。"""
    if start_idx <= 0:
        return "", [], (-1, -1)
    s = max(0, start_idx - k)
    e = start_idx - 1
    frag_seq, frag_lab = _slice_seq_labels(seq, labels, s, e)
    return frag_seq, frag_lab, (s, e)

def _following_k(seq: str, labels: List[str], end_idx: int, k: int = 7) -> Tuple[str, List[str], Tuple[int, int]]:
    """返回 end_idx 之后连续 k 个氨基酸（不够就尽量），及其标签与闭区间 (s,e)。"""
    n = len(seq)
    if end_idx >= n - 1:
        return "", [], (-1, -1)
    s = end_idx + 1
    e = min(n - 1, end_idx + k)
    frag_seq, frag_lab = _slice_seq_labels(seq, labels, s, e)
    return frag_seq, frag_lab, (s, e)

def extract_cdr_and_flanks(
    six_segments: Dict[str, Dict[str, Tuple[str, List[str]]]],
    k: int = 7
) -> Dict[str, Dict[str, Dict[str, Any]]]:
    """
    提取 Heavy/Light 的 CDR1/2/3 序列、ranges、以及前/后 k 个氨基酸（序列+标签+位置）。
    返回结构:
      result[chain][cdr] = {
        "sequence": str,
        "ranges": List[(s,e)],
        "preceding7_seq": str, "preceding7_labels": List[str], "preceding7_range": (s,e),
        "following7_seq": str, "following7_labels": List[str], "following7_range": (s,e),
      }
    """
    result: Dict[str, Dict[str, Dict[str, Any]]] = {}
    for chain in ("heavy", "light"):
        full_seq, full_labels = _concat_chain(six_segments[chain])
        chain_out: Dict[str, Dict[str, Any]] = {}
        for cdr_tag in ("CDR1", "CDR2", "CDR3"):
            ranges = _contiguous_ranges(full_labels, cdr_tag)
            cdr_seq = "".join(full_seq[s:e+1] for s, e in ranges) if ranges else ""

            if ranges:
                first_start = ranges[0][0]
                last_end = ranges[-1][1]
                pre_seq, pre_lab, pre_rng = _preceding_k(full_seq, full_labels, first_start, k=k)
                fol_seq, fol_lab, fol_rng = _following_k(full_seq, full_labels, last_end, k=k)
            else:
                pre_seq, pre_lab, pre_rng = "", [], (-1, -1)
                fol_seq, fol_lab, fol_rng = "", [], (-1, -1)

            chain_out[cdr_tag] = {
                "sequence": cdr_seq,
                "ranges": ranges,
                "preceding7_seq": pre_seq,
                "preceding7_labels": pre_lab,
                "preceding7_range": pre_rng,
                "following7_seq": fol_seq,
                "following7_labels": fol_lab,
                "following7_range": fol_rng,
            }
        result[chain] = chain_out
    return result

def _pretty_print_cdr(info: Dict[str, Dict[str, Dict[str, Any]]], k: int = 7) -> None:
    for chain in ("heavy", "light"):
        print(f"\n=== {chain.UPPER()} ===" if hasattr(chain, "UPPER") else f"\n=== {chain.upper()} ===")
        for cdr in ("CDR1", "CDR2", "CDR3"):
            item = info[chain][cdr]
            seq = item["sequence"]
            ranges = item["ranges"]
            pre7 = item["preceding7_seq"]; pre7_labels = item["preceding7_labels"]; ps, pe = item["preceding7_range"]
            fol7 = item["following7_seq"];  fol7_labels = item["following7_labels"]; fs, fe = item["following7_range"]
            print(f"- {cdr}:")
            print(f"  ranges: {ranges if ranges else '[]'}")
            print(f"  sequence ({len(seq)} aa): {seq if seq else 'N/A'}")
            print(f"  preceding{k} ({ps}..{pe}): {pre7 if pre7 else 'N/A'}")
            if pre7_labels: print(f"  preceding{k} labels: {pre7_labels}")
            print(f"  following{k} ({fs}..{fe}): {fol7 if fol7 else 'N/A'}")
            if fol7_labels: print(f"  following{k} labels: {fol7_labels}")


# ===================== Ⅳ. 使用示例（main） =====================
if __name__ == "__main__":
    # A) 从 tm.fasta 得到每类(type)最高分的 template 名
    fasta_file = r"E:\Data\stitch-v1.4.0-windows\results\SA58\report-monoclonal-tm.fasta"
    output_csv = r"E:\Data\stitch-v1.4.0-windows\results\SA58\template_search.csv"
    df = extract_template_info_from_fasta(fasta_file, output_csv_path=output_csv)

    # 分组取最大分
    best_template_name = (df.loc[df.groupby("type")["score"].idxmax()]
                            .set_index("type")["template"]
                            .to_dict())

    # B) 配置文件
    cfg_path = r"E:\Data\stitch-v1.4.0-windows\batchfiles\SA58.txt"

    # C) 用最佳模板名拼成 six IDs
    required = ["HV","HJ","HC","LV","LJ","LC"]
    miss = [k for k in required if k not in best_template_name or not best_template_name[k]]
    if miss:
        raise ValueError(f"从 {fasta_file} 提取的 best_template_name 缺少：{miss}")

    target_ids = {
        "heavy": {"V": best_template_name["HV"], "J": best_template_name["HJ"], "C": best_template_name["HC"]},
        "light": {"V": best_template_name["LV"], "J": best_template_name["LJ"], "C": best_template_name["LC"]},
    }

    # D) 读取并解析六段
    template, runname = get_six_segments(cfg_path, target_ids)
    print("Runname:", runname)
    for chain in ("heavy", "light"):
        for region in ("V", "J", "C"):
            seq, labels = template[chain][region]
            print(f"{chain}.{region}: {target_ids[chain][region]} | len={len(seq)} | preview={seq}")

    # E) 提取 Heavy/Light 的 CDR1/2/3 以及前/后 7 位（序列+标签）
    cdr_info = extract_cdr_and_flanks(template, k=7)
    _pretty_print_cdr(cdr_info, k=7)


✅ 结果已保存到: E:\Data\stitch-v1.4.0-windows\results\SA58\template_search.csv
Runname: SA58
heavy.V: IGHV7-4-1 | len=98 | preview=QVQLVQSGSELKKPGASVKVSCKASGYTFTSYAMNWVRQAPGQGLEWMGWINTNTGNPTYAQGFTGRFVFSLDTSVSTAYLQICSLKAEDTAVYYCAR
heavy.J: IGHJ4 | len=15 | preview=YFDYWGQGTLVTVSS
heavy.C: IGHG1 | len=330 | preview=ASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKSCDKTHTCPPCPAPELLGGPSVFLFPPKPKDTLMISRTPEVTCVVVDVSHEDPEVKFNWYVDGVEVHNAKTKPREEQYNSTYRVVSVLTVLHQDWLNGKEYKCKVSNKALPAPIEKTISKAKGQPREPQVYTLPPSRDELTKNQVSLTCLVKGFYPSDIAVEWESNGQPENNYKTTPPVLDSDGSFFLYSKLTVDKSRWQQGNVFSCSVMHEALHNHYTQKSLSLSPGK
light.V: IGKV3-15 | len=95 | preview=EIVMTQSPATLSVSPGERATLSCRASQSVSSNLAWYQQKPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDFAVYYCQQYNNWP
light.J: IGKJ4 | len=12 | preview=LTFGGGTKVEIK
light.C: IGKC | len=107 | preview=RTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC

=== HEAVY ===
- CDR1:
  ranges: [(25,

In [2]:
# -*- coding: utf-8 -*-
import heapq
from typing import Dict, List, Tuple, Optional

def read_kmer_set_from_csv(file_path):
    import pandas as pd
    df = pd.read_csv(file_path)
    # 过滤长度为7的kmer并转换为字典
    kmer_set = {kmer: count for kmer, count in zip(df['kmer'], df['count']) if len(kmer) == 7}
    return kmer_set

def beam_search(start_kmer: str,
                start_kmer_template: Optional[str],
                kmers: Dict[str, int],
                template_sequence: str,
                template_labels: List[str],
                beam_width: int,
                max_iterations: int,
                top_n: int = 3,
                dist_threshold: int = 5,
                cdr_tail: int = 6,
                stop_sequence: Optional[str] = None,
                direction: int = 1,
                distance_guard: int = 1,                   # 1=启用距离保护, -1=关闭
                ban_sequences: Optional[List[str]] = None, # 长序列黑名单：I→L，再生成所有7-mer
                template_weight_mode: int = 1              # 1=按模板加权；-1=仅用count，不加权
                ) -> Dict[str, int]:
    """
    固定 7-mer 的束搜索（自动起点定位 + 双向扩展）

    新增：template_weight_mode
      - 1 : 模板匹配加权（保守位点×10，非保守位点×2；失配不加权）
      - -1: 完全基于 kmers 的 count，不使用任何模板乘法权重
    """
    # -------------- checks --------------
    assert direction in (1, -1), "direction must be 1 (forward) or -1 (reverse)"
    assert len(start_kmer) == 7, "start_kmer must be length-7"
    if start_kmer_template is not None:
        assert len(start_kmer_template) == 7, "start_kmer_template must be length-7"
    if stop_sequence is not None:
        assert len(stop_sequence) == 5, "stop_sequence must be length-5"
    assert len(template_sequence) == len(template_labels), "template_sequence and template_labels must align"
    assert template_weight_mode in (1, -1), "template_weight_mode must be 1 or -1"

    # -------------- helpers --------------
    def levenshtein(a: str, b: str) -> int:
        la, lb = len(a), len(b)
        if la == 0: return lb
        if lb == 0: return la
        prev = list(range(lb + 1)); curr = [0] * (lb + 1)
        for i in range(1, la + 1):
            curr[0] = i; ai = a[i - 1]
            for j in range(1, lb + 1):
                cost = 0 if ai == b[j - 1] else 1
                curr[j] = min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost)
            prev, curr = curr, prev
        return prev[lb]

    def is_conserved(label) -> bool:
        return str(label) == "Conserved"

    def is_cdr(label) -> bool:
        u = str(label).upper()
        return u.startswith("CDR1") or u.startswith("CDR2") or u.startswith("CDR3")

    def build_protected_indices(labels: List[str], length: int, tail: int) -> set:
        """对称保护：CDR 左右各 tail 个模板位置。"""
        prot = set()
        for i, lab in enumerate(labels):
            if is_cdr(lab):
                s = max(0, i - tail)
                e = min(length - 1, i + tail)
                for j in range(s, e + 1):
                    prot.add(j)
        return prot

    def find_all_positions(seq: str, sub: str) -> List[int]:
        out, i = [], seq.find(sub, 0)
        while i != -1:
            out.append(i)
            i = seq.find(sub, i + 1)
        return out

    # -------------- 黑名单：从长序列生成被禁 7-mer（I→L，仅用于黑名单；kmers 自身只有 L） --------------
    banned7L: set = set()
    if ban_sequences:
        for long_seq in ban_sequences:
            if not long_seq:
                continue
            sL = long_seq.upper().replace('I', 'L')
            if len(sL) >= 7:
                for i in range(len(sL) - 6):
                    banned7L.add(sL[i:i+7])

    def contains_banned7(seq: str) -> bool:
        if not banned7L:
            return False
        s = seq.upper()
        for i in range(0, len(s) - 6):
            if s[i:i+7] in banned7L:
                return True
        return False

    # -------------- basics --------------
    n_template = len(template_sequence)
    protected_idx = build_protected_indices(template_labels, n_template, cdr_tail)
    overlap_order = [6, 5, 4]
    overlap_priority = {6: 3, 5: 2, 4: 1}

    anchor = start_kmer_template if start_kmer_template is not None else start_kmer
    positions = find_all_positions(template_sequence, anchor)
    if not positions:
        return {}

    # 分装一个“候选计分器”，由 template_weight_mode 控制
    def adjust_score(cnt: int, use_template: bool, tgt_pos: int, ext_aa: str) -> int:
        """
        根据模板匹配调整得分，考虑I/L等价性。
        """
        if template_weight_mode == -1:
            return cnt

        # template_weight_mode == 1
        if use_template and 0 <= tgt_pos < n_template:
            # 获取模板上的氨基酸
            template_aa = template_sequence[tgt_pos]

            # 如果模板是 I，并且扩展氨基酸是 L，视为等价
            if template_aa == 'I' and ext_aa == 'L':
                template_aa = 'L'  # 将模板的 I 当作 L 来处理

            # 计算得分
            if template_aa == ext_aa:
                return cnt * (10 if is_conserved(template_labels[tgt_pos]) else 2)
    
        return cnt

    # -------------- 单起点搜索 --------------
    def search_from_pos(pos: int) -> Dict[str, int]:
        def next_ext_pos(curr_len: int) -> int:
            added = curr_len - 7
            return (pos + 7 + added) if direction == 1 else (pos - 1 - added)

        def expand_fn(current_seq: str) -> List[Tuple[str, int, str]]:
            tgt_pos = next_ext_pos(len(current_seq))
            use_template = (0 <= tgt_pos < n_template)
            cand_by_aa: Dict[str, Tuple[str, int, int]] = {}

            for ov in overlap_order:
                if ov > len(current_seq):
                    continue

                if direction == 1:
                    anchor_sub = current_seq[-ov:]
                    ext_idx = -1 - (6 - ov)   # 6/5/4 -> -1/-2/-3
                    for kmer, cnt in kmers.items():
                        # 候选级黑名单
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.startswith(anchor_sub):
                            continue
                        if abs(ext_idx) > len(kmer):
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                else:
                    anchor_sub = current_seq[:ov]
                    for kmer, cnt in kmers.items():
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.endswith(anchor_sub):
                            continue
                        if len(kmer) <= ov:
                            continue
                        ext_idx = len(kmer) - ov - 1  # 6/5/4 -> 0/1/2
                        if ext_idx < 0:
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

            pot = [(k, s, aa) for aa, (k, s, _) in cand_by_aa.items()]
            pot.sort(key=lambda x: -x[1])
            return pot[:beam_width]

        frontier: List[Tuple[str, int]] = [(start_kmer, 0)]

        for _ in range(max_iterations):
            new_frontier: List[Tuple[str, int]] = []
            for seq, sc in frontier:
                for kmer, ext_sc, aa in expand_fn(seq):
                    new_seq = (seq + aa) if direction == 1 else (aa + seq)
                    new_sc  = sc + ext_sc

                    # 序列级黑名单
                    if banned7L and len(new_seq) >= 7 and contains_banned7(new_seq):
                        continue

                    # 重复 7-mer 过滤
                    if len(new_seq) >= 7:
                        if direction == 1:
                            tail7 = new_seq[-7:]
                            if tail7 in new_seq[:-7]:
                                continue
                        else:
                            head7 = new_seq[:7]
                            if head7 in new_seq[7:]:
                                continue

                    # 5-mer 终止
                    if stop_sequence is not None and len(new_seq) >= 5:
                        if (direction == 1 and new_seq[-5:] == stop_sequence) or \
                           (direction == -1 and new_seq[:5] == stop_sequence):
                            return {new_seq: new_sc}

                    # 7-mer 距离保护（模板内 & 非保护位点；是否开启由 distance_guard 控制）
                    if distance_guard == 1:
                        tgt_pos = next_ext_pos(len(seq))  # 新 AA 的模板坐标（扩展前长度）
                        if (0 <= tgt_pos < n_template) and (tgt_pos not in protected_idx) and len(new_seq) >= 7:
                            if direction == 1:
                                q7 = new_seq[-7:]; t_start = max(0, tgt_pos - 6); t_end = tgt_pos + 1
                            else:
                                q7 = new_seq[:7];  t_start = tgt_pos;              t_end = min(n_template, tgt_pos + 7)
                            templ7 = template_sequence[t_start:t_end]
                            if len(templ7) == 7 and levenshtein(q7, templ7) >= dist_threshold:
                                continue

                    new_frontier.append((new_seq, new_sc))

            if not new_frontier:
                break
            frontier = heapq.nlargest(beam_width, new_frontier, key=lambda x: x[1])

        best = heapq.nlargest(top_n, frontier, key=lambda x: x[1])
        return {seq: score for (seq, score) in best}

    # -------------- 多起点聚合 --------------
    aggregated: Dict[str, int] = {}
    for p in positions:
        sub = search_from_pos(p)
        for s, v in sub.items():
            if (s not in aggregated) or (v > aggregated[s]):
                aggregated[s] = v

    if not aggregated:
        return {}
    return dict(sorted(aggregated.items(), key=lambda x: -x[1])[:top_n])


In [4]:
def beam_search_C(start_kmer: str,
                  start_kmer_template: Optional[str],
                  kmers: Dict[str, int],
                  template_sequence: str,
                  template_labels: List[str],
                  beam_width: int,
                  max_iterations: int,
                  top_n: int = 3,
                  dist_threshold: int = 5,
                  cdr_tail: int = 6,
                  stop_sequence: Optional[str] = None,
                  direction: int = 1,
                  distance_guard: int = 1,                    # 1=启用距离保护, -1=关闭
                  ban_sequences: Optional[List[str]] = None,  # 长序列黑名单：I→L，再生成所有7-mer
                  template_weight_mode: int = 1,              # 1=按模板加权；-1=仅用count，不加权
                  min_overlap: int = 4                        # 最小 overlap（∈{4,5,6}），范围自动为 [min_overlap..6]
                  ) -> Dict[str, int]:
    """
    固定 7-mer 的束搜索（自动起点定位 + 双向扩展）

    关键参数：
      - min_overlap: 设置使用的 overlap 最小值。
          =6 -> 仅用 overlap=6
          =5 -> 用 overlap=6,5
          =4 -> 用 overlap=6,5,4

    计数阈值（默认）：
      - overlap=6：count ≥ 1
      - overlap=5/4：count ≥ 2

    特殊规则：
      - 当当前步使用 ov=6 且模板坐标有效时，如果模板氨基酸 template_sequence[tgt_pos]
        对应的扩展 AA 未出现在候选中，则强制补入一个候选（合成 kmer；count=1），随后照常进行
        模板加权（Conserved×10，非Conserved×5）。

    变更：
      - 在模板加权（或未加权）得到 adj 分数后，**将 adj==1 的候选过滤掉**，再进行组装扩展。
    """
    # -------------- checks --------------
    assert direction in (1, -1), "direction must be 1 (forward) or -1 (reverse)"
    assert len(start_kmer) == 7, "start_kmer must be length-7"
    if start_kmer_template is not None:
        assert len(start_kmer_template) == 7, "start_kmer_template must be length-7"
    if stop_sequence is not None:
        assert len(stop_sequence) == 5, "stop_sequence must be length-5"
    assert len(template_sequence) == len(template_labels), "template_sequence and template_labels must align"
    assert template_weight_mode in (1, -1), "template_weight_mode must be 1 or -1"
    assert min_overlap in (4, 5, 6), "min_overlap must be one of {4,5,6}"

    # -------------- helpers --------------
    def levenshtein(a: str, b: str) -> int:
        la, lb = len(a), len(b)
        if la == 0: return lb
        if lb == 0: return la
        prev = list(range(lb + 1)); curr = [0] * (lb + 1)
        for i in range(1, la + 1):
            curr[0] = i; ai = a[i - 1]
            for j in range(1, lb + 1):
                cost = 0 if ai == b[j - 1] else 1
                curr[j] = min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost)
            prev, curr = curr, prev
        return prev[lb]

    def is_conserved(label) -> bool:
        return str(label) == "Conserved"

    def is_cdr(label) -> bool:
        u = str(label).upper()
        return u.startswith("CDR1") or u.startswith("CDR2") or u.startswith("CDR3")

    def build_protected_indices(labels: List[str], length: int, tail: int) -> set:
        """对称保护：CDR 左右各 tail 个模板位置。"""
        prot = set()
        for i, lab in enumerate(labels):
            if is_cdr(lab):
                s = max(0, i - tail)
                e = min(length - 1, i + tail)
                for j in range(s, e + 1):
                    prot.add(j)
        return prot

    def find_all_positions(seq: str, sub: str) -> List[int]:
        out, i = [], seq.find(sub, 0)
        while i != -1:
            out.append(i)
            i = seq.find(sub, i + 1)
        return out

    # -------------- 黑名单：从长序列生成被禁 7-mer（I→L，仅用于黑名单；kmers 自身只有 L） --------------
    banned7L: set = set()
    if ban_sequences:
        for long_seq in ban_sequences:
            if not long_seq:
                continue
            sL = long_seq.upper().replace('I', 'L')
            if len(sL) >= 7:
                for i in range(len(sL) - 6):
                    banned7L.add(sL[i:i+7])

    def contains_banned7(seq: str) -> bool:
        if not banned7L:
            return False
        s = seq.upper()
        for i in range(0, len(s) - 6):
            if s[i:i+7] in banned7L:
                return True
        return False

    # -------------- basics --------------
    n_template = len(template_sequence)
    protected_idx = build_protected_indices(template_labels, n_template, cdr_tail)

    # 根据 min_overlap 生成 overlap 顺序（大到小）
    overlap_order = [ov for ov in (6, 5, 4) if ov >= min_overlap]
    # 优先级：序号越前优先级越高
    overlap_priority = {ov: (len(overlap_order) - i) for i, ov in enumerate(overlap_order)}
    # 最小 count 要求：6→1，其它（5/4）→2
    min_count_by_overlap = {ov: (1 if ov == 6 else 2) for ov in overlap_order}

    anchor = start_kmer_template if start_kmer_template is not None else start_kmer
    positions = find_all_positions(template_sequence, anchor)
    if not positions:
        return {}

    # 模板加权（与 I/L 等价）
    def adjust_score(cnt: int, use_template: bool, tgt_pos: int, ext_aa: str) -> int:
        if template_weight_mode == -1:
            return cnt
        if use_template and 0 <= tgt_pos < n_template:
            templ_aa = template_sequence[tgt_pos]
            templ_norm = 'L' if templ_aa in ('I', 'L') else templ_aa
            ext_norm   = 'L' if ext_aa in ('I', 'L') else ext_aa
            if templ_norm == ext_norm:
                return cnt * (10 if is_conserved(template_labels[tgt_pos]) else 5)
        return cnt

    # -------------- 单起点搜索 --------------
    def search_from_pos(pos: int) -> Dict[str, int]:
        def next_ext_pos(curr_len: int) -> int:
            added = curr_len - 7
            return (pos + 7 + added) if direction == 1 else (pos - 1 - added)

        def expand_fn(current_seq: str) -> List[Tuple[str, int, str]]:
            tgt_pos = next_ext_pos(len(current_seq))
            use_template = (0 <= tgt_pos < n_template)
            cand_by_aa: Dict[str, Tuple[str, int, int]] = {}

            for ov in overlap_order:
                if ov > len(current_seq):
                    continue

                if direction == 1:
                    anchor_sub = current_seq[-ov:]
                    ext_idx = -1 - (6 - ov)   # ov=6/5/4 -> -1/-2/-3
                    for kmer, cnt in kmers.items():
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.startswith(anchor_sub):
                            continue
                        if cnt < min_count_by_overlap[ov]:
                            continue
                        if abs(ext_idx) > len(kmer):
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        # —— 过滤：加权后 adj==1 的候选直接丢弃 —
                        if adj == 1:
                            continue

                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                    # ★ ov=6 时，若模板AA不在候选中则补入一个合成候选（count=1）
                    if 6 in overlap_order and ov == 6 and use_template:
                        templ_aa = template_sequence[tgt_pos]
                        if templ_aa not in cand_by_aa:
                            synthetic_kmer = anchor_sub + templ_aa
                            if not (banned7L and synthetic_kmer.upper() in banned7L):
                                adj = adjust_score(1, use_template, tgt_pos, templ_aa)
                                # 同样应用过滤
                                if adj != 1:
                                    pri = overlap_priority[6]
                                    cand_by_aa[templ_aa] = (synthetic_kmer, adj, pri)

                else:
                    anchor_sub = current_seq[:ov]
                    for kmer, cnt in kmers.items():
                        if banned7L and kmer.upper() in banned7L:
                            continue
                        if not kmer.endswith(anchor_sub):
                            continue
                        if len(kmer) <= ov:
                            continue
                        if cnt < min_count_by_overlap[ov]:
                            continue
                        ext_idx = len(kmer) - ov - 1  # 6/5/4 -> 0/1/2
                        if ext_idx < 0:
                            continue
                        ext_aa = kmer[ext_idx]

                        adj = adjust_score(cnt, use_template, tgt_pos, ext_aa)
                        # —— 过滤：加权后 adj==1 的候选直接丢弃 —
                        if adj == 1:
                            continue

                        pri = overlap_priority[ov]
                        if ext_aa in cand_by_aa:
                            _, prev_sc, prev_pri = cand_by_aa[ext_aa]
                            if (pri > prev_pri) or (pri == prev_pri and adj > prev_sc):
                                cand_by_aa[ext_aa] = (kmer, adj, pri)
                        else:
                            cand_by_aa[ext_aa] = (kmer, adj, pri)

                    # ★ ov=6 时补模板AA（反向）
                    if 6 in overlap_order and ov == 6 and use_template:
                        templ_aa = template_sequence[tgt_pos]
                        if templ_aa not in cand_by_aa:
                            synthetic_kmer = templ_aa + anchor_sub
                            if not (banned7L and synthetic_kmer.upper() in banned7L):
                                adj = adjust_score(1, use_template, tgt_pos, templ_aa)
                                if adj != 1:
                                    pri = overlap_priority[6]
                                    cand_by_aa[templ_aa] = (synthetic_kmer, adj, pri)

            pot = [(k, s, aa) for aa, (k, s, _) in cand_by_aa.items()]
            pot.sort(key=lambda x: -x[1])
            return pot[:beam_width]

        frontier: List[Tuple[str, int]] = [(start_kmer, 0)]

        for _ in range(max_iterations):
            new_frontier: List[Tuple[str, int]] = []
            for seq, sc in frontier:
                for kmer, ext_sc, aa in expand_fn(seq):
                    new_seq = (seq + aa) if direction == 1 else (aa + seq)
                    new_sc  = sc + ext_sc

                    # 序列级黑名单
                    if banned7L and len(new_seq) >= 7 and contains_banned7(new_seq):
                        continue

                    # 重复 7-mer 过滤
                    if len(new_seq) >= 7:
                        if direction == 1:
                            tail7 = new_seq[-7:]
                            if tail7 in new_seq[:-7]:
                                continue
                        else:
                            head7 = new_seq[:7]
                            if head7 in new_seq[7:]:
                                continue

                    # 5-mer 终止
                    if stop_sequence is not None and len(new_seq) >= 5:
                        if (direction == 1 and new_seq[-5:] == stop_sequence) or \
                           (direction == -1 and new_seq[:5] == stop_sequence):
                            return {new_seq: new_sc}

                    # 7-mer 距离保护（模板内 & 非保护位点）
                    if distance_guard == 1:
                        tgt_pos = (pos + len(new_seq)) if direction == 1 else (pos - len(new_seq) + 6)
                        if (0 <= tgt_pos < n_template) and (tgt_pos not in protected_idx) and len(new_seq) >= 7:
                            if direction == 1:
                                q7 = new_seq[-7:]; t_start = max(0, tgt_pos - 6); t_end = tgt_pos + 1
                            else:
                                q7 = new_seq[:7];  t_start = tgt_pos;              t_end = min(n_template, tgt_pos + 7)
                            templ7 = template_sequence[t_start:t_end]
                            if len(templ7) == 7 and levenshtein(q7, templ7) >= dist_threshold:
                                continue

                    new_frontier.append((new_seq, new_sc))

            if not new_frontier:
                break
            frontier = heapq.nlargest(beam_width, new_frontier, key=lambda x: x[1])

        best = heapq.nlargest(top_n, frontier, key=lambda x: x[1])
        return {seq: score for (seq, score) in best}

    # -------------- 多起点聚合 --------------
    aggregated: Dict[str, int] = {}
    for p in positions:
        sub = search_from_pos(p)
        for s, v in sub.items():
            if (s not in aggregated) or (v > aggregated[s]):
                aggregated[s] = v

    if not aggregated:
        return {}
    return dict(sorted(aggregated.items(), key=lambda x: -x[1])[:top_n])


In [5]:
import pandas as pd
import Levenshtein
import time

def suffix_prefix_intersect(s1: str, s2: str) -> str:
    """返回 s1 的后缀 与 s2 的前缀 的最长交集；无交集返回空串"""
    k = min(len(s1), len(s2))
    while k > 0 and s1[-k:] != s2[:k]:
        k -= 1
    return s2[:k]  # 或者返回 s1[-k:]

def normalize(seq: str) -> str:
    """将 I/L 视为等价：统一把 I 映射为 L。"""
    if not isinstance(seq, str):
        seq = str(seq)
    return seq.replace('I', 'L')

def get_candidates(query_sequence, kmer_set, max_distance=3, max_overlap=3, conserved_positions=None):
    """
    获取符合条件的候选序列，包括最大距离和最大重叠。
    conserved_positions：保守位点的列表，候选的 k-mer 在这些位置的氨基酸必须与 query_sequence 匹配。
    """
    candidates = []
    q_norm = normalize(query_sequence)

    # 如果没有提供保守位点，设为 None
    if conserved_positions is None:
        conserved_positions = []

    for kmer, count in kmer_set.items():
        k_norm = normalize(kmer)
        
        # 检查保守位点的匹配
        if conserved_positions:
            for pos in conserved_positions:
                if q_norm[pos] != k_norm[pos]:  # 保守位点位置不匹配
                    break
            else:  # 如果所有保守位点匹配
                distance = Levenshtein.distance(q_norm, k_norm)
                if distance <= max_distance and max_overlap_count(q_norm, k_norm) <= max_overlap:
                    candidates.append((kmer, count, distance))
        else:
            distance = Levenshtein.distance(q_norm, k_norm)
            if distance <= max_distance and max_overlap_count(q_norm, k_norm) <= max_overlap:
                candidates.append((kmer, count, distance))

    # 按 count 从大到小排序
    candidates_sorted = sorted(candidates, key=lambda x: x[1], reverse=True)
    return candidates_sorted

def max_overlap_count(seq1, seq2):
    """
    计算左右对齐时的最大“相同字符”重叠数，I==L 已被归一化处理。
    """
    s1 = normalize(seq1)
    s2 = normalize(seq2)
    len1, len2 = len(s1), len(s2)
    max_overlap = 0

    # 左对齐：s1 向右移动
    for i in range(1, len1):
        n = min(len1 - i, len2)
        if n <= 0:
            break
        overlap = sum(1 for j in range(n) if s1[i + j] == s2[j])
        if overlap > max_overlap:
            max_overlap = overlap

    # 右对齐：s2 向右移动
    for i in range(1, len2):
        n = min(len2 - i, len1)
        if n <= 0:
            break
        overlap = sum(1 for j in range(n) if s1[j] == s2[i + j])
        if overlap > max_overlap:
            max_overlap = overlap

    return max_overlap

def get_seed(query_sequence, kmer_set, max_distance=3, max_overlap=3, conserved_positions=None):
    """
    返回得分最高的 seed: (kmer, count, distance, score)
    """
    candidates = get_candidates(query_sequence, kmer_set, max_distance, max_overlap, conserved_positions)
    if not candidates:
        return None

    best = None
    best_score = float("-inf")
    for kmer, count, distance in candidates:
        score = count * (2 ** (3 - distance))
        if score > best_score:
            best = [kmer, score]
            best_score = score

    return best

In [7]:
def get_modified_template(template: str, target: str, placeholder: str = 'X') -> str:
    """
    在只允许“一段连续插入或删除”的前提下，将 template 的长度调整为与 target 一致，
    并选择使标准 Levenshtein（插入/删除/替换代价=1，完全相等才算匹配）距离最小的位置。
    若 target 更长：在 template 中插入 placeholder * n；
    若 target 更短：从 template 删除一段长度 n；
    若两者等长：直接返回 template（不改字符差异）。
    """

    def lev(a: str, b: str) -> int:
        # 标准 Levenshtein 距离（无任何等价字符规则）
        if len(a) < len(b):
            a, b = b, a
        if not b:
            return len(a)
        prev = list(range(len(b) + 1))
        for i, ca in enumerate(a, 1):
            curr = [i]
            for j, cb in enumerate(b, 1):
                cost_sub = 0 if ca == cb else 1
                curr.append(min(
                    prev[j] + 1,        # 删除 a 的一个字符
                    curr[j - 1] + 1,    # 向 a 插入一个字符
                    prev[j - 1] + cost_sub  # 替换/匹配
                ))
            prev = curr
        return prev[-1]

    len_t, len_g = len(template), len(target)
    if len_t == len_g:
        return template

    if len_g > len_t:
        # 需要在 template 插入 n 个占位符
        n = len_g - len_t
        best_i, best_d = 0, float('inf')
        # 在 target 中“假删去”长度 n 的一段，找与 template 距离最小的位置
        for i in range(len_g - n + 1):
            cand = target[:i] + target[i + n:]
            d = lev(cand, template)
            if d < best_d:
                best_d, best_i = d, i
        insert_pos = min(best_i, len_t)
        return template[:insert_pos] + (placeholder * n) + template[insert_pos:]
    else:
        # 需要从 template 删除 n 个字符（连续一段）
        n = len_t - len_g
        best_i, best_d = 0, float('inf')
        for i in range(len_t - n + 1):
            cand = template[:i] + template[i + n:]
            d = lev(cand, target)
            if d < best_d:
                best_d, best_i = d, i
        return template[:best_i] + template[best_i + n:]


In [8]:
def replace_cdr_in_template(
    template: Dict[str, Dict[str, Tuple[str, List[str]]]],
    new_cdr_seq: str,
    chain: str = "light",   # "light" 或 "heavy"
    region: str = "V",      # "V" 或 "J"
    cdr_tag: str = "CDR1"   # "CDR1" / "CDR2" / "CDR3"
) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
    """
    仅替换 template[chain][region] 中标记为 cdr_tag 的区域为 new_cdr_seq，
    并同步更新 labels；不做任何 cdr_info 重算。
    说明：若该标签在该段有多个不相邻块，会把它们的整体跨度（第一个块起点到最后一个块终点）
          合并为一段后进行替换。
    """
    # 1) 取段
    if chain not in template or region not in template[chain]:
        raise KeyError(f"template 中不存在 {chain}.{region} 段")
    seq, labels = template[chain][region]
    if len(seq) != len(labels):
        raise ValueError(f"{chain}.{region} 序列与标签长度不一致: len(seq)={len(seq)} len(labels)={len(labels)}")

    # 2) 找到 cdr_tag 的连续区间（使用你模块里的 _contiguous_ranges）
    ranges = _contiguous_ranges(labels, cdr_tag)
    if not ranges:
        raise ValueError(f"{chain}.{region} 段未找到 {cdr_tag}")

    s, e = ranges[0][0], ranges[-1][1]

    # 3) 规范新序列并替换
    new_cdr = re.sub(r"\s+", "", new_cdr_seq).upper()
    new_seq    = seq[:s] + new_cdr + seq[e+1:]
    new_labels = labels[:s] + [cdr_tag] * len(new_cdr) + labels[e+1:]

    # 4) 写回
    template[chain][region] = (new_seq, new_labels)
    return template

In [None]:
HC_assembly = ''
LC_assembly = ''
file_path = 'E:/Data/Fusion/dataset/mAbs/human/SA58/casanovo/50-cleaned/kmer_50_Casanovo.csv'
kmer_set = read_kmer_set_from_csv(file_path)

#轻链
template_sequence, template_labels = template['light']['V']
conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
start_query_sequence = template_sequence[0:7]
start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=3, max_overlap=4, conserved_positions=None)[0]
start_kmer_template = template_sequence[0:7]
end_query_sequence = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])

ban_seqs = [template['light']['C'][0], template['heavy']['C'][0]]
if seed is None:
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, direction=1,distance_guard=1, ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    if seq[-1] == template_sequence[conser_position[0]]:
        LC_assembly = seq
    else:
        print('check the template information and kmers')  
else:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    LC_assembly = seq

query_sequence = cdr_info['light']['CDR1']['following7_seq']
start_kmer = get_seed(query_sequence, kmer_set, max_distance=3, max_overlap=4, conserved_positions=None)[0]
idx = cdr_info['light']['CDR1']['following7_range'][1]-6
result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=-1,distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
seq = next(iter(result))
Len = len(suffix_prefix_intersect(LC_assembly, seq))
if Len>=1:
    LC_assembly = LC_assembly+seq[Len:]
    CDR1 = seq[Len+3:len(seq)-7]

if len(CDR1) != len(cdr_info['light']['CDR1']['sequence']):
    new_cdr1_LV = get_modified_template(cdr_info['light']['CDR1']['sequence'],CDR1)  
    template = replace_cdr_in_template(template, new_cdr1_LV, chain="light", region="V", cdr_tag="CDR1")
    template_sequence, template_labels = template['light']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']


start_kmer = LC_assembly[-7:]
start_kmer_template = cdr_info['light']['CDR1']['following7_seq']
end_query_sequence = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[2]-len(LC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]


start_kmer = LC_assembly[-7:]
start_kmer_template = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
end_query_sequence = cdr_info['light']['CDR3']['preceding7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[3]-len(LC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]

start_kmer = LC_assembly[-7:]
start_kmer_template = cdr_info['light']['CDR3']['preceding7_seq']
end_query_sequence = cdr_info['light']['CDR3']['following7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[0,1,3])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(LC_assembly)+30, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
if end_kmer in seq:
    Len = len(seq) - len(end_kmer) - len(start_kmer)
    CDR3 = seq[7:len(seq)-7]
    LC_assembly = LC_assembly+seq[7:]

if len(CDR3) != len(cdr_info['light']['CDR3']['sequence']):
    new_cdr3_LV = get_modified_template(cdr_info['light']['CDR3']['sequence'],CDR3)  
    template = replace_cdr_in_template(template, new_cdr3_LV, chain="light", region="V", cdr_tag="CDR3")
    template_sequence, template_labels = template['light']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
    seq, labels = template['light']['J']
    seq_no_cdr3 = ''.join(aa for aa, lab in zip(seq, labels) if lab != 'CDR3')
    label_no_cdr3 = [lab for lab in labels if lab != 'CDR3']
    template['light']['V_J'] = ( template['light']['V'][0]+seq_no_cdr3, template['light']['V'][1]+label_no_cdr3)
else:
    template['light']['V_J'] = ( template['light']['V'][0]+template['light']['J'][0], template['light']['V'][1]+template['light']['J'][1])


template_sequence, template_labels = template['light']['V_J']
start_kmer = LC_assembly[-7:]
start_kmer_template = cdr_info['light']['CDR3']['following7_seq']
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template_sequence) - len(LC_assembly), top_n = 5,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
LC_assembly = LC_assembly+seq[7:]

start_kmer = LC_assembly[-7:]
start_kmer_template = template['light']['J'][0][-7:]
template_sequence = template['light']['V_J'][0] + template['light']['C'][0]
template_labels = template['light']['V_J'][1] + template['light']['C'][1]
#ban_seqs = [template['heavy']['C'][0]]
#result = beam_search_C(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template['light']['C'][0]), top_n = 5,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1,min_overlap=5)
result = beam_search_C(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template['light']['C'][0]), top_n = 5,direction=1,distance_guard=1, template_weight_mode=1,min_overlap=5)
seq = next(iter(result))
if len(seq)-7 == len(template['light']['C'][0]):
    LC_assembly = LC_assembly+seq[7:]
    LC_template = template['light']['V_J'][0] + template['light']['C'][0]
    LC_label = template['light']['V_J'][1] + template['light']['C'][1]

print(LC_assembly)
print(len(LC_assembly))

EVVMTQSPASLSVSPGERATLSCRARASLGLSTDLAWYQQRPGQAPRLLLYGASTRATGLPARFSGSGSGTEFTLTLSSLQSEDSAVYYCQQYSNWPLTFGGGTKVELKRTVAAPSVFLFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDSALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC
216


In [34]:
#重链
template_sequence, template_labels = template['heavy']['V']
conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
start_query_sequence = template_sequence[0:7]
start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=3, max_overlap=4, conserved_positions=None)[0]
start_kmer_template = template_sequence[0:7]
end_query_sequence = template_sequence[(conser_position[0]-6):(conser_position[0]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])

ban_seqs = [template['light']['C'][0], template['heavy']['C'][0]]
if seed is None:
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, direction=1,distance_guard=1, ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    if seq[-1] == template_sequence[conser_position[0]]:
        HC_assembly = seq
    else:
        print('check the template information and kmers')  
else:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[0]-6, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = seq


query_sequence = cdr_info['heavy']['CDR1']['following7_seq']
start_kmer = get_seed(query_sequence, kmer_set, max_distance=3, max_overlap=4, conserved_positions=None)[0]
idx = cdr_info['heavy']['CDR1']['following7_range'][1]-6
result = beam_search(start_kmer, query_sequence, kmer_set, template_sequence, template_labels, beam_width=1, max_iterations=idx, top_n = 5, direction=-1,distance_guard=-1, ban_sequences=ban_seqs, template_weight_mode=-1)
seq = next(iter(result))
Len = len(suffix_prefix_intersect(HC_assembly, seq))
if Len>=1:
    HC_assembly = HC_assembly+seq[Len:]
    CDR1 = seq[Len+3:len(seq)-7]

if len(CDR1) != len(cdr_info['heavy']['CDR1']['sequence']):
    new_cdr1_HV = get_modified_template(cdr_info['heavy']['CDR1']['sequence'],CDR1)  
    template = replace_cdr_in_template(template, new_cdr1_HV, chain="heavy", region="V", cdr_tag="CDR1")
    template_sequence, template_labels = template['heavy']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']

start_kmer_template = cdr_info['heavy']['CDR1']['following7_seq']
start_kmer = HC_assembly[-7:]
end_query_sequence = cdr_info['heavy']['CDR2']['preceding7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
if seed:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=100, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = HC_assembly+seq[7:]


start_query_sequence = cdr_info['heavy']['CDR2']['preceding7_seq']
start_kmer = get_seed(start_query_sequence, kmer_set, max_distance=3, max_overlap=4, conserved_positions=None)[0]
start_kmer_template = start_query_sequence
end_query_sequence = cdr_info['heavy']['CDR2']['following7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
if seed:
    end_kmer = seed[0] 
    result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=100, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
    seq = next(iter(result))
    HC_assembly = HC_assembly+seq[7:]
    CDR2 = seq[7:-7]
if len(CDR2) != len(cdr_info['heavy']['CDR2']['sequence']):
    new_cdr2_HV = get_modified_template(cdr_info['heavy']['CDR2']['sequence'],CDR2)  
    template = replace_cdr_in_template(template, new_cdr2_HV, chain="heavy", region="V", cdr_tag="CDR2")
    template_sequence, template_labels = template['heavy']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']


start_kmer = HC_assembly[-7:]
start_kmer_template = cdr_info['heavy']['CDR2']['following7_seq']
end_query_sequence = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[2]-len(HC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
HC_assembly = HC_assembly+seq[7:]


start_kmer = HC_assembly[-7:]
start_kmer_template = template_sequence[(conser_position[2]-6):(conser_position[2]+1)]
end_query_sequence = cdr_info['heavy']['CDR3']['preceding7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[6])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=conser_position[3]-len(HC_assembly)+1, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
HC_assembly = HC_assembly+seq[7:]

start_kmer = HC_assembly[-7:]
start_kmer_template = cdr_info['heavy']['CDR3']['preceding7_seq']
end_query_sequence = cdr_info['heavy']['CDR3']['following7_seq']
seed = get_seed(end_query_sequence, kmer_set,
                max_distance=2, max_overlap=5, conserved_positions=[0,1,3])
end_kmer = seed[0] 
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(HC_assembly)+30, top_n = 5, stop_sequence=end_kmer[2:],direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
if end_kmer in seq:
    Len = len(seq) - len(end_kmer) - len(start_kmer)
    CDR3 = seq[7:len(seq)-7]
    HC_assembly = HC_assembly+seq[7:]

if len(CDR3) != len(cdr_info['heavy']['CDR3']['sequence']):
    new_cdr3_HV = get_modified_template(cdr_info['heavy']['CDR3']['sequence'],CDR3)  
    template = replace_cdr_in_template(template, new_cdr3_HV, chain="heavy", region="V", cdr_tag="CDR3")
    template_sequence, template_labels = template['heavy']['V']
    conser_position = [i for i in range(len(template_labels)) if template_labels[i] == 'Conserved']
    seq, labels = template['heavy']['J']
    seq_no_cdr3 = ''.join(aa for aa, lab in zip(seq, labels) if lab != 'CDR3')
    label_no_cdr3 = [lab for lab in labels if lab != 'CDR3']
    template['heavy']['V_J'] = ( template['heavy']['V'][0]+seq_no_cdr3, template['heavy']['V'][1]+label_no_cdr3)
else:
    template['heavy']['V_J'] = ( template['heavy']['V'][0]+template['heavy']['J'][0], template['heavy']['V'][1]+template['heavy']['J'][1])


template_sequence, template_labels = template['heavy']['V_J']
start_kmer = HC_assembly[-7:]
start_kmer_template = cdr_info['heavy']['CDR3']['following7_seq']
result = beam_search(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template_sequence) - len(HC_assembly), top_n = 5,direction=1,distance_guard=1,ban_sequences=ban_seqs, template_weight_mode=1)
seq = next(iter(result))
HC_assembly = HC_assembly+seq[7:]


start_kmer = HC_assembly[-7:]
start_kmer_template = template['heavy']['J'][0][-7:]
template_sequence = template['heavy']['V_J'][0] + template['heavy']['C'][0]
template_labels = template['heavy']['V_J'][1] + template['heavy']['C'][1]
result = beam_search_C(start_kmer, start_kmer_template, kmer_set, template_sequence, template_labels, beam_width=5, max_iterations=len(template['heavy']['C'][0]), top_n = 5,direction=1,distance_guard=1, template_weight_mode=1,min_overlap=6)
seq = next(iter(result))
if len(seq)-7 == len(template['heavy']['C'][0]):
    HC_assembly = HC_assembly+seq[7:]
    HC_template = template['heavy']['V_J'][0] + template['heavy']['C'][0]
    HC_label = template['heavy']['V_J'][1] + template['heavy']['C'][1]

print(HC_assembly)
print(len(HC_assembly))

QVQLAQSGSELRKPGASVKVSCDTSGHSFTSNALHWVRQAPGQGLEWMGWVNTDTGTPTYAQGFTGRFVFSLDTSARTAYLQLSSLKADDTAVFYCARERDYSDYFFDYWGQGTLVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYLCNVNHKPSNTKVDKKVEPKSCDKTHTCPPCPAPELLGGPSVFLFPPKPKDTLMLSRTPEVTCVVVDVSHEDPEVKFNWYVDGVEVHNAKTKPREEQYNSTYRVVSVLTVLHQDWLNGKEYKCKVSNKALPAPLEKTLSKAKGQPREPQVYTLPPSRDELTKNQVSLTCLVKGFYPSDLAVEWESNGQPENNYKTTPPVLDSDGSFFLYSKLTVDKSRWQQGNVFSCSVMHEALHNHYTQKSLSLSPGK
450
