In [40]:
import pandas as pd
import json
from pathlib import Path
import numpy as np
from enum import Enum
from typing import Tuple
from dataclasses import dataclass
from collections import Counter, defaultdict
import re

In [39]:
def split_var_name(name: str) -> Tuple[str, int, str]:
    items = re.match(r"([A-Z]+)([-0-9]+)([A-Z/\*]+)", name, re.I).groups()
    return items[0], int(items[1]), items[2]

In [2]:
csvs = list(Path("../../results/amr_predictions/").glob("*.csv"))
qc = pd.read_csv("../../results/qc.csv")
cov_threshold = 15
phenotypes = pd.read_csv("../../docs/samplesheet.csv", index_col="run")
ignore_drugs = {"Ciprofloxacin", "Ofloxacin"}
minor_is_susceptible = False
panel_names = {
    "previous": "Walker et al.",
    "who2021": "WHO only",
    "hunt2019": "Mykrobe",
    "hall2022": "Combined"
}
who_results = pd.read_csv("../../docs/who-results.csv")
CONF = 0.95

In [3]:
frames = []
for p in csvs:
    frames.append(pd.read_csv(p))

df = pd.concat(frames).reset_index(drop=True)

In [4]:
passed_qc = set(qc.query("coverage >= @cov_threshold")["run"])
len(passed_qc)

8160

In [5]:
df = df.query("run in @passed_qc")

In [6]:
class Prediction(Enum):
    Resistant = "R"
    Susceptible = "S"
    MinorResistance = "r"

    def __str__(self) -> str:
        return self.value
    
class Classification(Enum):
    TruePositive = "TP"
    FalsePositive = "FP"
    TrueNegative = "TN"
    FalseNegative = "FN"

    def __str__(self) -> str:
        return self.value
    
class Classification(Enum):
    TruePositive = "TP"
    FalsePositive = "FP"
    TrueNegative = "TN"
    FalseNegative = "FN"

    def __str__(self) -> str:
        return self.value


class Classifier:
    def __init__(
        self,
        minor_is_susceptible: bool = False,
    ):
        self.minor_is_susceptible = minor_is_susceptible
        self.susceptible = {Prediction.Susceptible}
        self.resistant = {Prediction.Resistant}
        if self.minor_is_susceptible:
            self.susceptible.add(Prediction.MinorResistance)
        else:
            self.resistant.add(Prediction.MinorResistance)


    def from_predictions(
        self, y_true: Prediction, y_pred: Prediction
    ) -> Classification:
        if y_true in self.susceptible:
            expected_susceptible = True
        elif y_true in self.resistant:
            expected_susceptible = False
        else:
            raise NotImplementedError(f"Don't know how to classify {y_true} calls yet")

        if y_pred in self.susceptible:
            called_susceptible = True
        elif y_pred in self.resistant:
            called_susceptible = False
        else:
            raise NotImplementedError(f"Don't know how to classify {y_pred} calls yet")

        if expected_susceptible and called_susceptible:
            return Classification.TrueNegative
        elif expected_susceptible and not called_susceptible:
            return Classification.FalsePositive
        elif not expected_susceptible and not called_susceptible:
            return Classification.TruePositive
        else:
            return Classification.FalseNegative

