In [8]:
from pathlib import Path
from typing import Optional, List, Dict, Tuple
import pandas as pd
from IPython.display import display

# ========= CONFIGURE HERE =========
FASTA_PATH = Path("./mv.fa")  # <- change to your file

# Filtering thresholds (set to None to disable that bound)
MIN_LEN: Optional[int] = None        # drop sequences with length < MIN_LEN
MAX_LEN: Optional[int] = None        # drop sequences with length > MAX_LEN
MIN_GAP_FRAC: Optional[float] = None # drop if gap_fraction < MIN_GAP_FRAC
MAX_GAP_FRAC: Optional[float] = None # drop if gap_fraction > MAX_GAP_FRAC

# Optional: also write CSVs (set to True to save)
SAVE_FILES = False
OUT_DIR = FASTA_PATH.parent
# ==================================

def read_fasta(path: Path):
    """Yield (header, sequence) for each entry in a FASTA file."""
    header = None
    seq_chunks: List[str] = []
    with open(path, "r") as fh:
        for raw in fh:
            line = raw.rstrip("\r\n")
            if not line:
                continue
            if line.startswith(">"):
                if header is not None:
                    yield header, "".join(seq_chunks)
                header = line[1:].strip()
                seq_chunks = []
            else:
                seq_chunks.append(line.strip())
    if header is not None:
        yield header, "".join(seq_chunks)

def compute_metrics(path: Path) -> pd.DataFrame:
    rows: List[Dict] = []
    global_chars = set()
    for sid, seq in read_fasta(path):
        s = seq.upper()
        global_chars.update(s)
        length = len(s)
        gap_count = s.count("-")
        non_atgc_count = sum(ch not in {"A","C","G","T","-"} for ch in s)
        non_atgc_chars = sorted({ch for ch in s if ch not in {"A","C","G","T","-"}})
        rows.append({
            "sequence_id": sid,
            "length": length,
            "gap_count": gap_count,
            "gap_fraction": (gap_count / length) if length else 0.0,
            "non_ATGC_count": non_atgc_count,
            "non_ATGC_fraction": (non_atgc_count / length) if length else 0.0,
            "non_ATGC_chars": ",".join(non_atgc_chars),
            "only_ATGC_and_gaps": (non_atgc_count == 0),
        })
    df = pd.DataFrame(rows)
    df.attrs["observed_chars"] = "".join(sorted(global_chars))
    return df

def apply_filters(df: pd.DataFrame,
                  min_len: Optional[int],
                  max_len: Optional[int],
                  min_gap_frac: Optional[float],
                  max_gap_frac: Optional[float]) -> pd.DataFrame:
    df = df.copy()
    keep = pd.Series(True, index=df.index)
    reasons: List[str] = []

    # length filters
    if min_len is not None:
        keep &= df["length"] >= min_len
    if max_len is not None:
        keep &= df["length"] <= max_len

    # gap filters
    if min_gap_frac is not None:
        keep &= df["gap_fraction"] >= min_gap_frac
    if max_gap_frac is not None:
        keep &= df["gap_fraction"] <= max_gap_frac

    # per-row reasons
    dr = []
    for _, r in df.iterrows():
        row_reasons = []
        if (min_len is not None) and (r["length"] < min_len): row_reasons.append(f"len<{min_len}")
        if (max_len is not None) and (r["length"] > max_len): row_reasons.append(f"len>{max_len}")
        if (min_gap_frac is not None) and (r["gap_fraction"] < min_gap_frac): row_reasons.append(f"gap_frac<{min_gap_frac}")
        if (max_gap_frac is not None) and (r["gap_fraction"] > max_gap_frac): row_reasons.append(f"gap_frac>{max_gap_frac}")
        dr.append(",".join(row_reasons))
    df["keep"] = keep
    df["drop_reason"] = dr
    return df

