In [1]:
# these are modules specific to this project, but I didn't make an actual package
from model import GNNModule, GNNConfig
from data import DataConfig, GraphDataModule

from pathlib import Path


import torch
import itertools as it
import polars as pl
import polars.selectors as cs
from torch_geometric.data import HeteroData

In [2]:
%load_ext watermark
%watermark -vp torch,polars,torch_geometric

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

torch          : 2.2.2
polars         : 0.20.26
torch_geometric: 2.5.2



In [3]:
host_taxa_levels = ["domain", "phylum", "class", "order", "family", "genus", "species"]
host_taxa_levels = [f"host_{level}" for level in host_taxa_levels]

host_info = (
    pl.read_csv("../data/supplementary_table_8.tsv", separator="\t")
    .rename({"host_species": "host_name", "gtdbtk_classification": "host_taxonomy"})
    .with_columns(
        host_lineage = (
            pl.col("host_taxonomy")
            .str
            .split(";")
            .list
            .to_struct(fields = host_taxa_levels)
        ),
    )
    .unnest("host_lineage")
    .select(cs.starts_with("host"))
    .unique()
)

host_info

host_accession,host_name,host_label,host_taxonomy,host_domain,host_phylum,host_class,host_order,host_family,host_genus,host_species
str,str,str,str,str,str,str,str,str,str,str
"""GCA_005222125.1""","""Spiroplasma melliferum""","""Spiroplasma melliferum""","""d__Bacteria;p__Bacillota;c__Ba…","""d__Bacteria""","""p__Bacillota""","""c__Bacilli""","""o__Mycoplasmatales""","""f__Mycoplasmataceae""","""g__Spiroplasma""","""s__Spiroplasma melliferum"""
"""GCA_900475035.1""","""Streptococcus""","""Streptococcus sp.""","""d__Bacteria;p__Bacillota;c__Ba…","""d__Bacteria""","""p__Bacillota""","""c__Bacilli""","""o__Lactobacillales""","""f__Streptococcaceae""","""g__Streptococcus""","""s__Streptococcus pyogenes"""
"""GCA_003987795.1""","""Curvibacter sp.""","""Curvibacter sp.""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Burkholderiales""","""f__Burkholderiaceae_B""","""g__Curvibacter""","""s__Curvibacter sp003987795"""
"""GCA_003790525.1""","""Escherichia coli 4s""","""Escherichia coli""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales""","""f__Enterobacteriaceae""","""g__Escherichia""","""s__Escherichia coli"""
"""GCA_000220485.1""","""Klebsiella pneumoniae 51503""","""Klebsiella pneumoniae""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales""","""f__Enterobacteriaceae""","""g__Klebsiella""","""s__Klebsiella pneumoniae"""
…,…,…,…,…,…,…,…,…,…,…
"""GCA_008479505.2""","""Salmonella enterica subsp. ent…","""Salmonella enterica""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales""","""f__Enterobacteriaceae""","""g__Salmonella""","""s__Salmonella enterica"""
"""GCA_000025625.1""","""Natrialba magadii""","""Natrialba magadii""","""d__Archaea;p__Halobacteriota;c…","""d__Archaea""","""p__Halobacteriota""","""c__Halobacteria""","""o__Halobacteriales""","""f__Natrialbaceae""","""g__Natrialba""","""s__Natrialba magadii"""
"""GCA_024229555.1""","""Alteromonas sp.""","""Alteromonas sp.""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales_A""","""f__Alteromonadaceae""","""g__Alteromonas""","""s__Alteromonas sp002691625"""
"""GCA_002079225.1""","""Escherichia coli STEC O179""","""Escherichia coli""","""d__Bacteria;p__Pseudomonadota;…","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales""","""f__Enterobacteriaceae""","""g__Escherichia""","""s__Escherichia coli"""


In [4]:
def load_everything(knowledge_graph_file: Path, ckptfile: Path) -> tuple[GNNModule, GraphDataModule]:
    ckpt = torch.load(ckptfile, map_location="cpu")
    ckpt["datamodule_hyper_parameters"]["file"] = knowledge_graph_file

    datacfg = DataConfig.model_validate(ckpt["datamodule_hyper_parameters"])
    datamodule = GraphDataModule(datacfg)

    modelcfg = GNNConfig.model_validate(ckpt["hyper_parameters"])
    model = GNNModule(modelcfg, datamodule.data.metadata())
    model.load_state_dict(ckpt["state_dict"])
    return model, datamodule

def edge_index_to_set_of_tuples(edge_index: torch.Tensor) -> set[tuple[int, int]]:
    if edge_index.shape[0] != 2:
        raise ValueError("edge_index must have shape (2, E)")
    
    return set(tuple(map(int, col)) for col in edge_index.t()) # type: ignore

def enumerate_all_virus_host_pairs(dataset: HeteroData) -> tuple[torch.Tensor, torch.Tensor]:
    true_edge_index = dataset["infects"].edge_index
    n_viruses, n_hosts = map(int, true_edge_index.amax(1) + 1)

    # shape is [E, 2]
    full_edge_index = torch.tensor(
        list(it.product(range(n_viruses), range(n_hosts))),
    )

    true_edges = edge_index_to_set_of_tuples(true_edge_index)
    edge_label_list: list[bool] = []
    for edge in full_edge_index:
        edge = tuple(map(int, edge))
        edge_label_list.append(edge in true_edges)
    edge_label = torch.tensor(edge_label_list)

    # reshape to [2, E]
    return full_edge_index.t().contiguous(), edge_label

In [5]:
dataset_name = "pst-large"

model, dataset = load_everything(
    Path(f"../data/knowledge_graphs/{dataset_name}_knowledge_graph.pt"),
    Path(f"trained_models/{dataset_name}_trained_model.ckpt")
)

In [6]:
true_edges = edge_index_to_set_of_tuples(dataset.data["infects"].edge_index)
len(true_edges)

31484

In [7]:
# we nede to create a fully-enumerated edge index for all pairs of viruses and 
# hosts in the knowledge graph. This will be used to evaluate the model
# we also keep track of which edges are actually in the knowledge graph
full_edge_index, full_edge_label = enumerate_all_virus_host_pairs(dataset.data)
full_edge_index.shape

torch.Size([2, 4237520])

In [8]:
@torch.no_grad()
def predict(
    model: GNNModule, 
    dataset: HeteroData, 
    inference_edge_index: torch.Tensor, 
    inference_edge_labels: torch.Tensor,
    evaluate_test_only: bool = True,
) -> pl.DataFrame:
    model.eval()

    virus_test_mask = dataset["virus"].test_mask
    inference_edge_test_mask = virus_test_mask[inference_edge_index[0]]

    logit = model(dataset, inference_edge_index)

    if evaluate_test_only:
        # need to clone so that the .numpy call works later
        inference_edge_index = inference_edge_index[:, inference_edge_test_mask].clone()
        inference_edge_labels = inference_edge_labels[inference_edge_test_mask]
        logit = logit[inference_edge_test_mask]

    proba = torch.sigmoid(logit)
    results = pl.DataFrame({
        "virus_id": inference_edge_index[0].numpy(),
        "host_id": inference_edge_index[1].numpy(),
        "logit": logit.numpy(),
        "proba": proba.numpy(),
        "label": inference_edge_labels.numpy(),
        "host": dataset["host"].label[dataset["host"].y[inference_edge_index[1]]],
        # "test_virus": inference_edge_test_mask.numpy(),
    })

    if not evaluate_test_only:
        results["test_mask"] = inference_edge_test_mask.numpy()

    return results

def summarize_predictions(results: pl.DataFrame) -> pl.DataFrame:
    start = (
        results
        .group_by("virus_id", "host")
        .agg(
            pl.max("logit", "proba"),
            pl.any("label"),
        )
        .join(host_info.drop("host_taxonomy", "host_accession").unique(), left_on="host", right_on="host_label")
        .sort(["virus_id", "logit"], descending=[False, True])
    )

    # keep track of all possible host labels for each virus
    true_labels = (
        start
        .lazy()
        .filter(pl.col("label"))
        .group_by("virus_id")
        .agg(
            *[
                pl.col(c).unique()
                for c in host_taxa_levels
            ],
            pl.max("logit", "proba")
        )
        .collect()
    )

    summary = (
        start
        .join(true_labels, on="virus_id", suffix="_true")
        .with_columns(
            *[
                pl.col(f"{c}_true").list.contains(pl.col(c))
                for c in host_taxa_levels
            ]
        )
        .sort(["virus_id"])
    )

    return summary

In [9]:
results = predict(model, dataset.data, full_edge_index, full_edge_label)
summary = summarize_predictions(results)
summary

virus_id,host,logit,proba,label,host_name,host_domain,host_phylum,host_class,host_order,host_family,host_genus,host_species,host_domain_true,host_phylum_true,host_class_true,host_order_true,host_family_true,host_genus_true,host_species_true,logit_true,proba_true
i64,str,f32,f32,bool,str,str,str,str,str,str,str,str,bool,bool,bool,bool,bool,bool,bool,f32,f32
0,"""Bacteroides intestinalis""",-1.421236,0.194468,false,"""Bacteroides intestinalis""","""d__Bacteria""","""p__Bacteroidota""","""c__Bacteroidia""","""o__Bacteroidales""","""f__Bacteroidaceae""","""g__Bacteroides""","""s__Bacteroides intestinalis""",true,false,false,false,false,false,false,-1.63951,0.162532
0,"""Kitasatospora aureofaciens""",-1.550044,0.17508,false,"""Kitasatospora aureofaciens""","""d__Bacteria""","""p__Actinomycetota""","""c__Actinomycetia""","""o__Streptomycetales""","""f__Streptomycetaceae""","""g__Kitasatospora""","""s__Kitasatospora aureofaciens""",true,false,false,false,false,false,false,-1.63951,0.162532
0,"""Wolbachia sp.""",-1.63951,0.162532,true,"""Wolbachia sp.""","""d__Bacteria""","""p__Pseudomonadota""","""c__Alphaproteobacteria""","""o__Rickettsiales""","""f__Anaplasmataceae""","""g__Wolbachia""","""s__Wolbachia_SPECIES""",true,true,true,true,true,true,true,-1.63951,0.162532
0,"""Pseudomonas grimontii""",-2.078492,0.111205,false,"""Pseudomonas grimontii""","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Pseudomonadales""","""f__Pseudomonadaceae""","""g__Pseudomonas_E""","""s__Pseudomonas_E grimontii""",true,true,false,false,false,false,false,-1.63951,0.162532
0,"""Vibrio pelagius""",-2.19977,0.099771,false,"""Vibrio pelagius""","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales_A""","""f__Vibrionaceae""","""g__Vibrio""","""s__Vibrio pelagius""",true,true,false,false,false,false,false,-1.63951,0.162532
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
5263,"""Kitasatospora aureofaciens""",-7.388738,0.000618,false,"""Kitasatospora aureofaciens""","""d__Bacteria""","""p__Actinomycetota""","""c__Actinomycetia""","""o__Streptomycetales""","""f__Streptomycetaceae""","""g__Kitasatospora""","""s__Kitasatospora aureofaciens""",true,false,false,false,false,false,false,2.821479,0.943826
5263,"""Bacteroides intestinalis""",-7.461622,0.000574,false,"""Bacteroides intestinalis""","""d__Bacteria""","""p__Bacteroidota""","""c__Bacteroidia""","""o__Bacteroidales""","""f__Bacteroidaceae""","""g__Bacteroides""","""s__Bacteroides intestinalis""",true,false,false,false,false,false,false,2.821479,0.943826
5263,"""Wolbachia sp.""",-7.678973,0.000462,false,"""Wolbachia sp.""","""d__Bacteria""","""p__Pseudomonadota""","""c__Alphaproteobacteria""","""o__Rickettsiales""","""f__Anaplasmataceae""","""g__Wolbachia""","""s__Wolbachia_SPECIES""",true,true,false,false,false,false,false,2.821479,0.943826
5263,"""Vibrio pelagius""",-7.898198,0.000371,false,"""Vibrio pelagius""","""d__Bacteria""","""p__Pseudomonadota""","""c__Gammaproteobacteria""","""o__Enterobacterales_A""","""f__Vibrionaceae""","""g__Vibrio""","""s__Vibrio pelagius""",true,true,true,false,false,false,false,2.821479,0.943826