@dataclass
class ConfusionMatrix:
    tp: int = 0
    tn: int = 0
    fp: int = 0
    fn: int = 0

    def ravel(self) -> Tuple[int, int, int, int]:
        """Return the matrix as a flattened tuple.
        The order of return is TN, FP, FN, TP
        """
        return self.tn, self.fp, self.fn, self.tp

    def as_matrix(self) -> np.ndarray:
        """Returns a 2x2 matrix [[TN, FP], [FN, TP]]"""
        return np.array([[self.tn, self.fp], [self.fn, self.tp]])

    def num_positive(self) -> int:
        """Number of TPs and FNs - i.e. actual condition positive"""
        return self.tp + self.fn

    def num_negative(self) -> int:
        """Number of TNs and FPs - i.e. actual condition negative"""
        return self.tn + self.fp

    def ppv(self) -> Tuple[float, float, float]:
        """Also known as precision"""
        try:
            ppv = self.tp / (self.tp + self.fp)
            lwr_bound, upr_bound = confidence_interval(n_s=self.tp, n_f=self.fp)
            return ppv, lwr_bound, upr_bound
        except ZeroDivisionError:
            return [None, None, None]

    def npv(self) -> Tuple[float, float, float]:
        """Negative predictive value"""
        try:
            npv = self.tn / (self.tn + self.fn)
            lwr_bound, upr_bound = confidence_interval(n_s=self.tn, n_f=self.fn)
            return npv, lwr_bound, upr_bound
        except ZeroDivisionError:
            return [None, None, None]

    def sensitivity(self) -> Tuple[float, float, float]:
        """Also known as recall and true positive rate (TPR)"""
        try:
            sn =  self.tp / self.num_positive()
            lwr_bound, upr_bound = confidence_interval(n_s=self.tp, n_f=self.fn)
            return sn, lwr_bound, upr_bound
        except ZeroDivisionError:
            return None, None, None

    def specificity(self) -> Tuple[float, float, float]:
        """Also known as selectivity and true negative rate (TNR)"""
        try:
            sp = self.tn / self.num_negative()
            lwr_bound, upr_bound = confidence_interval(n_s=self.tn, n_f=self.fp)
            return sp, lwr_bound, upr_bound
        except ZeroDivisionError:
            return None, None, None

    def fnr(self) -> Tuple[float, float, float]:
        """False negative rate or VME (very major error rate)"""
        try:
            fnr = self.fn / self.num_positive()
            lwr_bound, upr_bound = confidence_interval(n_s=self.fn, n_f=self.tp)
            return fnr, lwr_bound, upr_bound
        except ZeroDivisionError:
            return [None, None, None]

    def fpr(self) -> Tuple[float, float, float]:
        """False positive rate or ME (major error rate)"""
        try:
            fpr = self.fp / self.num_negative()
            lwr_bound, upr_bound = confidence_interval(n_s=self.fp, n_f=self.tn)
            return fpr, lwr_bound, upr_bound
        except ZeroDivisionError:
            return [None, None, None]

    def f_score(self, beta: float = 1.0) -> float:
        """Harmonic mean of precision and recall.
        When beta is set to 0, you get precision. When beta is set to 1, you get the
        unweighted F-score which is the harmonic mean of precision and recall. Setting
        beta to 2 weighs recall twice as much as precision. Setting beta to 0.5 weighs
        precision twice as much as recall.
        """
        ppv = self.precision()
        tpr = self.recall()
        if ppv is None or tpr is None:
            return None
        beta2 = beta ** 2

        return ((beta2 + 1) * ppv * tpr) / ((beta2 * ppv) + tpr)

    @staticmethod
    def from_series(s: pd.Series) -> "ConfusionMatrix":
        tp = s.get("TP", 0)
        fp = s.get("FP", 0)
        fn = s.get("FN", 0)
        tn = s.get("TN", 0)
        return ConfusionMatrix(tp=tp, fn=fn, fp=fp, tn=tn)


def confidence_interval(n_s: int, n_f: int, conf: float = CONF) -> Tuple[float, float]:
    """Calculate the Wilson score interval.
    Equation take from https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval
    n_s: Number of successes or, in the case of confusion matrix statistics, the numerator
    n_f: Number of failures or, in the case of confusion matrix statistics, the denominator minus the numerator
    conf: the confidence level. i.e. 0.95 is 95% confidence
    """
    n = n_f + n_s
    z = stats.norm.ppf(1 - (1 - conf) / 2)  # two-sided
    z2 = z ** 2
    nz2 = n + z2
    A = (n_s + (0.5 * z2)) / nz2
    B = z / nz2
    C = sqrt(((n_s * n_f) / n) + (z2 / 4))
    CI = B * C
    return A - CI, A + CI


def round_up_to_base(x, base=10):
    return int(x + (base - x) % base)


def round_down_to_base(x, base=10):
    return int(x - (x % base))

In [7]:
df = df.query("drug not in @ignore_drugs")

In [8]:
classifier = Classifier(minor_is_susceptible=minor_is_susceptible)
classifications = []
for _, row in df.iterrows():
    drug = row["drug"].lower()
    y_true = phenotypes.at[row["run"], drug]
    if pd.isna(y_true):
        classifications.append(None)
    else:
        y_true = Prediction(y_true)
        y_pred = Prediction(row["prediction"])
        clf = classifier.from_predictions(y_true, y_pred)
        classifications.append(str(clf))

In [9]:
df["classification"] = classifications

In [10]:
df

Unnamed: 0,run,biosample,bioproject,panel,drug,prediction,classification
0,ERR036186,SAMEA897802,PRJEB2358,who2021,Delamanid,S,
1,ERR036186,SAMEA897802,PRJEB2358,who2021,Kanamycin,S,
2,ERR036186,SAMEA897802,PRJEB2358,who2021,Amikacin,S,
3,ERR036186,SAMEA897802,PRJEB2358,who2021,Ethambutol,S,TN
4,ERR036186,SAMEA897802,PRJEB2358,who2021,Ethionamide,S,
...,...,...,...,...,...,...,...
324499,SRR7131298,SAMN09090624,PRJNA438921,hall2022,Levofloxacin,S,
324500,SRR7131298,SAMN09090624,PRJNA438921,hall2022,Pyrazinamide,S,
324501,SRR7131298,SAMN09090624,PRJNA438921,hall2022,Linezolid,S,
324502,SRR7131298,SAMN09090624,PRJNA438921,hall2022,Rifampicin,S,TN