def global_summary(df: pd.DataFrame, filt: pd.DataFrame, fasta_path: Path,
                   min_len, max_len, min_gap_frac, max_gap_frac) -> pd.DataFrame:
    unique_lengths = sorted(df["length"].unique().tolist())
    all_same_length = (len(unique_lengths) == 1)
    obs_chars = df.attrs.get("observed_chars", "")
    summary_rows = [
        {"metric": "file", "value": fasta_path.name},
        {"metric": "sequences_total", "value": int(len(df))},
        {"metric": "all_same_length", "value": bool(all_same_length)},
        {"metric": "unique_lengths", "value": unique_lengths[:10] + (["..."] if len(unique_lengths) > 10 else [])},
        {"metric": "min_length", "value": int(df["length"].min()) if len(df) else None},
        {"metric": "max_length", "value": int(df["length"].max()) if len(df) else None},
        {"metric": "with_gaps", "value": int((df["gap_count"] > 0).sum())},
        {"metric": "with_non_ATGC", "value": int((df["non_ATGC_count"] > 0).sum())},
        {"metric": "observed_chars", "value": obs_chars},
        {"metric": "kept_after_filters", "value": int(filt["keep"].sum())},
        {"metric": "dropped_after_filters", "value": int((~filt["keep"]).sum())},
        {"metric": "filters_min_len", "value": min_len},
        {"metric": "filters_max_len", "value": max_len},
        {"metric": "filters_min_gap_frac", "value": min_gap_frac},
        {"metric": "filters_max_gap_frac", "value": max_gap_frac},
    ]
    return pd.DataFrame(summary_rows)

# ---- Run QC ----
df_metrics = compute_metrics(FASTA_PATH)
df_filtered = apply_filters(df_metrics, MIN_LEN, MAX_LEN, MIN_GAP_FRAC, MAX_GAP_FRAC)
summary_df = global_summary(df_metrics, df_filtered, FASTA_PATH,
                            MIN_LEN, MAX_LEN, MIN_GAP_FRAC, MAX_GAP_FRAC)

# Show tables inline
print("Alignment QC — Global Summary")
display(summary_df)

print("\nAlignment QC — Per-sequence (with keep/drop flags)")
display(df_filtered.sort_values(["keep", "length"], ascending=[False, False]).reset_index(drop=True))

# Optional CSVs
if SAVE_FILES:
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    summary_df.to_csv(OUT_DIR / "alignment_qc_summary.csv", index=False)
    df_filtered.to_csv(OUT_DIR / "alignment_qc_per_sequence.csv", index=False)


Alignment QC — Global Summary


Unnamed: 0,metric,value
0,file,mv.fa
1,sequences_total,226
2,all_same_length,False
3,unique_lengths,"[0, 15893, 15894]"
4,min_length,0
5,max_length,15894
6,with_gaps,0
7,with_non_ATGC,207
8,observed_chars,ACGHKMNRSTWY
9,kept_after_filters,226



Alignment QC — Per-sequence (with keep/drop flags)


Unnamed: 0,sequence_id,length,gap_count,gap_fraction,non_ATGC_count,non_ATGC_fraction,non_ATGC_chars,only_ATGC_and_gaps,keep,drop_reason
0,SRR30155677,15894,0,0.0,0,0.000000,,True,True,
1,SRR30155678,15894,0,0.0,3,0.000189,N,False,True,
2,SRR30155679,15894,0,0.0,13,0.000818,"N,Y",False,True,
3,SRR30155680,15894,0,0.0,13,0.000818,"N,R,Y",False,True,
4,SRR30155681,15894,0,0.0,2312,0.145464,"N,S",False,True,
...,...,...,...,...,...,...,...,...,...,...
221,SRR25426232,0,0,0.0,0,0.000000,,True,True,
222,SRR25426246,0,0,0.0,0,0.000000,,True,True,
223,SRR25426257,0,0,0.0,0,0.000000,,True,True,
224,SRR25426267,0,0,0.0,0,0.000000,,True,True,


In [12]:
df_filtered.groupby('length').agg({'length':'count'})

Unnamed: 0_level_0,length
length,Unnamed: 1_level_1
0,9
15893,2
15894,215


