In [2]:
from enum import Enum
from typing import Optional, Tuple
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
import re

In [3]:
table_files = map(Path, snakemake.input.tables)
# table_files = Path("../../results/mutation_concordance/mykrobe/").rglob("*.csv")

In [4]:
class Calls(Enum):
    Ref = "REF"
    Alt = "ALT"
    Null = "NULL"
    Minor = "HET"
    Filtered = "FILT"


class Classification(Enum):
    TruePositive = "TP"
    FalsePositive = "FP"
    TrueNegative = "TN"
    FalseNegative = "FN"

    def __str__(self) -> str:
        return self.value

    def __lt__(self, other):
        return str(self) < str(other)

    @staticmethod
    def from_pair(y: Calls, y_hat: Calls) -> "Classification":
        return {
            (Calls.Ref, Calls.Ref): Classification.TrueNegative,
            (Calls.Alt, Calls.Alt): Classification.TruePositive,
            (Calls.Ref, Calls.Alt): Classification.FalsePositive,
            (Calls.Alt, Calls.Ref): Classification.FalseNegative,
            (Calls.Minor, Calls.Alt): Classification.TruePositive,
            (Calls.Minor, Calls.Ref): Classification.FalseNegative,
            (Calls.Minor, Calls.Null): Classification.FalseNegative,
            (Calls.Null, Calls.Alt): Classification.FalsePositive,
            (Calls.Null, Calls.Ref): Classification.FalseNegative,
            (Calls.Null, Calls.Null): Classification.TrueNegative,
            (Calls.Ref, Calls.Null): Classification.FalseNegative,
            (Calls.Alt, Calls.Null): Classification.FalseNegative,
        }[(y, y_hat)]


class Classifier:
    def __init__(self, treat_minor_as: str = "HET", treat_null_as: str = "NULL"):
        self.minor = treat_minor_as
        self.null = treat_null_as

    def convert(self, call: str) -> str:
        return {
            "HET": self.minor,
            "NULL": self.null
        }.get(call, call)

    def classify(self, illumina_call: str, nanopore_call: str) -> Optional[Classification]:
        illumina_call = Calls(self.convert(illumina_call))
        nanopore_call = Calls(self.convert(nanopore_call))
        if Calls.Filtered in (illumina_call, nanopore_call):
            return None
        else:
            return Classification.from_pair(illumina_call, nanopore_call)


@dataclass
class ConfusionMatrix:
    tp: int = 0
    tn: int = 0
    fp: int = 0
    fn: int = 0

    def ravel(self) -> Tuple[int, int, int, int]:
        """Return the matrix as a flattened tuple.
        The order of return is TN, FP, FN, TP
        """
        return self.tn, self.fp, self.fn, self.tp

    def precision(self) -> float:
        """Also known as positive predictive value (PPV)"""
        return self.tp / (self.tp + self.fp)

    def recall(self) -> float:
        """Also known as true positive rate (TPR)"""
        return self.tp / (self.tp + self.fn)

    def fnr(self) -> float:
        """False negative rate"""
        return 1 - self.recall()

    def fpr(self) -> float:
        "False positive rate"
        return self.fp / (self.fp + self.tn)

    @staticmethod
    def from_counter(c: Counter) -> "ConfusionMatrix":
        return ConfusionMatrix(
            tp=c[Classification.TruePositive],
            tn=c[Classification.TrueNegative],
            fp=c[Classification.FalsePositive],
            fn=c[Classification.FalseNegative]
        )

In [5]:
# classifier = Classifier(treat_minor_as="REF", treat_null_as="FILT")
classifier = Classifier(treat_minor_as=snakemake.params.treat_minor_as, treat_null_as=snakemake.params.treat_null_as)

In [13]:
outstream = open(snakemake.output[0], "w")
data = {}
rgx = re.compile(r"-?\d+")
indel_counts = {
    Classification.FalseNegative: 0,
    Classification.FalsePositive: 0,
    Classification.TruePositive: 0,
    Classification.TrueNegative: 0
}
for file in table_files:
    sample = file.name.split(".")[0]
    with open(file) as f:
        _ = next(f)
        c = Counter()
        for row in f:
            mut, illumina_call, nanopore_call = row.rstrip().split(",")
            ref, alt = rgx.split(mut.split("-")[-1])
            is_indel = len(ref) != len(alt)
            clf = classifier.classify(illumina_call, nanopore_call)
            if clf is None:
                continue
            elif clf in (Classification.FalseNegative, Classification.FalsePositive):
                print(f"{mut}\t{sample}\t{str(clf)}", file=outstream)
            if is_indel:
                indel_counts[clf] += 1
            c[clf] += 1
        data[sample] = ConfusionMatrix.from_counter(c)

In [None]:
fps = 0
fns = 0
for s, cm in data.items():
    fps += cm.fp
    fns += cm.fn

In [None]:
print(f"Total of {fps} FPs and {fns} FNs across all samples", file=outstream)
print(f"{indel_counts[Classification.FalsePositive]} FPs, {indel_counts[Classification.FalseNegative]} FNs, and {indel_counts[Classification.TruePositive]} TPs are indels")
outstream.close()