Now, let's look at cases where, for Isoniazid, the WHO catalogue calls S and the default Mykrobe catalogue calls R - i.e., stuff "missing" from the WHO catalogue

In [72]:
drug = "Isoniazid"
who_fns = set(df.query("panel == 'who2021' and classification == 'FN'")[["run", "drug"]].to_records(index=False).tolist())
# who_fns = set()
len(who_fns)

1926

In [74]:
mykrobe_tps = set()
for _, row in df.query("panel == 'hunt2019' and classification != 'FN'").iterrows():
    ix = (row["run"], row["drug"])
    if ix in who_fns:
        mykrobe_tps.add(ix)
len(mykrobe_tps)

440

This means 382/721 (~53%) of the WHO INH FNs were called TP by Mykrobe's default catalogue.

Let's get a list of all of these mutations

In [76]:
mykrobe_jsons = dict()
for p in Path("../../results/amr_predictions/hunt2019/").rglob("*.mykrobe.json"):
    run = p.name.split(".")[0]
    mykrobe_jsons[run] = p
        
len(mykrobe_jsons)

8611

In [91]:
who_jsons = dict()
for p in Path("../../results/amr_predictions/who2021/").rglob("*.mykrobe.json"):
    run = p.name.split(".")[0]
    who_jsons[run] = p
        
len(who_jsons)

8611

In [99]:
all_variants = []
for run, drug in mykrobe_tps:
    j = mykrobe_jsons[run]
    with open(j) as fp:
        data = json.load(fp)
        muts = list(data[run]["susceptibility"][drug]["called_by"].keys())
        for m in muts:
            all_variants.append((m.rsplit("-", maxsplit=1)[0], drug))
            if "rrs" in m:
                print(run, m, drug)

ERR1035350 rrs_A514X-A1472359T Streptomycin
ERR181861 rrs_C513X-C1472358T Streptomycin
SRR6831774 rrs_A1401X-A1473246G Kanamycin
ERR190365 rrs_C513X-C1472358T Streptomycin
ERR181907 rrs_C513X-C1472358T Streptomycin
SRR6831774 rrs_A1401X-A1473246G Amikacin
SRR6831774 rrs_A1401X-A1473246G Capreomycin
ERR163962 rrs_A514X-A1472359T Streptomycin
ERR1035344 rrs_A514X-A1472359T Streptomycin
ERR245836 rrs_A514X-A1472359T Streptomycin
SRR1723458 rrs_C517X-C1472362T Streptomycin
ERR036188 rrs_C513X-C1472358T Streptomycin
ERR181930 rrs_A514X-A1472359T Streptomycin
ERR163958 rrs_C513X-C1472358T Streptomycin
ERR176629 rrs_C513X-C1472358T Streptomycin


In [80]:
c = Counter(all_variants).most_common(1000)
c