In [16]:
#!/usr/bin/env python3
"""
Extract gene/CDS-aligned FASTAs from a multiple sequence alignment (MSA)
using coordinates from NC_001498.1 (Measles virus) GenBank.

Enhancements:
- Keep ONLY sequences whose raw FASTA length equals the GenBank sequence length.
- Optionally drop sequences with > MAX_N_COUNT ambiguous 'N' bases.
- Optionally trim trailing stop codon from CDS slices (based on the reference CDS).

Outputs:
- One FASTA per feature in OUTPUT_DIR
- _index.tsv manifest
- _filter_report.tsv (dropped sequences + reasons)
- filtered MSA snapshot (optional)
"""

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import re
import csv

# Third-party
try:
    from Bio import SeqIO, Entrez
    from Bio.SeqFeature import CompoundLocation
except Exception:
    SeqIO = None
    Entrez = None
    CompoundLocation = None

# --------------------- CONFIG ---------------------
ACCESSION = "NC_001498.1"
MSA_PATH = Path("/home/anton/git/overlapTools/mv.fa")    # <- change if needed
OUTPUT_DIR = Path("/home/anton/git/overlapTools/msa")
FEATURE_LEVEL = "CDS"                                    # "gene" or "CDS"
REVCOMP_NEG_STRAND = True                                 # produce coding orientation on - strand

# NCBI Entrez (if fetching online)
NCBI_EMAIL = "anton@nekrut.org"                           # required by NCBI
NCBI_API_KEY = None                                       # optional
LOCAL_GENBANK: Optional[Path] = None                      # e.g. Path("/path/NC_001498.1.gb")

# Filtering & export options
WRITE_FILTERED_MSA = True
MAX_N_COUNT: Optional[int] = 100  # e.g. 100 -> drop sequences with >100 Ns; set None to disable
TRIM_TRAILING_STOP: bool = True    # when FEATURE_LEVEL == "CDS" and reference CDS ends with stop, trim last 3 bases

STOP_CODONS = {"TAA", "TAG", "TGA"}

# --------------------- Helpers ---------------------
@dataclass
class FeatureSlice:
    name: str
    kind: str            # 'gene' or 'CDS'
    strand: int          # +1 / -1 / 0
    intervals: List[Tuple[int, int]]  # 1-based inclusive intervals [(start, end), ...]
    qualifiers: Dict[str, List[str]]

def read_fasta(path: Path) -> List[Tuple[str, str]]:
    """Read a FASTA file into a list of (id, seqstr)."""
    ids, seqs = [], []
    cur_id, cur = None, []
    with open(path, "r") as fh:
        for line in fh:
            line = line.rstrip("\n")
            if not line:
                continue
            if line.startswith(">"):
                if cur_id is not None:
                    ids.append(cur_id)
                    seqs.append("".join(cur))
                cur_id = line[1:].strip()
                cur = []
            else:
                cur.append(line)
    if cur_id is not None:
        ids.append(cur_id)
        seqs.append("".join(cur))
    return list(zip(ids, seqs))

def find_reference_row(msa: List[Tuple[str, str]], accession_substr: str = ACCESSION) -> Optional[int]:
    for i, (sid, _) in enumerate(msa):
        if accession_substr in sid:
            return i
    return None

def build_pos_to_col_map_from_reference(ref_seq: str) -> Dict[int, int]:
    """Map ungapped reference genomic positions (1-based) → alignment columns (0-based)."""
    pos_to_col, pos = {}, 0
    for col, nt in enumerate(ref_seq):
        if nt != "-":
            pos += 1
            pos_to_col[pos] = col
    return pos_to_col

def build_pos_to_col_map_identity(length: int) -> Dict[int, int]:
    """Identity mapping: assumes ungapped alignment and genome length == alignment length."""
    return {i: (i - 1) for i in range(1, length + 1)}

def alignment_has_any_gaps(msa: List[Tuple[str, str]]) -> bool:
    return any("-" in s for _, s in msa)

