In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=27325 sha256=3ed4983defee1a4742f5ec3ebbff71c0d24f630990a3fc8d3614fcdbe13e5ecd
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-0lvamhr2/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pickle
from gentrain.evaluation import get_computation_rate_plot, candidate_evaluation_and_matrices, get_candidate_evaluation_and_export_mst, get_baseline_evaluation_and_export_mst
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
from gentrain.encoding import get_nucleotide_sensitive_encodings, get_mutation_sensitive_encodings, generate_one_hot_encoding
from gentrain.nextclade import get_mutations_from_dataframe
from gentrain.candidate_sourcing import bitwise_xor_candidates
from gentrain.graph import build_mst, export_graph_gexf, mean_edge_weight, get_outbreak_community_labels, build_graph
from scipy.spatial.distance import pdist
import community as community_louvain
import umap
import faiss
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances_argmin
import os
import shutil

In [3]:
aggregate = "nrw_2022"
size = 5000

In [4]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [5]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";").set_index("igs_id").sort_index()
sequences_count = len(sequences_df)

In [6]:
mutations_df = get_mutations_from_dataframe(sequences_df)

In [7]:
distance_matrix_df = pd.read_csv(f"../01_algorithm_optimization/distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";",
                                     index_col="Unnamed: 0").sort_index()
distance_matrix_df = distance_matrix_df[~distance_matrix_df.index.duplicated(keep='first')]
distance_matrix_df = distance_matrix_df.loc[sequences_df.index, sequences_df.index]
distance_matrix = distance_matrix_df.to_numpy()

In [8]:
gentrain_graph = build_graph(distance_matrix)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
datetime_sampling_dates = pd.to_datetime(sequences_df["date_of_sampling"])
numeric_dates = (datetime_sampling_dates - datetime_sampling_dates.min()).dt.days
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/brute_force")

mst generation time: 97.62s


In [9]:
mask = np.triu(np.ones(distance_matrix_df.shape), k=1).astype(bool)
filtered = distance_matrix_df.where(mask)
infections_count = (filtered < 2).sum().sum()
distances_count = filtered.count().sum()

## Encoding

In [10]:
encodings_N_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=False, filter_N=True)

execution time: 65.54s


In [11]:
encodings_N_and_SNV_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=True, filter_N=True)

execution time: 62.25s


## Accurate candidate search

### Depth search with N frequency filter

In [12]:
depth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "depth")
    depth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 76.67s
execution time depth search: 7.9s
execution time 624875: 86.2s
mst generation time: 1.32s
execution time xor distance calculation: 77.69s
execution time depth search: 8.91s
execution time 1249750: 88.23s
mst generation time: 2.89s
execution time xor distance calculation: 77.62s
execution time depth search: 9.6s
execution time 1874625: 89.06s
mst generation time: 11.56s
execution time xor distance calculation: 83.82s
execution time depth search: 10.4s
execution time 2499500: 96.06s
mst generation time: 15.83s


In [13]:
pd.DataFrame(depth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.912987,1.0,1.0,0.912987,86.2,3.233542,15.0,358,0.372452,0.983308,0.028308
1,0.1,0.969776,1.0,1.0,0.969776,88.23,3.331951,19.0,152,0.626862,0.968332,0.013332
2,0.15,0.986541,1.0,1.0,0.986541,89.06,3.314154,19.0,111,0.776623,0.959975,0.004975
3,0.2,0.990909,1.0,1.0,0.990909,96.06,3.337897,22.0,84,0.817148,0.954268,-0.000732


### Breadth search with N frequency filter

In [14]:
breadth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "breadth")
    breadth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 65.81s
execution time distance collection: 16.54s
execution time breadth search: 12.81s
execution time 624875: 97.35s
mst generation time: 4.06s
matrix generation time: 65.56s
execution time distance collection: 13.91s
execution time breadth search: 12.11s
execution time 1249750: 93.82s
mst generation time: 6.77s
matrix generation time: 63.71s
execution time distance collection: 13.7s
execution time breadth search: 12.46s
execution time 1874625: 91.72s
mst generation time: 9.44s
matrix generation time: 61.82s
execution time distance collection: 13.14s
execution time breadth search: 11.92s
execution time 2499500: 88.75s
mst generation time: 12.33s