[(('fabG1_C-15X', 'Isoniazid'), 264),
 (('inhA_S94A', 'Isoniazid'), 60),
 (('inhA_I194T', 'Isoniazid'), 40),
 (('fabG1_CTG607CTA', 'Isoniazid'), 26),
 (('ahpC_G-48A', 'Isoniazid'), 17),
 (('fabG1_T-8X', 'Isoniazid'), 14),
 (('katG_W191G', 'Isoniazid'), 10),
 (('inhA_I21T', 'Isoniazid'), 10),
 (('ahpC_C-57T', 'Isoniazid'), 9),
 (('pncA_V139M', 'Pyrazinamide'), 7),
 (('rrs_C513X', 'Streptomycin'), 6),
 (('katG_D142G', 'Isoniazid'), 6),
 (('rrs_A514X', 'Streptomycin'), 5),
 (('embA_C-16G', 'Ethambutol'), 5),
 (('ahpC_C-72T', 'Isoniazid'), 5),
 (('katG_S315I', 'Isoniazid'), 5),
 (('fabG1_G-17T', 'Isoniazid'), 4),
 (('katG_P232R', 'Isoniazid'), 4),
 (('gid_A80P', 'Streptomycin'), 4),
 (('katG_W191R', 'Isoniazid'), 4),
 (('katG_S315G', 'Isoniazid'), 4),
 (('katG_A109V', 'Isoniazid'), 3),
 (('gid_S149R', 'Streptomycin'), 2),
 (('pncA_H57P', 'Pyrazinamide'), 2),
 (('gid_L26F', 'Streptomycin'), 2),
 (('katG_S315R', 'Isoniazid'), 2),
 (('gid_S136P', 'Streptomycin'), 2),
 (('katG_L141F', 'Isoniaz

In [83]:
with open("../../docs/who_fns.tsv", "w") as fp:
    print("\t".join(["mutation", "drug", "num_fn"]), file=fp)
    for t, count in c:
        mut, drug = t
        print("\t".join([mut, drug, str(count)]), file=fp)

Get all RIF FPs from WHO

In [92]:
rif_fps = []
for _, row in df.query("panel == 'who2021' and classification == 'FP' and drug =='Rifampicin'").iterrows():
    run = row["run"]
    j = who_jsons[run]
    pred = row["prediction"]
    with open(j) as fp:
        data = json.load(fp)
        muts = list(data[run]["susceptibility"]["Rifampicin"]["called_by"].keys())
        for m in muts:
            rif_fps.append((m.rsplit("-", maxsplit=1)[0], pred))

In [93]:
rif_fps

[('rpoB_S450L', 'R'),
 ('rpoB_T427P', 'r'),
 ('rpoB_T427P', 'r'),
 ('rpoB_H445D', 'R'),
 ('rpoB_L452P', 'r'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445D', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_L452P', 'R'),
 ('rpoB_L452P', 'R'),
 ('rpoB_H445S', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_L452P', 'R'),
 ('rpoB_H445Y', 'r'),
 ('rpoB_S450L', 'r'),
 ('rpoB_H445N', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'r'),
 ('rpoB_D435Y', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_H445D', 'R'),
 ('rpoB_L452P', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_S450L', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_H445N', 'R'),
 ('rpoB_L430P', 'R'),
 ('rpoB_T427P', 'r'),
 ('rpoB_H4

In [94]:
with open("who_rif_fps.tsv", "w") as fp:
    print("mutation\tprediction", file=fp)
    for mut, pred in rif_fps:
        print(f"{mut}\t{pred}", file=fp)

Load the WHO catalogue and see if these variants exist in the catalogue

In [48]:
catalogue = defaultdict(list)
with open("../../docs/who-panel.tsv") as fp:
    _ = next(fp)
    for line in map(str.rstrip, fp):
        fields = line.split("\t")
        ref, pos, alt = split_var_name(fields[1])
        catalogue[(fields[0], pos)].append((ref, alt, int(fields[-1]), fields[3]))
        

In [54]:
n_not_in_cat = 0
for mut, count in c:
    gene, var = mut.split("_")
    ref, pos, alt = split_var_name(var)
    cat_vars = catalogue.get((gene, pos))
    if cat_vars is None:
        print(f"{mut} not in catalogue")
        n_not_in_cat += count
    else:
        for v in cat_vars:
            if v[0] == ref and v[1] == alt:
                if v[-1] == drug.lower():
                    print(f"[X] {mut} exactly found in catalogue: {v}")
                else:
                    print(f"{mut} exactly found for different drug: {v}")
            else:
                continue
#         print(f"{mut} in catalogue with variants: {cat_vars}")

fabG1_C-15X not in catalogue
inhA_S94A exactly found for different drug: ('S', 'A', 2, 'ethionamide')
[X] inhA_S94A exactly found in catalogue: ('S', 'A', 3, 'isoniazid')
inhA_I194T exactly found for different drug: ('I', 'T', 3, 'ethionamide')
[X] inhA_I194T exactly found in catalogue: ('I', 'T', 3, 'isoniazid')
fabG1_CTG607CTA not in catalogue
[X] ahpC_G-48A exactly found in catalogue: ('G', 'A', 3, 'isoniazid')
fabG1_T-8X not in catalogue
[X] katG_W191G exactly found in catalogue: ('W', 'G', 3, 'isoniazid')
inhA_I21T exactly found for different drug: ('I', 'T', 3, 'ethionamide')
[X] inhA_I21T exactly found in catalogue: ('I', 'T', 3, 'isoniazid')
[X] ahpC_C-57T exactly found in catalogue: ('C', 'T', 3, 'isoniazid')
[X] katG_D142G exactly found in catalogue: ('D', 'G', 3, 'isoniazid')
[X] katG_S315I exactly found in catalogue: ('S', 'I', 3, 'isoniazid')
[X] ahpC_C-72T exactly found in catalogue: ('C', 'T', 3, 'isoniazid')
[X] katG_S315G exactly found in catalogue: ('S', 'G', 3, 'ison

In [51]:
n_not_in_cat

309