# Complete candidate search (nrw_2022, 2,500 sequences)
This notebook contains the evaluation of the complete candidate search concerning the aggregate nrw_2022 with 2,500 sequences. Breadth search and depth search results were taken into account.
The distance matrices were evaluated in terms of infection recall. Furthermore, the mst structure was evaluated based on community ARI, lineage purity and mean edge weight. MSTs were generated to compare the results with the MST of the optimized algorithm regarding the distribution of outbreak-related attributes.

In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=26677 sha256=08221447055efe7d0be9ed549afa6ca5d2958d264e692ea6390f9c39c17c9ee7
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-gw556wm2/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pickle
from gentrain.evaluation import get_computation_rate_plot, candidate_evaluation_and_matrices, get_candidate_evaluation_and_export_mst
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
from gentrain.encoding import get_nucleotide_sensitive_encodings, get_mutation_sensitive_encodings, generate_one_hot_encoding
from gentrain.nextclade import get_mutations_from_dataframe
from gentrain.candidate_sourcing import bitwise_xor_candidates
from gentrain.graph import build_mst, export_graph_gexf, mean_edge_weight, get_outbreak_community_labels, build_graph
from scipy.spatial.distance import pdist
import community as community_louvain
import umap
import faiss
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances_argmin
import os
import shutil

In [3]:
aggregate = "nrw_2022"
size = 2500

In [4]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [5]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";").set_index("igs_id").sort_index()
sequences_count = len(sequences_df)

In [6]:
mutations_df = get_mutations_from_dataframe(sequences_df)

In [7]:
distance_matrix_df = pd.read_csv(f"../01_algorithm_optimization/distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";",
                                     index_col="Unnamed: 0").sort_index()
distance_matrix_df = distance_matrix_df[~distance_matrix_df.index.duplicated(keep="first")]
distance_matrix_df = distance_matrix_df.loc[sequences_df.index, sequences_df.index]
distance_matrix = distance_matrix_df.to_numpy()

In [8]:
gentrain_graph = build_graph(distance_matrix)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
datetime_sampling_dates = pd.to_datetime(sequences_df["date_of_sampling"])
numeric_dates = (datetime_sampling_dates - datetime_sampling_dates.min()).dt.days
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/brute_force")

mst generation time: 13.72s


In [9]:
mask = np.triu(np.ones(distance_matrix_df.shape), k=1).astype(bool)
filtered = distance_matrix_df.where(mask)
infections_count = (filtered < 2).sum().sum()
distances_count = filtered.count().sum()

## Encoding

In [10]:
encodings_N_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=False, filter_N=True)

execution time: 49.23s


In [11]:
encodings_N_and_SNV_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=True, filter_N=True)

execution time: 43.29s


## Accurate candidate search

### Depth search with N frequency filter

In [12]:
depth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "depth")
    depth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 16.23s
execution time depth search: 1.96s
execution time 156187: 18.55s
mst generation time: 0.27s
execution time xor distance calculation: 14.73s
execution time depth search: 2.77s
execution time 312375: 17.94s
mst generation time: 1.26s
execution time xor distance calculation: 14.82s
execution time depth search: 2.49s
execution time 468562: 17.69s
mst generation time: 2.11s
execution time xor distance calculation: 13.96s
execution time depth search: 2.34s
execution time 624750: 16.74s
mst generation time: 2.43s


In [13]:
pd.DataFrame(depth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.901632,1.0,1.0,0.901632,18.55,3.530692,-0.296959,14.0,216,0.480901,0.987386,0.028186
1,0.1,0.969697,1.0,1.0,0.969697,17.94,3.662017,-0.165634,15.0,91,0.701814,0.970637,0.011437
2,0.15,0.992075,1.0,1.0,0.992075,17.69,3.674908,-0.152743,16.0,61,0.840369,0.965264,0.006064
3,0.2,0.99627,1.0,1.0,0.99627,16.74,3.697023,-0.130628,26.0,48,0.883093,0.958079,-0.001121


### Breadth search with N frequency filter

In [14]:
breadth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "breadth")
    breadth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 9.29s
execution time distance collection: 3.3s
execution time breadth search: 2.96s
execution time 156187: 16.05s
mst generation time: 0.87s
matrix generation time: 9.9s
execution time distance collection: 4.23s
execution time breadth search: 3.12s
execution time 312375: 17.72s
mst generation time: 1.05s
matrix generation time: 9.49s
execution time distance collection: 3.39s
execution time breadth search: 2.97s
execution time 468562: 16.32s
mst generation time: 1.68s
matrix generation time: 9.31s
execution time distance collection: 3.6s
execution time breadth search: 2.95s
execution time 624750: 16.34s
mst generation time: 2.62s


