In [1]:
import os
os.environ["POLARS_MAX_THREADS"] = "128"

import polars as pl
import polars.selectors as cs

pl.enable_string_cache()

In [2]:
%load_ext watermark
%watermark -vp polars

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

polars: 0.20.6



Protein-protein alignments are computed using `mmseqs2` (v13.45111):

```bash
mmseqs createdb FASTAFILE seqDB
mmseqs search seqDB seqDB resDB tmp -s 7.5 -c 0.3 # be sure to set your thread count!
mmseqs convertalis seqDB seqDB resDB RESULTSFILE --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,qlen,tlen,qcov,tcov 
```

In [3]:
cols = {
    "query": pl.String,
    "target": pl.String,
    "fident": pl.Float32,
    "alnlen": pl.Int16,
    "mismatch": pl.Int16,
    "gapopen": pl.Int16,
    "qstart": pl.Int16,
    "qend": pl.Int16,
    "tstart": pl.Int16,
    "tend": pl.Int16,
    "evalue": pl.Float32,
    "bits": pl.Float32,
    "qlen": pl.Int16,
    "tlen": pl.Int16,
    "qcov": pl.Float32,
    "tcov": pl.Float32,
}

def get_genome(col: str, is_train_dataset: bool = False) -> pl.Expr:
    # basically all proteins are named like this: scaffold_1...scaffold_2...
    # however, for multi-scaffold viruses (vMAGs), 
    # just take the first IMGVR4 identifier to get all scaffolds
    if is_train_dataset:
        return pl.col(col).str.extract(r"(.*)?_\d+$")
    
    return pl.col(col).str.split("|").list.get(0).cast(pl.Categorical)

def read_protein_alignments(file: str, test_run: bool=True) -> pl.DataFrame:
    """
    File is a tab-separated file with the following columns:
    query, target, fident, alnlen, mismatch, gapopen, qstart, qend, tstart, tend, evalue, bits, qlen, tlen, qcov, tcov

    This is generated by mmseqs2 (v13.45111):
        ```bash
        mmseqs convertalis seqDB seqDB resDB results.tsv \
               --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,qlen,tlen,qcov,tcov
        ```
    """
    data = (
        pl.scan_csv(
            file,
            has_header=False,
            separator="\t",
            new_columns=list(cols.keys()),
            dtypes=cols,
            n_rows=5_000_000 if test_run else None,
        )
        .drop(["mismatch", "gapopen"])
        .with_columns(
            query_genome=get_genome("query"),
            target_genome=get_genome("target"),
        )
        .filter(pl.col("query_genome") != pl.col("target_genome"))
        .collect()
        .cast({"query": pl.Categorical, "target": pl.Categorical})
    )

    return data

def read_genome_info(file: str) -> pl.DataFrame:
    """
    File is a simple tab-delimited file with 2 columns:
        - genome: genome name
        - ptns: number of encoded proteins

    Notably, the genome name is the same as defined above in the 
    `get_genome` function.
    """
    genome_info = (
        pl.read_csv(
            file, 
            separator="\t", 
            columns=["genome", "ptns"]
        )
        .cast({"genome": pl.Categorical, "ptns": pl.Int16})
    )

    return genome_info

def compute_aai(ptn_aln_file: str, genome_info_file: str, test_run: bool = False) -> pl.DataFrame:
    ptn_alns = read_protein_alignments(ptn_aln_file, test_run)
    genome_info = read_genome_info(genome_info_file)

    aai = (
        ptn_alns
        # .head(1000)
        .lazy()
        # filter genome-genome comps to choose best match 
        # for each ptn from each genome. This basically only evaluates AAI
        # for the best hits between each pair of genomes
        .with_columns(
            query_best_bits = pl.col("bits").max().over("query", "query_genome", "target_genome"),
            query_best_cov = pl.col("qcov").max().over("query", "query_genome", "target_genome"),
            target_best_bits = pl.col("bits").max().over("target", "query_genome", "target_genome"),
            target_best_cov = pl.col("tcov").max().over("target", "query_genome", "target_genome"),
        )
        .filter(
            (pl.col("bits") == pl.col("query_best_bits")) &
            (pl.col("bits") == pl.col("target_best_bits")) &
            (pl.col("qcov") == pl.col("query_best_cov")) &
            (pl.col("tcov") == pl.col("target_best_cov"))
        )

        ### then compute aai
        .group_by(["query_genome", "target_genome"])
        .agg(
            aai = pl.col("fident").mean(),
            shared_genes = pl.min_horizontal(pl.n_unique("query", "target")),
        )
        .collect()
        # then we need to compute the proportion of shared genes used to compute AAI
        .join(genome_info, left_on="query_genome", right_on="genome")
        .rename({"ptns": "query_ptns"})
        .join(genome_info, left_on="target_genome", right_on="genome")
        .rename({"ptns": "target_ptns"})
        .with_columns(
            query_shared = (pl.col("shared_genes") / pl.col("query_ptns")).cast(pl.Float32),
            target_shared = (pl.col("shared_genes") / pl.col("target_ptns")).cast(pl.Float32),
        )
    )

    return aai

def replace_ext(file: str, ext: str) -> str:
    return f'{file.rsplit(".", 1)[0]}.{ext}'

