In [None]:
import os
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import cast, List, Optional, Dict
from result import Ok, Err, Result
from strictyaml import load, Map, YAMLError, Str  # , Str, Int  # type: ignore
import polars as pl
from pprint import pprint

In [318]:
def construct_file_list(
    results_dir: Path, glob_pattern: Path
) -> Result[List[Path], str]:
    """
        Function `construct_file_list()` uses the provided path of wildcards
        to expand out all available files to be compiled downstream.

    Args:
        - `results_dir: Path`: A Pathlib path instance recording the results
        "root" directory, which is the top of the READ-ZAP results hierarchy.
        - `ivar_pattern: Path`: A Pathlib path instance containing wildcards
        that can be expanded to the desired files.

    Returns:
        - `Result[List[Path], str]`: A Result type instance containing either
        a list of paths to the desired files, or an error message string.
    """

    # collect a list of all the files to search
    files_to_query = list(results_dir.glob(str(glob_pattern)))

    # make sure that there aren't duplicates
    try:
        set(files_to_query)
    except ValueError as message:
        return Err(f"Redudant files present that may corrupt results:\n{message}")

    if len(files_to_query) == 0:
        return Err(
            f"No files found that match the wildcard path:\n{glob_pattern}\nwithin {results_dir}"
        )

    return Ok(files_to_query)


def _check_cleanliness(path_parts: tuple[str, ...]) -> bool:
    """ """

    for part in path_parts:
        if "._" in part:
            return False
    return True


def compile_data_with_io(file_list: List[Path]) -> Result[pl.LazyFrame, str]:
    """
        Function `compile_data_with_io()` takes the list of paths and
        reads each file, parsing it with Polars, and writing it into one
        large temporary TSV file. This method for compiling all files into
        one involves a great deal of read-write, but it also avoids potential
        type mismatch issues between a type schema inferred for one dataframe
        and the type schema inferred for the next. Downstream, the new
        TSV can be parsed into a single dataframe where it's possible to
        infer a type scheme from many rows.

    Args:
        - `file_list: List[Path]`: A list of paths, where each path points to
        a TSV file generated by iVar.

    Returns:
        - `pl.LazyFrame`: A Polars LazyFrame to be queries and transformed
        downstream.
    """

    # Double check that a tempfile from a previous run isn't present
    if os.path.isfile("tmp.tsv"):
        os.remove("tmp.tsv")
        os.remove("contigs_long_table.arrow")

    if len(file_list) == 0:
        return Err("No files found to compile data from.")

    # compile all tables into one large temporary table
    for i, file in enumerate(file_list):
        with open("tmp.tsv", "a", encoding="utf-8") as temp:
            # Parse out information in the file path to add into the dataframe
            # NOTE: this hardcoding will eventually be replaced with config params
            amplicon = str(file.parent).split("results/amplicon_")[1].split("/")[0]
            simplename = os.path.basename(file).replace(".tsv", "")
            sample_id = simplename.split("_")[0]
            contig = simplename.split("_")[-1]
            if len(sample_id) == 1:
                continue

            # quick test of whether the header should be written. It is only written
            # the first time
            write_header = bool(i == 0)

            # Read the csv, modify it, and write it onto the growing temporary tsv
            pl.read_csv(
                file, separator="\t", raise_if_empty=True, null_values=["NA", ""]
            ).with_columns(
                pl.lit(sample_id).alias("Sample ID"),
                pl.lit(amplicon).alias("Amplicon"),
                pl.lit(contig).alias("Contig"),
                pl.lit(f"{amplicon}-{sample_id}-{contig}").alias(
                    "Amplicon-Sample-Contig"
                ),
            ).write_csv(temp, separator="\t", include_header=write_header)

    # lazily scan the new tmp tsv for usage downstream
    pl.scan_csv("tmp.tsv", separator="\t", infer_schema_length=1500).sort(
        "POS"
    ).sink_ipc("contigs_long_table.arrow", compression="zstd")

    all_contigs = pl.scan_ipc("contigs_long_table.arrow").lazy()
    os.remove("tmp.tsv")

    return Ok(all_contigs)


