In [1]:
import os
from pathlib import Path
from typing import TypedDict

import tables as tb

import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric import utils

os.environ["POLARS_MAX_THREADS"] = "128"
import polars as pl # noqa: F401

In [2]:
%load_ext watermark
%watermark -vp torch,torch_geometric,polars,tables,numpy

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

torch          : 2.2.2
torch_geometric: 2.5.2
polars         : 1.1.0
tables         : 3.9.2
numpy          : 1.26.4



## Virus-host edges

This info can be found in in `Supplementary Table 8` of the PST manuscript.

In [3]:
virus_host_info = (
    pl.read_csv("supplementary_table_8.tsv", separator="\t")
    .with_columns(
        n_viruses = pl.col("virus_accession").n_unique().over("host_label"),
    )
)

virus_ids = {
    v: i
    for i, v in enumerate(virus_host_info["virus_accession"].unique().sort())
}

host_ids = {
    h: i
    for i, h in enumerate(virus_host_info["host_accession"].unique().sort())
}

host2label = {
    row["host_accession"]: row["host_label"]
    for row in virus_host_info.select("host_accession", "host_label").unique().iter_rows(named=True)
}

virus_host_info = virus_host_info.with_columns(
    virus_id = pl.col("virus_accession").replace(virus_ids).cast(pl.UInt32),
    host_id = pl.col("host_accession").replace(host_ids).cast(pl.UInt32),
)

virus_host_info

