In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=26677 sha256=2c9d64e5977aeeeccf0a0aad3068439200148a673ac21eb5d8b08dc7fd97a0d1
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-yuct7ri2/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
from collections import Counter
import pickle
from gentrain.evaluation import get_computation_rate_plot, candidate_evaluation_and_matrices, get_candidate_evaluation_and_export_mst
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
from gentrain.encoding import get_nucleotide_sensitive_encodings, get_mutation_sensitive_encodings, generate_one_hot_encoding
from gentrain.nextclade import get_mutations_from_dataframe
from gentrain.candidate_sourcing import bitwise_xor_candidates
from gentrain.graph import build_mst, export_graph_gexf, mean_edge_weight, get_outbreak_community_labels, build_graph
from scipy.spatial.distance import pdist
import community as community_louvain
import umap
import faiss
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances_argmin
import os
import shutil

In [3]:
aggregate = "due_202203"
size = 5000

In [4]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [5]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";").set_index("igs_id").sort_index()
sequences_count = len(sequences_df)

In [6]:
mutations_df = get_mutations_from_dataframe(sequences_df)

In [7]:
distance_matrix_df = pd.read_csv(f"../01_algorithm_optimization/distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";",
                                     index_col="Unnamed: 0").sort_index()
distance_matrix_df = distance_matrix_df[~distance_matrix_df.index.duplicated(keep='first')]
distance_matrix_df = distance_matrix_df.loc[sequences_df.index, sequences_df.index]
distance_matrix = distance_matrix_df.to_numpy()

In [8]:
gentrain_graph = build_graph(distance_matrix)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
datetime_sampling_dates = pd.to_datetime(sequences_df["date_of_sampling"])
numeric_dates = (datetime_sampling_dates - datetime_sampling_dates.min()).dt.days
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/brute_force_mst")

mst generation time: 72.48s


In [9]:
communities_count = len(Counter(gentrain_community_labels))

In [10]:
mask = np.triu(np.ones(distance_matrix_df.shape), k=1).astype(bool)
filtered = distance_matrix_df.where(mask)
infections_count = (filtered < 2).sum().sum()
distances_count = filtered.count().sum()

## Encoding

In [11]:
encodings_N_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=False, filter_N=True)

execution time: 103.19s


In [12]:
encodings_N_and_SNV_frequency_filter = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=True, filter_N=True)

execution time: 102.23s


## Accurate candidate search

### Depth search with N frequency filter

In [13]:
depth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "depth")
    depth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["Nextclade_pango"]), sequences_df, runtime))

execution time xor distance calculation: 83.23s
execution time depth search: 10.59s
execution time 624875: 95.97s
mst generation time: 4.42s
execution time xor distance calculation: 76.47s
execution time depth search: 11.24s
execution time 1249750: 89.64s
mst generation time: 8.42s
execution time xor distance calculation: 81.0s
execution time depth search: 12.11s
execution time 1874625: 95.05s
mst generation time: 12.89s
execution time xor distance calculation: 82.06s
execution time depth search: 16.4s
execution time 2499500: 100.44s
mst generation time: 19.11s


In [14]:
pd.DataFrame(depth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.695703,1.0,1.0,0.695703,95.97,0.965122,-0.358342,7.0,1072,0.207985,0.939424,0.050624
1,0.1,0.743299,1.0,1.0,0.743299,89.64,1.135038,-0.188427,6.0,639,0.222916,0.919901,0.031101
2,0.15,0.760276,1.0,1.0,0.760276,95.05,1.261499,-0.061966,13.0,369,0.251514,0.908294,0.019494
3,0.2,0.771006,1.0,1.0,0.771006,100.44,1.325453,0.001988,15.0,250,0.266668,0.896234,0.007434


### Breadth search with N frequency filter

In [15]:
breadth_search_N_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_frequency_filter, limit, "breadth")
    breadth_search_N_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["Nextclade_pango"]), sequences_df, runtime))

matrix generation time: 56.02s
execution time distance collection: 21.98s
execution time breadth search: 16.42s
execution time 624875: 96.58s
mst generation time: 5.95s
matrix generation time: 62.99s
execution time distance collection: 23.92s
execution time breadth search: 17.23s
execution time 1249750: 106.43s
mst generation time: 10.99s
matrix generation time: 66.62s
execution time distance collection: 25.2s
execution time breadth search: 18.05s
execution time 1874625: 112.29s
mst generation time: 14.03s
matrix generation time: 56.77s
execution time distance collection: 21.46s
execution time breadth search: 17.95s
execution time 2499500: 99.0s
mst generation time: 25.02s