def fetch_or_read_genbank(accession: str, local_file: Optional[Path], email: Optional[str], api_key: Optional[str]):
    if SeqIO is None:
        raise RuntimeError("Biopython is required (pip install biopython) to parse GenBank.")
    if local_file is not None:
        with open(local_file, "r") as fh:
            return SeqIO.read(fh, "genbank")
    if Entrez is None:
        raise RuntimeError("Biopython Entrez is required to fetch GenBank from NCBI (pip install biopython).")
    if not email:
        raise RuntimeError("Please set NCBI_EMAIL with your email to use NCBI Entrez.")
    Entrez.email = email
    if api_key:
        Entrez.api_key = api_key
    with Entrez.efetch(db="nuccore", id=accession, rettype="gbwithparts", retmode="text") as handle:
        return SeqIO.read(handle, "genbank")

def extract_features(record, level: str = "gene") -> List[FeatureSlice]:
    """Extract 'gene' or 'CDS' features as FeatureSlice objects."""
    if level not in {"gene", "CDS"}:
        raise ValueError("level must be 'gene' or 'CDS'")
    feats: List[FeatureSlice] = []
    for feat in record.features:
        if feat.type != level:
            continue

        def loc_to_intervals(loc) -> List[Tuple[int, int]]:
            if CompoundLocation is not None and isinstance(loc, CompoundLocation):
                parts = []
                for p in loc.parts:
                    parts.append((int(p.start) + 1, int(p.end)))  # to 1-based inclusive
                return parts
            else:
                return [(int(loc.start) + 1, int(loc.end))]

        intervals = loc_to_intervals(feat.location)
        strand = int(feat.location.strand or 0)
        qualifiers = {k: v for k, v in feat.qualifiers.items()}
        name = None
        for key in ("gene", "gene_synonym", "product", "note", "locus_tag"):
            if key in feat.qualifiers:
                name = feat.qualifiers[key][0]
                break
        if not name:
            start, end = intervals[0][0], intervals[-1][1]
            name = f"{level}_at_{start}_{end}"
        feats.append(FeatureSlice(name=name, kind=level, strand=strand, intervals=intervals, qualifiers=qualifiers))
    return feats

def sanitize_name(name: str) -> str:
    name = re.sub(r"\s+", "_", name)
    name = re.sub(r"[^A-Za-z0-9_.+-]", "_", name)
    return name

def collect_columns_for_feature(pos_to_col: Dict[int, int], feature: FeatureSlice, revcomp_neg: bool) -> List[int]:
    cols: List[int] = []
    parts = feature.intervals
    if feature.strand == -1 and revcomp_neg:
        parts = list(reversed(parts))
    for (start, end) in parts:
        if start > end:
            start, end = end, start
        for pos in range(start, end + 1):
            if pos not in pos_to_col:
                raise KeyError(
                    f"Genomic position {pos} not present in mapping. "
                    f"Include a reference row ({ACCESSION}) or use an ungapped alignment."
                )
            cols.append(pos_to_col[pos])
    if feature.strand == -1 and revcomp_neg:
        cols = cols[::-1]
    return cols

def _revcomp(seq: str) -> str:
    comp = str.maketrans("ACGTUacgtu", "TGCAAtgcaa")
    return "".join('-' if ch == '-' else ch.translate(comp) for ch in seq[::-1])

def feature_ref_coding_seq(record, feature: FeatureSlice, revcomp_neg: bool) -> str:
    """Get the reference coding sequence for a feature in coding orientation."""
    parts = feature.intervals
    if feature.strand == -1 and revcomp_neg:
        parts = list(reversed(parts))
    ref = "".join(str(record.seq[s-1:e]) for s, e in parts).upper()
    if feature.strand == -1 and revcomp_neg:
        ref = _revcomp(ref).replace("-", "")
    return ref.replace("-", "")