def save_aai(aai: pl.DataFrame, out_file: str, parquet: bool = True):
    if parquet:
        # output will generally be very large, 
        # so parquet is much more storage efficient
        out_file = replace_ext(out_file, "parquet")
        aai.write_parquet(out_file)
    else:
        out_file = replace_ext(out_file, "tsv")
        aai.write_csv(out_file, separator="\t")

    print(f"Saved AAI to {out_file}")

In [4]:
genome_info_file = "genome_info.tsv"
ptn_aln_file = "protein_alignments.tsv"

aai = compute_aai(ptn_aln_file, genome_info_file, test_run=True)
aai

query_genome,target_genome,aai,shared_genes,query_ptns,target_ptns,query_shared,target_shared
cat,cat,f32,u32,i16,i16,f32,f32
"""IMGVR_UViG_271…","""IMGVR_UViG_255…",0.419,1,121,66,0.008264,0.015152
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.391,1,121,63,0.008264,0.015873
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.398,1,121,85,0.008264,0.011765
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.376,1,121,57,0.008264,0.017544
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.377,1,121,81,0.008264,0.012346
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.344,1,121,72,0.008264,0.013889
"""IMGVR_UViG_271…","""IMGVR_UViG_330…",0.322,1,121,59,0.008264,0.016949
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.758,1,41,39,0.02439,0.025641
"""IMGVR_UViG_330…","""IMGVR_UViG_258…",0.734,1,41,41,0.02439,0.02439
"""IMGVR_UViG_330…","""IMGVR_UViG_273…",0.221,1,211,62,0.004739,0.016129


# Virus clustering based on AAI

We can prune the above pairwise AAI calculations to only consider clustering viruses at broader evolutionary connections. The following thresholds were used [previously](https://github.com/snayfach/MGV/tree/master/aai_cluster), but you could use your own thresholds.

We view this table as a graph, where the nodes are the genomes, and edges connect genomes based on detectable AAI. The edge weight is just the AAI and minimum proportion of shared genes. This notably penalizes high AAI based on only a few shared genes. Thus, any downstream clusters will be viruses that are more holistically related.

In [5]:
def prepare_for_clustering(aai_df: pl.DataFrame, cluster_level: str) -> pl.DataFrame:
    if cluster_level == "genus":
        min_aai = 0.4
        min_shared = 0.2
        min_genes = 16
    elif cluster_level == "family":
        min_aai = 0.2
        min_shared = 0.1
        min_genes = 8
    else:
        raise ValueError(f"Invalid cluster level: {cluster_level} | Must be 'genus' or 'family'")
    
    filtered_aai = (
        aai_df
        .lazy()
        .with_columns(
            min_shared = pl.min_horizontal(cs.ends_with("_shared")),
        )
        .filter(
            (pl.col("aai") >= min_aai) &
            (
                (pl.col("min_shared") >= min_shared) |
                (pl.col("shared_genes") >= min_genes)
            )
        )
        .with_columns(
            edge_weight = pl.col("aai") * pl.col("min_shared"),
        )
        .sort("edge_weight", descending=True)
        .collect()
    )

    return filtered_aai

def save_for_clustering(filtered_aai: pl.DataFrame, out_file: str):
    out_file = replace_ext(out_file, "abc")
    print(f"Saving for mcl clustering to {out_file}")

    (
        filtered_aai
        .select("query_genome", "target_genome", "edge_weight")
        .write_csv(out_file, separator="\t", include_header=False)
    )

The only columns needed for clustering (which define a text-based tabular graph) are `query_genome`, `target_genome`, and `edge_weight`. The above function `save_for_clustering` will save the dataframe for you.

In [6]:
filtered_aai = prepare_for_clustering(aai, "genus")
filtered_aai

query_genome,target_genome,aai,shared_genes,query_ptns,target_ptns,query_shared,target_shared,min_shared,edge_weight
cat,cat,f32,u32,i16,i16,f32,f32,f32,f32
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",1.0,1,2,2,0.5,0.5,0.5,0.5
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.976,1,2,2,0.5,0.5,0.5,0.488
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.967,1,2,2,0.5,0.5,0.5,0.4835
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.959,1,2,2,0.5,0.5,0.5,0.4795
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.956,1,2,2,0.5,0.5,0.5,0.478
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.953,1,2,2,0.5,0.5,0.5,0.4765
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.945,1,2,2,0.5,0.5,0.5,0.4725
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.937,1,2,2,0.5,0.5,0.5,0.4685
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.935,1,2,2,0.5,0.5,0.5,0.4675
"""IMGVR_UViG_330…","""IMGVR_UViG_330…",0.919,1,2,2,0.5,0.5,0.5,0.4595


You can use any graph clustering method you prefer, but for the PST manuscript, we clustered genomes on the basis of AAI using the Markov clustering algorithm ([mcl](https://github.com/micans/mcl)).

We use the following setup for mcl clustering:

```bash
mcxload -abc AAI_GRAPH.abc -o AAI_GRAPH.mci -write-tab AAI_GRAPH.tab
mcl AAI_GRAPH.mci -I 2.0 -use-tab AAI_GRAPH.tab -o AAI_genome_clusters.txt
```

Where `AAI_GRAPH.abc` is the file produced from the `save_for_clustering` function.