def _try_parse_int(value: str) -> Optional[int]:
    """
    Helper function that handles the possibility that a read support
    cannot be parsed as an integer from the FASTA defline and returns
    `None` instead of raising an unrecoverable error.
    """
    try:
        return int(value)
    except ValueError:
        return None


def _try_parse_identifier(defline: str, amplicon: str) -> Optional[str]:
    """ """

    items = defline.split("_")
    sample_id = items[1]
    (contig,) = [item for item in items if "contig" in item]

    identifier = f"{amplicon}-{sample_id}-{contig}"

    return identifier


def _is_valid_utf8(fasta_line: str) -> bool:
    """ """
    try:
        fasta_line.encode("utf-8").decode("utf-8")
        return True
    except UnicodeDecodeError:
        return False


def generate_seq_dict(
    fasta_path: Path, input_fasta: List[str], split_char: str
) -> Optional[Dict[Optional[str], Optional[int]]]:
    """
    Placeholder
    """

    # make sure the lines can be decoded
    decodable = [_is_valid_utf8(line) for line in input_fasta]
    if False in decodable:
        return None

    (amplicon,) = [
        item.replace("amplicon_", "")
        for item in os.path.normpath(fasta_path).split(os.sep)
        if "amplicon" in item
    ]
    deflines = [line for line in input_fasta if line.startswith(">")]
    supports = [
        _try_parse_int(line.split(split_char)[-1])
        for line in input_fasta
        if line.startswith(">")
    ]
    identifiers = [_try_parse_identifier(defline, amplicon) for defline in deflines]

    assert len(deflines) == len(
        identifiers
    ), "Mismatch between the number of deflines and number of sequences"

    seq_dict = dict(zip(identifiers, supports))

    return seq_dict


def compile_contig_depths(fasta_list: List[Path]) -> pl.LazyFrame:
    """
    Placeholder
    """

    seq_dicts = []

    for fasta in fasta_list:
        with open(fasta, "r", encoding="utf-8") as fasta_contents:
            try:
                fasta_lines = fasta_contents.readlines()
            except UnicodeDecodeError:
                print(
                    f"The FASTA at the following path could not be decoded to utf-8:\n{fasta}"
                )
                continue
            seq_dict = generate_seq_dict(fasta, fasta_lines, "_")
            if seq_dict is None:
                print(
                    f"The FASTA at the following path could not be decoded to utf-8:\n{fasta}"
                )
                continue
            seq_dicts.append(seq_dict)

    identifiers = [list(d.keys())[0] for d in seq_dicts]
    supports = [list(d.values())[0] for d in seq_dicts]

    depth_df = pl.LazyFrame(
        {"Amplicon-Sample-Contig": identifiers, "Depth of Coverage": supports}
    )

    return depth_df


def generate_gene_df(gene_bed: Path) -> pl.LazyFrame:
    """
    Placeholder
    """

    gene_df = (
        pl.scan_csv(gene_bed, separator="\t", has_header=False, new_columns=["Ref", "Start Position", "Stop Position", "NAME", "INDEX", "SENSE", "Gene"])
        .with_columns(
            pl.col("NAME")
            .str.replace("_RIGHT", "")
            .str.replace("_LEFT", "")
            .alias("Amplicon")
        )
        .unique(subset="Amplicon", keep="first", maintain_order=True)
        .drop(["Ref", "NAME", "Stop Position", "INDEX", "SENSE"])
        .join(
            pl.scan_csv(gene_bed, separator="\t", has_header=False, new_columns=["Ref", "Start Position", "Stop Position", "NAME", "INDEX", "SENSE", "Gene"])
            .with_columns(
                pl.col("NAME")
                .str.replace("_RIGHT", "")
                .str.replace("_LEFT", "")
                .alias("Amplicon")
            )
            .unique(subset="NAME", keep="last", maintain_order=True)
            .drop(["Ref", "Start Position", "NAME", "Gene", "INDEX", "SENSE"]),
            on="Amplicon", how="left"
        )
    )

    return gene_df