def slice_alignment_columns(msa: List[Tuple[str, str]], columns: List[int],
                            feature: FeatureSlice, revcomp_neg: bool) -> List[Tuple[str, str]]:
    comp = str.maketrans("ACGTUacgtu", "TGCAAtgcaa")
    out = []
    for sid, seq in msa:
        subseq = ''.join(seq[c] for c in columns)
        if feature.strand == -1 and revcomp_neg:
            subseq = ''.join('-' if ch == '-' else ch.translate(comp) for ch in subseq[::-1])
        out.append((sid, subseq))
    return out

def write_fasta(records: List[Tuple[str, str]], path: Path, header_suffix: str = "") -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as fh:
        for sid, s in records:
            header = f"{sid} {header_suffix}".strip()
            fh.write(f">{header}\n")
            for i in range(0, len(s), 80):
                fh.write(s[i:i+80] + "\n")

def write_simple_fasta(records: List[Tuple[str, str]], path: Path) -> None:
    """Write a simple FASTA (no suffixes), for the filtered MSA snapshot."""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as fh:
        for sid, s in records:
            fh.write(f">{sid}\n")
            for i in range(0, len(s), 80):
                fh.write(s[i:i+80] + "\n")

# --------------------- Main ---------------------
def main(
    msa_path: Path = MSA_PATH,
    feature_level: str = FEATURE_LEVEL,
    output_dir: Path = OUTPUT_DIR,
    accession: str = ACCESSION,
    local_genbank: Optional[Path] = LOCAL_GENBANK,
    email: Optional[str] = NCBI_EMAIL,
    api_key: Optional[str] = NCBI_API_KEY,
    revcomp_neg: bool = REVCOMP_NEG_STRAND,
):
    # Fetch GenBank to learn expected genome length
    record = fetch_or_read_genbank(accession, local_file=local_genbank, email=email, api_key=api_key)
    expected_len = len(record.seq)

    # Read MSA, then filter by exact length and by N-count (if configured)
    msa_all = read_fasta(msa_path)
    kept: List[Tuple[str, str]] = []
    filtered_out: List[Tuple[str, int, int, str]] = []  # (id, observed_len, expected_len, reasons)

    for sid, seq in msa_all:
        reasons = []
        obs_len = len(seq)
        if obs_len != expected_len:
            reasons.append("length_mismatch")
        if MAX_N_COUNT is not None and seq.upper().count("N") > MAX_N_COUNT:
            reasons.append(f"too_many_N(>{MAX_N_COUNT})")
        if reasons:
            filtered_out.append((sid, obs_len, expected_len, ",".join(reasons)))
        else:
            kept.append((sid, seq))

    # Report filtering
    print(f"[Filter] GenBank expected length = {expected_len}")
    print(f"[Filter] Input sequences: {len(msa_all)}")
    print(f"[Filter] Kept: {len(kept)}")
    print(f"[Filter] Dropped: {len(filtered_out)}")
    if filtered_out:
        print("[Filter] Dropped (id  observed_len  expected_len  reasons) — first 20:")
        for sid, ol, el, reason in filtered_out[:20]:
            print(f"  {sid}\t{ol}\t{el}\t{reason}")
        if len(filtered_out) > 20:
            print(f"  ... ({len(filtered_out) - 20} more)")
    output_dir.mkdir(parents=True, exist_ok=True)
    filter_report_path = output_dir / "_filter_report.tsv"
    with open(filter_report_path, "w", newline="") as fh:
        w = csv.writer(fh, delimiter="\t")
        w.writerow(["sequence_id", "observed_length", "expected_length", "reasons"])
        for sid, ol, el, reason in filtered_out:
            w.writerow([sid, ol, el, reason])
    print(f"[Filter] Report written: {filter_report_path}")

    if WRITE_FILTERED_MSA:
        filtered_msa_path = output_dir / f"{msa_path.stem}.len{expected_len}.filtered.fa"
        write_simple_fasta(kept, filtered_msa_path)
        print(f"[Filter] Filtered MSA written: {filtered_msa_path}")

    if not kept:
        raise RuntimeError("No sequences remain after filtering; aborting.")

    # Build position→column mapping
    ref_idx = find_reference_row(kept, accession)
    if ref_idx is not None:
        pos_to_col = build_pos_to_col_map_from_reference(kept[ref_idx][1])
    else:
        if alignment_has_any_gaps(kept):
            raise RuntimeError(
                "No row containing the accession found and gaps detected in remaining alignment; "
                "cannot safely map genome positions to alignment columns.\n"
                f"Either include a reference row for {accession} in the MSA or ensure the alignment is ungapped."
            )
        pos_to_col = build_pos_to_col_map_identity(expected_len)

    # Extract features & write outputs (using kept sequences)
    features = extract_features(record, level=feature_level)
    index_rows = []
    for feat in features:
        try:
            cols = collect_columns_for_feature(pos_to_col, feat, revcomp_neg=revcomp_neg)
        except KeyError as ke:
            print(f"Skipping feature {feat.name}: {ke}")
            continue

        # Optionally trim trailing stop codon for CDS features, based on the reference CDS
        trimmed = False
        if TRIM_TRAILING_STOP and feat.kind == "CDS":
            ref_cds = feature_ref_coding_seq(record, feat, revcomp_neg=revcomp_neg)
            if len(ref_cds) >= 3 and ref_cds[-3:] in STOP_CODONS:
                cols = cols[:-3]
                trimmed = True

        sliced = slice_alignment_columns(kept, cols, feat, revcomp_neg=revcomp_neg)
        start, end = feat.intervals[0][0], feat.intervals[-1][1]
        suffix_parts = [f"{feat.kind}={feat.name}", f"strand={feat.strand}", f"coord={start}-{end}"]
        if trimmed:
            suffix_parts.append("trimmed_stop=1")
        suffix = ";".join(suffix_parts)
        out_name = f"{sanitize_name(feat.kind)}__{sanitize_name(feat.name)}__{start}_{end}.fa"
        out_path = output_dir / out_name
        write_fasta(sliced, out_path, header_suffix=suffix)
        index_rows.append({
            "name": feat.name,
            "kind": feat.kind,
            "strand": feat.strand,
            "intervals": ";".join([f"{s}-{e}" for s, e in feat.intervals]),
            "trimmed_stop": int(trimmed),
            "output_fasta": str(out_path),
        })

    idx_path = output_dir / "_index.tsv"
    with open(idx_path, "w", newline="") as fh:
        w = csv.DictWriter(
            fh,
            fieldnames=["name", "kind", "strand", "intervals", "trimmed_stop", "output_fasta"],
            delimiter="\t"
        )
        w.writeheader()
        for row in index_rows:
            w.writerow(row)

    print(f"Wrote {len(index_rows)} feature-alignments to {output_dir}")
    print(f"Index: {idx_path}")

