# Complete candidate search (due_202203, 2,500 sequences)
This notebook contains the evaluation of the complete candidate search concerning the aggregate due_202203 with 2,500 sequences. Breadth search and depth search results were taken into account.
The distance matrices were evaluated in terms of infection recall. Furthermore, the mst structure was evaluated based on community ARI, lineage purity and mean edge weight. MSTs were generated to compare the results with the MST of the optimized algorithm regarding the distribution of outbreak-related attributes.

In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=26677 sha256=e032d785b567b78c4b77580ea6ea720a96437621b092a80493063dde1a7fd095
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-btqjssb_/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
from collections import Counter
import pickle
from gentrain.evaluation import get_computation_rate_plot, candidate_evaluation_and_matrices, get_candidate_evaluation_and_export_mst
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
from gentrain.encoding import get_nucleotide_sensitive_encodings, get_mutation_sensitive_encodings, generate_one_hot_encoding
from gentrain.nextclade import get_mutations_from_dataframe
from gentrain.candidate_sourcing import bitwise_xor_candidates
from gentrain.graph import build_mst, export_graph_gexf, mean_edge_weight, get_outbreak_community_labels, build_graph
from scipy.spatial.distance import pdist
import community as community_louvain
import umap
import faiss
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances_argmin
import os
import shutil

In [3]:
aggregate = "due_202203"
size = 2500

In [4]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [5]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";").set_index("igs_id").sort_index()
sequences_count = len(sequences_df)

In [6]:
mutations_df = get_mutations_from_dataframe(sequences_df)

In [7]:
distance_matrix_df = pd.read_csv(f"../01_algorithm_optimization/distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";",
                                     index_col="Unnamed: 0").sort_index()
distance_matrix_df = distance_matrix_df[~distance_matrix_df.index.duplicated(keep="first")]
distance_matrix_df = distance_matrix_df.loc[sequences_df.index, sequences_df.index]
distance_matrix = distance_matrix_df.to_numpy()

In [8]:
gentrain_graph = build_graph(distance_matrix)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
datetime_sampling_dates = pd.to_datetime(sequences_df["date_of_sampling"])
numeric_dates = (datetime_sampling_dates - datetime_sampling_dates.min()).dt.days
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/brute_force_mst")

mst generation time: 8.97s


In [9]:
communities_count = len(Counter(gentrain_community_labels))
communities_count

111

In [10]:
mask = np.triu(np.ones(distance_matrix_df.shape), k=1).astype(bool)
filtered = distance_matrix_df.where(mask)
infections_count = (filtered < 2).sum().sum()
distances_count = filtered.count().sum()

## Encoding

In [11]:
encodings_N_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=False, filter_N=True)

execution time: 40.61s


In [12]:
encodings_N_and_SNV_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=True, filter_N=True)

execution time: 45.29s


## Accurate candidate search

### Depth search with N frequency filter

In [13]:
depth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "depth")
    depth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 10.0s
execution time depth search: 1.99s
execution time 156187: 12.38s
mst generation time: 0.31s
execution time xor distance calculation: 9.45s
execution time depth search: 2.09s
execution time 312375: 11.92s
mst generation time: 1.24s
execution time xor distance calculation: 9.05s
execution time depth search: 2.13s
execution time 468562: 11.58s
mst generation time: 1.93s
execution time xor distance calculation: 9.18s
execution time depth search: 2.1s
execution time 624750: 11.67s
mst generation time: 2.42s


In [14]:
pd.DataFrame(depth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.750871,1.0,1.0,0.750871,12.38,1.092953,-0.468351,6.0,641,0.355879,0.998457,0.024057
1,0.1,0.796227,1.0,1.0,0.796227,11.92,1.287209,-0.274096,8.0,397,0.42479,0.997222,0.022822
2,0.15,0.808558,1.0,1.0,0.808558,11.58,1.428055,-0.133249,7.0,258,0.446889,0.996928,0.022528
3,0.2,0.816617,1.0,1.0,0.816617,11.67,1.486513,-0.074791,7.0,194,0.47627,0.995287,0.020887


### Breadth search with N frequency filter

In [15]:
breadth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "breadth")
    breadth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 5.86s
execution time distance collection: 3.43s
execution time breadth search: 2.63s
execution time 156187: 12.35s
mst generation time: 0.51s
matrix generation time: 5.34s
execution time distance collection: 4.08s
execution time breadth search: 2.96s
execution time 312375: 12.83s
mst generation time: 2.32s
matrix generation time: 5.47s
execution time distance collection: 3.38s
execution time breadth search: 3.04s
execution time 468562: 12.38s
mst generation time: 2.07s
matrix generation time: 5.19s
execution time distance collection: 3.44s
execution time breadth search: 2.88s
execution time 624750: 11.93s
mst generation time: 2.38s