In [16]:
pd.DataFrame(breadth_search_N_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.04938,0.435922,1.0,1.0,0.435922,96.58,1.479376,0.155911,44.0,1,0.159989,0.8788,-0.01
1,0.09751,0.596675,1.0,1.0,0.596675,106.43,1.435347,0.111882,44.0,1,0.199714,0.8832,-0.0056
2,0.144389,0.694608,1.0,1.0,0.694608,112.29,1.412843,0.089378,44.0,1,0.23269,0.8754,-0.0134
3,0.190018,0.747806,1.0,1.0,0.747806,99.0,1.39842,0.074955,24.0,1,0.256396,0.889,0.0002


### Depth search with N and SNV frequency filter

In [17]:
depth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "depth")
    depth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("depth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["Nextclade_pango"]), sequences_df, runtime))

execution time xor distance calculation: 62.48s
execution time depth search: 13.04s
execution time 624875: 77.81s
mst generation time: 6.02s
execution time xor distance calculation: 48.27s
execution time depth search: 11.97s
execution time 1249750: 62.06s
mst generation time: 8.98s
execution time xor distance calculation: 47.09s
execution time depth search: 11.19s
execution time 1874625: 60.01s
mst generation time: 13.69s
execution time xor distance calculation: 47.77s
execution time depth search: 11.5s
execution time 2499500: 61.04s
mst generation time: 20.57s


In [18]:
pd.DataFrame(depth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.05,0.767382,1.0,1.0,0.767382,77.81,1.14407,-0.179395,17.0,759,0.27503,0.952508,0.063708
1,0.1,0.825945,1.0,1.0,0.825945,62.06,1.26421,-0.059255,21.0,387,0.263431,0.932225,0.043425
2,0.15,0.869406,1.0,1.0,0.869406,60.01,1.312302,-0.011163,21.0,204,0.333407,0.92295,0.03415
3,0.2,0.908936,1.0,1.0,0.908936,61.04,1.336759,0.013294,21.0,125,0.336221,0.922218,0.033418


### Breadth search with N and SNV frequency filter

In [19]:
breadth_search_N_and_SNV_frequency_filter = []
for computation_rate in [0.05, 0.1, 0.15, 0.2]:
    limit = int(computation_rate*distances_count)
    candidates, runtime = bitwise_xor_candidates(encodings_N_and_SNV_frequency_filter, limit, "breadth")
    breadth_search_N_and_SNV_frequency_filter.append(get_candidate_evaluation_and_export_mst("breadth_N_and_SNV", candidates, graph_path, distance_matrix, gentrain_community_labels, gentrain_mst, list(sequences_df["Nextclade_pango"]), sequences_df, runtime))

matrix generation time: 21.74s
execution time distance collection: 19.48s
execution time breadth search: 14.83s
execution time 624875: 58.08s
mst generation time: 6.07s
matrix generation time: 22.52s
execution time distance collection: 20.66s
execution time breadth search: 14.85s
execution time 1249750: 60.07s
mst generation time: 11.23s
matrix generation time: 25.13s
execution time distance collection: 20.54s
execution time breadth search: 15.56s
execution time 1874625: 63.33s
mst generation time: 13.53s
matrix generation time: 22.03s
execution time distance collection: 19.98s
execution time breadth search: 15.28s
execution time 2499500: 59.27s
mst generation time: 17.64s


In [20]:
pd.DataFrame(breadth_search_N_and_SNV_frequency_filter)

Unnamed: 0,computation_rate,infection_detection_rate,infection_recall,infection_precision,infection_f1,runtime,mean_edge_weight,mean_edge_weight_diff,max_edge_weight,subgraph_count,adjusted_rand_index,lineage_purity,lineage_purity_diff
0,0.04938,0.411334,1.0,1.0,0.411334,58.08,1.452671,0.129206,44.0,1,0.176665,0.9008,0.012
1,0.09751,0.590606,1.0,1.0,0.590606,60.07,1.410842,0.087377,44.0,1,0.22797,0.8928,0.004
2,0.144389,0.70391,1.0,1.0,0.70391,63.33,1.393739,0.070274,44.0,1,0.263546,0.8782,-0.0106
3,0.190018,0.785902,1.0,1.0,0.785902,59.27,1.386557,0.063093,44.0,1,0.294397,0.8686,-0.0202


In [21]:
depth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

In [22]:
breadth_search_evaluation = {
    "N frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_frequency_filter}, "stroke": "dash", "color": "black"},
    "N & SNV frequency filter": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "dot", "color": "black"}
}

### Infection recall for different filters and computation rates using depth search

In [23]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_search_evaluation, "Infection recall", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=30),
        ))

sub_fig.show()

### Community ARI for different filters and computation rates using depth search

In [24]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_search_evaluation, "Community ARI", dict(
            x=0.85,
            y=0,
            itemwidth=60,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=30),
        ))

sub_fig.show()

### Infection recall for different filters and computation rates using breadth search

In [25]:
sub_fig = get_computation_rate_plot("infection_detection_rate", breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.65,
            y=0.05,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different filters and computation rates using breadth search

In [26]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.55,
            y=0.02,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

In [27]:
depth_vs_breadth_search_evaluation = {
    "Depth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in depth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "blue"},
    "Breadth search": {"values": {evaluation["computation_rate"]: evaluation for evaluation in breadth_search_N_and_SNV_frequency_filter}, "stroke": "solid", "color": "green"},  
}

### Infection recall for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [28]:
sub_fig = get_computation_rate_plot("infection_detection_rate", depth_vs_breadth_search_evaluation, "Infection recall", legend=dict(
            x=0.7,
            y=0.1,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()

### Community ARI for different computation rates using depth search and breadth search (N and SNV frequency filter)

In [29]:
sub_fig = get_computation_rate_plot("adjusted_rand_index", depth_vs_breadth_search_evaluation, "Community ARI", legend=dict(
            x=0.7,
            y=0.1,
            xanchor='left',
            yanchor='bottom',
            font=dict(size=35),
        ))
sub_fig.show()