In [15]:
pd.DataFrame(breadth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.04938,0.775207,1.0,1.0,0.775207,97.35,3.634947,74.0,1,0.383423,0.9678,0.0128
1,0.09751,0.880756,1.0,1.0,0.880756,93.82,3.563973,74.0,1,0.488546,0.9646,0.0096
2,0.144389,0.928571,1.0,1.0,0.928571,91.72,3.521184,74.0,1,0.585905,0.9644,0.0094
3,0.190018,0.951948,1.0,1.0,0.951948,88.75,3.49828,74.0,1,0.662671,0.96,0.005


### Depth search with N and SNV frequency filter

In [16]:
depth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "depth")
    depth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

execution time xor distance calculation: 50.18s
execution time depth search: 8.46s
execution time 624875: 60.22s
mst generation time: 3.89s
execution time xor distance calculation: 49.93s
execution time depth search: 8.94s
execution time 1249750: 60.4s
mst generation time: 6.58s
execution time xor distance calculation: 51.24s
execution time depth search: 8.84s
execution time 1874625: 61.62s
mst generation time: 9.57s
execution time xor distance calculation: 50.82s
execution time depth search: 8.99s
execution time 2499500: 61.45s
mst generation time: 13.07s


In [17]:
pd.DataFrame(depth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.942385,1.0,1.0,0.942385,60.22,3.213547,22.0,357,0.422696,0.972921,0.017921
1,0.1,0.978985,1.0,1.0,0.978985,60.4,3.304535,20.0,171,0.601214,0.969247,0.014247
2,0.15,0.992798,1.0,1.0,0.992798,61.62,3.317772,22.0,116,0.770286,0.960163,0.005163
3,0.2,0.994923,1.0,1.0,0.994923,61.45,3.322666,21.0,94,0.853284,0.959715,0.004715


### Breadth search with N and SNV frequency filter

In [19]:
breadth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "breadth")
    breadth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["clade"]), sequences_df, runtime))

matrix generation time: 36.98s
execution time distance collection: 19.28s
execution time breadth search: 13.31s
execution time 624875: 71.52s
mst generation time: 3.24s
matrix generation time: 36.69s
execution time distance collection: 14.64s
execution time breadth search: 13.17s
execution time 1249750: 66.49s
mst generation time: 7.13s
matrix generation time: 35.58s
execution time distance collection: 13.17s
execution time breadth search: 11.51s
execution time 1874625: 62.03s
mst generation time: 9.65s
matrix generation time: 33.87s
execution time distance collection: 12.91s
execution time breadth search: 11.46s
execution time 2499500: 59.98s
mst generation time: 12.41s


In [20]:
pd.DataFrame(breadth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.04938,0.808501,1.0,1.0,0.808501,71.52,3.612482,74.0,1,0.389345,0.9634,0.0084
1,0.09751,0.914758,1.0,1.0,0.914758,66.49,3.556531,74.0,1,0.492223,0.9622,0.0072
2,0.144389,0.956198,1.0,1.0,0.956198,62.03,3.518604,74.0,1,0.59107,0.959,0.004
3,0.190018,0.971547,1.0,1.0,0.971547,59.98,3.503081,74.0,1,0.653711,0.9586,0.0036


In [21]:
depth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

In [22]:
breadth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

### Infection recall for different filters and computation rates using depth search

In [23]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_search_evaluation, "Infection recall", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=30),
        ))

sub_fig.show()

### Community ARI for different filters and computation rates using depth search

In [24]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_search_evaluation, "Community ARI", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=30),
        ))

sub_fig.show()

### Infection recall for different filters and computation rates using breadth search

In [25]:
sub_fig = get_computation_rate_plot("infection_detection_rate", breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.65,
            y=0.05,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different filters and computation rates using breadth search

In [26]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.55,
            y=0.02,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

In [27]:
depth_vs_breadth_search_evaluation = {
    "Depth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "blue"},
    "Breadth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "green"},  
}

### Infection recall for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [28]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_vs_breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.7,
            y=0.1,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [29]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_vs_breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.7,
            y=0.1,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()