In [15]:
pd.DataFrame(breadth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.049775,0.760839,1.0,1.0,0.760839,16.05,4.067787,0.240136,74.0,1,0.417044,0.9556,-0.0036
1,0.097519,0.862937,1.0,1.0,0.862937,17.72,3.991236,0.163585,74.0,1,0.514469,0.9536,-0.0056
2,0.144773,0.911888,1.0,1.0,0.911888,16.32,3.95026,0.122609,74.0,1,0.6127,0.9508,-0.0084
3,0.190036,0.937529,1.0,1.0,0.937529,16.34,3.923489,0.095838,74.0,1,0.663349,0.9516,-0.0076


### Depth search with N and SNV frequency filter

In [16]:
depth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "depth")
    depth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 9.94s
execution time depth search: 2.22s
execution time 156187: 12.6s
mst generation time: 0.52s
execution time xor distance calculation: 10.31s
execution time depth search: 2.2s
execution time 312375: 12.93s
mst generation time: 1.3s
execution time xor distance calculation: 9.7s
execution time depth search: 2.21s
execution time 468562: 12.34s
mst generation time: 2.09s
execution time xor distance calculation: 9.89s
execution time depth search: 2.29s
execution time 624750: 12.65s
mst generation time: 2.78s


In [17]:
pd.DataFrame(depth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.919814,1.0,1.0,0.919814,12.6,3.544318,-0.283333,22.0,212,0.492901,0.980603,0.021403
1,0.1,0.975758,1.0,1.0,0.975758,12.93,3.632286,-0.195365,22.0,112,0.692157,0.967541,0.008341
2,0.15,0.99021,1.0,1.0,0.99021,12.34,3.674092,-0.153559,19.0,76,0.787432,0.958882,-0.000318
3,0.2,0.995338,1.0,1.0,0.995338,12.65,3.71018,-0.117471,19.0,54,0.87212,0.959217,1.7e-05


### Breadth search with N and SNV frequency filter

In [18]:
breadth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "breadth")
    breadth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 5.38s
execution time distance collection: 4.33s
execution time breadth search: 2.59s
execution time 156187: 12.77s
mst generation time: 0.83s
matrix generation time: 5.68s
execution time distance collection: 3.78s
execution time breadth search: 2.93s
execution time 312375: 12.88s
mst generation time: 1.39s
matrix generation time: 5.42s
execution time distance collection: 3.74s
execution time breadth search: 3.28s
execution time 468562: 12.89s
mst generation time: 2.15s
matrix generation time: 6.0s
execution time distance collection: 3.79s
execution time breadth search: 2.87s
execution time 624750: 13.12s
mst generation time: 2.32s


In [19]:
pd.DataFrame(breadth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.049775,0.772494,1.0,1.0,0.772494,12.77,4.043497,0.215846,74.0,1,0.439836,0.954,-0.0052
1,0.097519,0.882051,1.0,1.0,0.882051,12.88,3.986595,0.158944,74.0,1,0.558261,0.9488,-0.0104
2,0.144773,0.931935,1.0,1.0,0.931935,12.89,3.961144,0.133493,74.0,1,0.627435,0.9484,-0.0108
3,0.190036,0.958508,1.0,1.0,0.958508,13.12,3.936335,0.108683,74.0,1,0.68758,0.9476,-0.0116


In [20]:
depth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

In [21]:
breadth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

### Infection recall for different filters and computation rates using depth search

In [22]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_search_evaluation, "Infection recall", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=30),
        ))

sub_fig.show()

### Community ARI for different filters and computation rates using depth search

In [23]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_search_evaluation, "Community ARI", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=30),
        ))

sub_fig.show()

### Infection recall for different filters and computation rates using breadth search

In [24]:
sub_fig = get_computation_rate_plot("infection_detection_rate", breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.65,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different filters and computation rates using breadth search

In [25]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.55,
            y=0.02,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

In [26]:
depth_vs_breadth_search_evaluation = {
    "Depth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "blue"},
    "Breadth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "green"},  
}

### Infection recall for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [27]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_vs_breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.7,
            y=0.1,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [28]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_vs_breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.7,
            y=0.1,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()