In [16]:
pd.DataFrame(breadth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.049775,0.461826,1.0,1.0,0.461826,12.35,1.754102,0.192797,11.0,1,0.175794,0.9936,0.0192
1,0.097519,0.632929,1.0,1.0,0.632929,12.83,1.701881,0.140576,11.0,1,0.242678,0.9936,0.0192
2,0.144773,0.735921,1.0,1.0,0.735921,12.38,1.67431,0.113005,11.0,1,0.31019,0.9908,0.0164
3,0.190036,0.795337,1.0,1.0,0.795337,11.93,1.652301,0.090996,11.0,1,0.366199,0.9904,0.016


### Depth search with N and SNV frequency filter

In [17]:
depth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "depth")
    depth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 6.51s
execution time depth search: 2.01s
execution time 156187: 8.9s
mst generation time: 0.66s
execution time xor distance calculation: 6.04s
execution time depth search: 1.96s
execution time 312375: 8.37s
mst generation time: 1.11s
execution time xor distance calculation: 6.07s
execution time depth search: 2.01s
execution time 468562: 8.45s
mst generation time: 1.74s
execution time xor distance calculation: 6.49s
execution time depth search: 2.03s
execution time 624750: 8.9s
mst generation time: 2.23s


In [18]:
pd.DataFrame(depth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.763278,1.0,1.0,0.763278,8.9,1.371645,-0.189659,16.0,458,0.480052,0.991248,0.016848
1,0.1,0.857448,1.0,1.0,0.857448,8.37,1.476184,-0.08512,16.0,241,0.671509,0.99015,0.01575
2,0.15,0.913737,1.0,1.0,0.913737,8.45,1.505115,-0.05619,16.0,193,0.695646,0.98946,0.01506
3,0.2,0.941144,1.0,1.0,0.941144,8.9,1.549684,-0.01162,16.0,123,0.75775,0.987169,0.012769


### Breadth search with N and SNV frequency filter

In [19]:
breadth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "breadth")
    breadth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 2.44s
execution time distance collection: 3.13s
execution time breadth search: 2.42s
execution time 156187: 8.38s
mst generation time: 0.61s
matrix generation time: 1.98s
execution time distance collection: 3.44s
execution time breadth search: 2.66s
execution time 312375: 8.49s
mst generation time: 1.31s
matrix generation time: 1.99s
execution time distance collection: 3.49s
execution time breadth search: 2.48s
execution time 468562: 8.33s
mst generation time: 1.51s
matrix generation time: 2.1s
execution time distance collection: 3.64s
execution time breadth search: 2.63s
execution time 624750: 8.78s
mst generation time: 2.43s


In [20]:
pd.DataFrame(breadth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.049775,0.406554,1.0,1.0,0.406554,8.38,1.713485,0.152181,12.0,1,0.229914,0.9916,0.0172
1,0.097519,0.596751,1.0,1.0,0.596751,8.49,1.662025,0.10072,11.0,1,0.283506,0.9912,0.0168
2,0.144773,0.708921,1.0,1.0,0.708921,8.33,1.64914,0.087835,11.0,1,0.316049,0.9896,0.0152
3,0.190036,0.785651,1.0,1.0,0.785651,8.78,1.635534,0.07423,11.0,1,0.374757,0.9868,0.0124


In [21]:
depth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

In [22]:
breadth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

### Infection recall for different filters and computation rates using depth search

In [23]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_search_evaluation, "Infection recall", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=30),
        ))

sub_fig.show()

### Community ARI for different filters and computation rates using depth search

In [24]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_search_evaluation, "Community ARI", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=30),
        ))

sub_fig.show()

### Infection recall for different filters and computation rates using breadth search

In [25]:
sub_fig = get_computation_rate_plot("infection_detection_rate", breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.65,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different filters and computation rates using breadth search

In [26]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.55,
            y=0.02,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

In [27]:
depth_vs_breadth_search_evaluation = {
    "Depth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "blue"},
    "Breadth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "green"},  
}

### Infection recall for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [28]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_vs_breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.7,
            y=0.1,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [29]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_vs_breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.7,
            y=0.1,
            xanchor="left",
            yanchor="bottom",
            font=dict(size=35),
        ))
sub_fig.show()