if __name__ == "__main__":
    main()


[Filter] GenBank expected length = 15894
[Filter] Input sequences: 226
[Filter] Kept: 164
[Filter] Dropped: 62
[Filter] Dropped (id  observed_len  expected_len  reasons) — first 20:
  SRR30155681	15894	15894	too_many_N(>100)
  SRR30155687	15894	15894	too_many_N(>100)
  SRR30155690	15894	15894	too_many_N(>100)
  SRR30155695	15894	15894	too_many_N(>100)
  SRR30155708	15894	15894	too_many_N(>100)
  SRR30155713	15894	15894	too_many_N(>100)
  SRR30155720	15894	15894	too_many_N(>100)
  SRR30155721	15894	15894	too_many_N(>100)
  SRR30155727	15894	15894	too_many_N(>100)
  SRR30155731	15894	15894	too_many_N(>100)
  SRR30155732	15894	15894	too_many_N(>100)
  SRR30155736	15894	15894	too_many_N(>100)
  SRR30155738	15894	15894	too_many_N(>100)
  SRR30155739	15894	15894	too_many_N(>100)
  SRR30155740	15894	15894	too_many_N(>100)
  SRR30155743	15894	15894	too_many_N(>100)
  SRR30155746	15894	15894	too_many_N(>100)
  SRR30155748	15894	15894	too_many_N(>100)
  SRR30155749	15894	15894	too_many_N(>100)
 