dataset,virus_accession,virus_species,host_accession,host_species,host_label,gtdbtk_classification,n_viruses,virus_id,host_id
str,str,str,str,str,str,str,u32,u32,u32
"""test""","""FN436268""","""Acaryochloris phage A-HIS1""","""GCA_000018105.1""","""Acaryochloris marina strain MB…","""Acaryochloris marina""","""d__Bacteria;p__Cyanobacteriota…",1,49,39
"""test""","""MT241607""","""Achromobacter phage AMA2""","""GCA_902859635.1""","""Achromobacter marplatensis""","""Achromobacter marplatensis""","""d__Bacteria;p__Pseudomonadota;…",1,5042,790
"""test""","""MT708550""","""Achromobacter phage Mano""","""GCA_002991505.1""","""Achromobacter sp.""","""Achromobacter sp.""","""d__Bacteria;p__Pseudomonadota;…",2,5143,307
"""test""","""MW269554""","""Achromobacter phage vB_AchrS_A…","""GCA_002991505.1""","""Achromobacter sp.""","""Achromobacter sp.""","""d__Bacteria;p__Pseudomonadota;…",2,5239,307
"""test""","""MH746817""","""Achromobacter phage vB_Ade_ART""","""GCA_013267375.1""","""Achromobacter denitrificans PR…","""Achromobacter denitrificans""","""d__Bacteria;p__Pseudomonadota;…",1,4305,454
…,…,…,…,…,…,…,…,…,…
"""train""","""GCF_000847165.1""","""Yersinia phage phiYeO3-12""","""GCA_025758635.1""","""Yersinia enterocolitica (type …","""Yersinia enterocolitica""","""d__Bacteria;p__Pseudomonadota;…",18,205,627
"""train""","""GCF_001505135.1""","""Yersinia phage vB_YenM_TG1""","""GCA_025758635.1""","""Yersinia enterocolitica""","""Yersinia enterocolitica""","""d__Bacteria;p__Pseudomonadota;…",18,1796,627
"""train""","""GCF_001470735.1""","""Yersinia phage vB_YenP_AP10""","""GCA_025758635.1""","""Yersinia enterocolitica""","""Yersinia enterocolitica""","""d__Bacteria;p__Pseudomonadota;…",18,1587,627
"""train""","""GCF_000926875.1""","""Yersinia phage vB_YenP_AP5""","""GCA_025758635.1""","""Yersinia enterocolitica""","""Yersinia enterocolitica""","""d__Bacteria;p__Pseudomonadota;…",18,1383,627


In [4]:
train_viruses = set(virus_host_info.filter(pl.col("dataset") == "train")["virus_accession"].unique())
test_viruses = set(virus_host_info.filter(pl.col("dataset") == "test")["virus_accession"].unique())

print("Train viruses:", len(train_viruses))
print("Test viruses:", len(test_viruses))

Train viruses: 3628
Test viruses: 1636


In [5]:
EdgeType = tuple[str, str, str]
IdDict = dict[str, int]
FilePath = str | Path

class NodeEmbedding(TypedDict):
    data: torch.Tensor
    names: list[str]

def get_node_embeddings(node_embed_file: FilePath) -> dict[str, NodeEmbedding]:
    storage: dict[str, NodeEmbedding] = {}

    with tb.open_file(node_embed_file) as fp:
        for node_type in ["host", "virus"]:
            data = torch.from_numpy(fp.root[f"{node_type}_data"][:])
            names = [name.decode() for name in fp.root[f"{node_type}_names"][:]]
            storage[node_type] = {"data": data, "names": names}

    return storage

def _get_virus_virus_edges(
    virus_virus_edge_file: FilePath, 
    virus_ids: IdDict,
) -> torch.Tensor:
    src: list[int] = []
    tgt: list[int] = []

    with open(virus_virus_edge_file) as fp:
        for line in fp:
            s, t = line.strip().split("\t")
            sid = virus_ids[s]
            tid = virus_ids[t]

            src.append(sid)
            tgt.append(tid)

    edge_index = torch.vstack(
        (torch.tensor(src), torch.tensor(tgt))
    )

    edge_index = utils.to_undirected(
        utils.add_self_loops(edge_index)[0]
    )

    return edge_index

def _get_virus_host_edges(
    virus_host_edge_file: FilePath, 
    virus_ids: IdDict, 
    host_ids: IdDict,
) -> torch.Tensor:
    viruses: list[int] = []
    hosts: list[int] = []

    with open(virus_host_edge_file) as fp:
        for line in fp:
            virus, host, *_ = line.strip().split("\t")
            vid = virus_ids[virus]
            hid = host_ids[host]

            viruses.append(vid)
            hosts.append(hid)

    edge_index = torch.vstack(
        (torch.tensor(viruses), torch.tensor(hosts))
    )

    edge_index = utils.coalesce(edge_index)

    return edge_index


def get_edges(
    virus_virus_edge_file: FilePath, 
    virus_host_edge_file: FilePath, 
    virus_ids: IdDict, 
    host_ids: IdDict
) -> dict[EdgeType, torch.Tensor]:
    edge_index_dict: dict[EdgeType, torch.Tensor] = {
        ("virus", "related_to", "virus"): _get_virus_virus_edges(virus_virus_edge_file, virus_ids),
        ("virus", "infects", "host"): _get_virus_host_edges(virus_host_edge_file, virus_ids, host_ids),
    }

    # need to add the reverse edges for message passing purposes
    edge_index_dict[("host", "rev_infects", "virus")] = edge_index_dict[("virus", "infects", "host")].flip(0)

    return edge_index_dict

In [6]:
def create_knowledge_graph(
    node_embed_file: FilePath, 
    virus_virus_edge_file: FilePath, 
    virus_host_edge_file: FilePath,
    virus_ids: IdDict,
    host_ids: IdDict,
    host2label: dict[str, str],
    test_viruses: set[str],
):
    x_dict = get_node_embeddings(node_embed_file)
    edge_index_dict = get_edges(
        virus_virus_edge_file, 
        virus_host_edge_file, 
        virus_ids, 
        host_ids
    )

    graph = HeteroData()

    host_label_id = 0
    host_label_ids = {}
    for label in host2label.values():
        if label not in host_label_ids:
            host_label_ids[label] = host_label_id
            host_label_id += 1

    ### NODE STUFF
    for node_type, node_data in x_dict.items():
        graph[node_type].x = node_data["data"]
        graph[node_type].names = np.array(node_data["names"])

        if node_type == "host":
            label_ids = torch.tensor([
                host_label_ids[host2label[name]]
                for name in node_data["names"]
            ])

            graph[node_type].y = label_ids

            label_names = np.array(list(host_label_ids.keys()))
            graph[node_type].label = label_names

        if node_type == "virus":
            all_viruses = set(node_data["names"])
            train_viruses = all_viruses - test_viruses

            train_mask = torch.tensor([
                name in train_viruses
                for name in node_data["names"]
            ])

            test_mask = ~train_mask

            graph[node_type].train_mask = train_mask
            graph[node_type].test_mask = test_mask

    ### EDGE STUFF
    for edge_type, edge_index in edge_index_dict.items():
        graph[edge_type].edge_index = edge_index

        if "infects" in edge_type[1]:
            edge_train_mask = graph["virus"].train_mask[edge_index[0]]
            edge_test_mask = graph["virus"].test_mask[edge_index[0]]

            graph[edge_type].train_mask = edge_train_mask
            graph[edge_type].test_mask = edge_test_mask

    return graph

In [7]:
dataset = "pst-large"
graph = create_knowledge_graph(
    f"node_embeddings/{dataset}_node_embeddings.h5",
    f"edgelist/virus-virus/{dataset}_virus-virus_edges.tsv",
    "edgelist/known_virus-host_edges.tsv",
    virus_ids,
    host_ids,
    host2label,
    test_viruses
)

graph

HeteroData(
  host={
    x=[805, 1280],
    names=[805],
    y=[805],
    label=[581],
  },
  virus={
    x=[5264, 1280],
    names=[5264],
    train_mask=[5264],
    test_mask=[5264],
  },
  (virus, related_to, virus)={ edge_index=[2, 136352] },
  (virus, infects, host)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  },
  (host, rev_infects, virus)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  }
)

In [8]:
dataset = "esm-large"
graph = create_knowledge_graph(
    f"node_embeddings/{dataset}_node_embeddings.h5",
    f"edgelist/virus-virus/{dataset}_virus-virus_edges.tsv",
    "edgelist/known_virus-host_edges.tsv",
    virus_ids,
    host_ids,
    host2label,
    test_viruses
)

graph

HeteroData(
  host={
    x=[805, 640],
    names=[805],
    y=[805],
    label=[581],
  },
  virus={
    x=[5264, 640],
    names=[5264],
    train_mask=[5264],
    test_mask=[5264],
  },
  (virus, related_to, virus)={ edge_index=[2, 142680] },
  (virus, infects, host)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  },
  (host, rev_infects, virus)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  }
)

In [9]:
dataset = "cherry"
graph = create_knowledge_graph(
    f"node_embeddings/{dataset}_node_embeddings.h5",
    f"edgelist/virus-virus/{dataset}_virus-virus_edges.tsv",
    "edgelist/known_virus-host_edges.tsv",
    virus_ids,
    host_ids,
    host2label,
    test_viruses
)

graph

HeteroData(
  host={
    x=[805, 256],
    names=[805],
    y=[805],
    label=[581],
  },
  virus={
    x=[5264, 256],
    names=[5264],
    train_mask=[5264],
    test_mask=[5264],
  },
  (virus, related_to, virus)={ edge_index=[2, 515804] },
  (virus, infects, host)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  },
  (host, rev_infects, virus)={
    edge_index=[2, 31484],
    train_mask=[31484],
    test_mask=[31484],
  }
)