In [10]:
def true_host_recall(
    summary: pl.DataFrame, 
    threshold: float = 0.9, 
    taxon_rank: str | list[str] = "all"
) -> float:
    if taxon_rank == "all":
        taxon_rank = host_taxa_levels[:]

    if isinstance(taxon_rank, str):
        taxon_rank = [taxon_rank]

    taxon_rank = [
        f"host_{rank}" if not rank.startswith("host_") else rank
        for rank in taxon_rank
    ]
    
    total_viruses = summary["virus_id"].n_unique()
    filtered_summary = summary.filter(pl.col("proba") >= threshold)

    recall: dict[str, float] = {}
    for level in taxon_rank:
        predicted_viruses = (
            filtered_summary.filter(pl.col(f"{level}_true"))
            ["virus_id"]
            .n_unique()
        )

        recall[level] = predicted_viruses / total_viruses

    return recall

In [11]:
true_host_recall(summary, 0.90, "all")

{'host_domain': 0.7096577017114915,
 'host_phylum': 0.6693154034229829,
 'host_class': 0.6644254278728606,
 'host_order': 0.6014669926650367,
 'host_family': 0.5831295843520783,
 'host_genus': 0.5317848410757946,
 'host_species': 0.45965770171149145}

In [12]:
def confidence_of_true_hosts(
    summary: pl.DataFrame, 
    threshold: float = 0.9, 
    mirror_iphop: bool = True,
    taxon_rank: str = "species",
    as_histogram: bool = True,
) -> pl.DataFrame:
    if not taxon_rank.startswith("host_"):
        taxon_rank = f"host_{taxon_rank}"
    
    summary = summary.lazy() # type: ignore
    if mirror_iphop:
        # iphop does not report hits below the threshold, but this method will
        # so we can keep track of these separately
        summary = (
            summary
            .with_columns(
                proba = (
                    pl.when(pl.col("proba") < threshold)
                    .then(0.0)
                    .otherwise(pl.col("proba"))
                )
            )
        )

    summary = (
        summary
        .filter(pl.col(f"{taxon_rank}_true"))
        .group_by("virus_id")
        .agg(pl.max("proba"))
        .select("virus_id", "proba")
        .sort("virus_id")
        .collect() # type:ignore
    )

    if as_histogram:
        return (
            summary
            ["proba"]
            .hist(bin_count=100)
            .filter(pl.col("count") > 0)
        )

    return summary.collect() # type: ignore

This mirrors the species level recall computed above where 752/1636=0.4597 of the test viruses have their correct host predictions at or above a threshold of 0.9.

In [13]:
confidence_of_true_hosts(
    summary, 
    0.9, 
    True, 
    "species",
    True,
)

break_point,category,count
f64,cat,u32
0.02,"""(-0.01, 0.02]""",884
0.92,"""(0.89, 0.92]""",203
0.95,"""(0.92, 0.95]""",266
0.98,"""(0.95, 0.98]""",231
1.01,"""(0.98, 1.01]""",52


Just repeat the above analyses for each of the trained models and include `iPHoP` results from the output. We specifically lower the confidence threshold of `iPHoP` to its minimum, and then manually filter the confidences as above with `polars`.
```bash
iphop predict -f FASTAFILE -m 75
```