def aggregate_haplotype_df(long_contigs: pl.LazyFrame, gene_bed: Path) -> pl.LazyFrame:
    """
    Placeholder
    """

    # construct data frame mapping genes to amplicons
    gene_df = generate_gene_df(gene_bed)

    # construct long dataframe
    long_df = (
        long_contigs.select(
            [
                "REGION",
                "POS",
                "REF",
                "ALT",
                "REF_AA",
                "ALT_CODON",
                "ALT_AA",
                "Amplicon",
                "Sample ID",
                "Contig",
                "Amplicon-Sample-Contig",
            ]
        )
        .with_columns(
            pl.concat_str(
                [pl.col("REF"), pl.col("POS"), pl.col("ALT")], separator="-"
            ).alias("NUC_SUB")
        )
        .drop(["REF", "POS", "ALT"])
        .join(gene_df, how="left", on="Amplicon")
        .with_columns(
            pl.concat_str(
                [pl.col("REF_AA"), pl.col("ALT_CODON"), pl.col("ALT_AA")], separator="-"
            ).alias("AA_SUB")
        )
        .with_columns(
            pl.concat_str([pl.col("Gene"), pl.col("AA_SUB")], separator=": ").alias(
                "AA_SUB"
            )
        )
        .with_columns(
            (pl.col("ALT_AA").is_null() & pl.col("REF_AA").is_null()).alias("Noncoding")
        )
        .with_columns(
            (
                (pl.col("ALT_AA") == pl.col("REF_AA")) & (pl.col("Noncoding") == False)
            ).alias("Synonymous")
        )
        .with_columns(
            (
                (pl.col("ALT_AA") != pl.col("REF_AA")) & (pl.col("Noncoding") == False)
            ).alias("Nonsynonymous")
        )
        .drop(["REF_AA", "ALT_AA", "ALT_CODON"])
    )

    # aggregate into short dataframe
    short_df = (
        long_df.unique(subset="Amplicon-Sample-Contig", maintain_order=True)
        .select(["Amplicon-Sample-Contig"])
        .join(
            long_df.select(
                [
                    "Amplicon",
                    "Sample ID",
                    "Amplicon-Sample-Contig"
                ]
            ).with_columns(
                pl.concat_str(
                [pl.col("Amplicon"), pl.col("Sample ID")], separator="-"
            ).alias("Amplicon-Sample")
            ),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(gene_df, how="left", on="Amplicon")
        .join(
            long_df.select(["Amplicon-Sample-Contig", "NUC_SUB"])
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("NUC_SUB"))
            .with_columns(
                pl.col("NUC_SUB").list.join(", ").alias("Nucleotide Substitutions")
            )
            .drop("NUC_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(
            long_df.select(["Amplicon-Sample-Contig", "NUC_SUB"])
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("NUC_SUB").count())
            .with_columns(pl.col("NUC_SUB").alias("Nuc Mut Count"))
            .drop("NUC_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(
            long_df.select(["Amplicon-Sample-Contig", "AA_SUB", "Synonymous"])
            .filter(pl.col("Synonymous"))
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("AA_SUB"))
            .with_columns(
                pl.col("AA_SUB").list.join(", ").alias("Synonymous Mutations")
            )
            .drop("AA_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(
            long_df.select(["Amplicon-Sample-Contig", "AA_SUB", "Synonymous"])
            .filter(pl.col("Synonymous"))
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("AA_SUB").count())
            .with_columns(pl.col("AA_SUB").alias("Syn count"))
            .drop("AA_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(
            long_df.select(["Amplicon-Sample-Contig", "AA_SUB", "Nonsynonymous"])
            .filter(pl.col("Nonsynonymous"))
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("AA_SUB"))
            .with_columns(
                pl.col("AA_SUB").list.join(", ").alias("Nonsynonymous Mutations")
            )
            .drop("AA_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
        .join(
            long_df.select(["Amplicon-Sample-Contig", "AA_SUB", "Nonsynonymous"])
            .filter(pl.col("Nonsynonymous"))
            .group_by("Amplicon-Sample-Contig", maintain_order=True)
            .agg(pl.col("AA_SUB").count())
            .with_columns(pl.col("AA_SUB").alias("Nonsyn count"))
            .drop("AA_SUB"),
            on="Amplicon-Sample-Contig",
            how="left",
        )
    )

    return short_df


def compute_crude_NS_ratio(short_df: pl.LazyFrame) -> pl.LazyFrame:
    """
    Placeholder
    """

    new_df = short_df.with_columns(
        (
            (
                (
                    pl.col("Nonsyn count")
                    / (pl.col("Stop Position") - pl.col("Start Position"))
                )
                / (
                    pl.col("Syn count")
                    / (pl.col("Stop Position") - pl.col("Start Position"))
                )
            ).alias("Crude N/S Ratio")
        )
    )

    return new_df

In [319]:
ivar_list_result = construct_file_list(results_dir, ivar_pattern)

In [320]:
ivar_list = ivar_list_result.unwrap()

In [321]:
clean_ivar_list = [file for file in ivar_list if _check_cleanliness(file.parts)]

In [322]:
all_contigs_result = compile_data_with_io(clean_ivar_list)
all_contigs = all_contigs_result.unwrap()

In [323]:
short_df = aggregate_haplotype_df(all_contigs, gene_bed)

In [324]:
fasta_list_result = construct_file_list(results_dir, fasta_pattern)

In [325]:
clean_fasta_list = [
    file for file in fasta_list_result.unwrap() if _check_cleanliness(file.parts)
]

In [326]:
depth_df = compile_contig_depths(clean_fasta_list)

In [327]:
final_df = compute_crude_NS_ratio(
    short_df.join(depth_df, on="Amplicon-Sample-Contig", how="left")
)

In [335]:
(
    final_df.collect()
    .partition_by("Sample ID", "Amplicon", maintain_order=True)
)



[shape: (111, 15)
 ┌────────────┬────────────┬────────────┬────────────┬───┬────────────┬────────┬────────────┬───────┐
 │ Amplicon-S ┆ Amplicon   ┆ Sample ID  ┆ Amplicon-S ┆ … ┆ Nonsynonym ┆ Nonsyn ┆ Depth of   ┆ Crude │
 │ ample-Cont ┆ ---        ┆ ---        ┆ ample      ┆   ┆ ous        ┆ count  ┆ Coverage   ┆ N/S   │
 │ ig         ┆ str        ┆ str        ┆ ---        ┆   ┆ Mutations  ┆ ---    ┆ ---        ┆ Ratio │
 │ ---        ┆            ┆            ┆ str        ┆   ┆ ---        ┆ u32    ┆ i64        ┆ ---   │
 │ str        ┆            ┆            ┆            ┆   ┆ str        ┆        ┆            ┆ f64   │
 ╞════════════╪════════════╪════════════╪════════════╪═══╪════════════╪════════╪════════════╪═══════╡
 │ QIAseq_221 ┆ QIAseq_221 ┆ SRR2101960 ┆ QIAseq_221 ┆ … ┆ null       ┆ null   ┆ 21         ┆ null  │
 │ -SRR210196 ┆            ┆ 7          ┆ -SRR210196 ┆   ┆            ┆        ┆            ┆       │
 │ 07-contig1 ┆            ┆            ┆ 07         ┆   ┆      