In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=26911 sha256=2f1ad9c939adbed92f7c26d47c22f324c0da5ac3039b9d5c5b7bd21478a34f73
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-5ilsv73n/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import plotly.io as pio
pio.renderers.default = "png"

In [3]:
import pandas as pd
from gentrain.encoding import get_mutation_sensitive_encodings
from gentrain.candidate_sourcing import bitwise_xor_candidates, and_or_lsh, get_hnsw_candidates
from gentrain.distance_matrix import get_bitwise_xor_distance_matrix
from gentrain.graph import build_mst, export_graph_gexf, get_outbreak_community_labels, build_graph
import networkx as nx
from sklearn.metrics.cluster import adjusted_rand_score
from collections import Counter
import numpy as np
import os
from gentrain.evaluation import get_candidate_evaluation_and_export_mst
from gentrain.distance_matrix import get_bitwise_xor_distance_matrix, get_kendall_tau_correlation, get_infection_recall, get_signed_rmse, get_signed_infection_rmse
import os
import shutil
import faiss

In [4]:
aggregate = "nrw_2022"
size = 10000

In [5]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [6]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";").set_index("igs_id").sort_index()
sequences_count = len(sequences_df)
calculations_count = (sequences_count*(sequences_count-1))/2

In [7]:
distance_matrix_df = pd.read_csv(f"../01_algorithm_optimization/distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";",
                                     index_col="Unnamed: 0").sort_index()
distance_matrix_df = distance_matrix_df[~distance_matrix_df.index.duplicated(keep='first')]
distance_matrix_df = distance_matrix_df.loc[sequences_df.index, sequences_df.index]
distance_matrix = distance_matrix_df.to_numpy()

In [8]:
gentrain_graph = build_graph(distance_matrix)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
datetime_sampling_dates = pd.to_datetime(sequences_df["date_of_sampling"])
numeric_dates = (datetime_sampling_dates - datetime_sampling_dates.min()).dt.days
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/gentrain")

mst generation time: 2501.11s


In [9]:
encodings = get_mutation_sensitive_encodings(sequences_df, exclude_indels=False, use_frequency_filtering=True, filter_N=True)
encodings_length = len(encodings[0])

execution time: 114.26s


In [10]:
encodings_length

6328

In [11]:
accurate_evaluation = {}
for calculation_rate in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
    print(calculation_rate)
    accurate_candidates, runtime = bitwise_xor_candidates(encodings, int(calculation_rate * calculations_count), "depth")
    accurate_evaluation[calculation_rate] = get_candidate_evaluation_and_export_mst(f"accurate_candidates_{calculation_rate}", accurate_candidates, graph_path, distance_matrix, gentrain_community_labels, sequences_df, runtime)

0.001
execution time xor distance calculation: 274.87s
execution time depth search: 225.53s
execution time 49995: 507.87s
mst generation time: 7.26s
0.01
execution time xor distance calculation: 274.71s
execution time depth search: 275.13s
execution time 499950: 559.07s
mst generation time: 32.86s
0.02
execution time xor distance calculation: 290.02s
execution time depth search: 278.47s
execution time 999900: 577.41s
mst generation time: 51.61s
0.03
execution time xor distance calculation: 299.91s
execution time depth search: 278.43s
execution time 1499850: 587.27s
mst generation time: 72.63s
0.04
execution time xor distance calculation: 278.81s
execution time depth search: 281.28s
execution time 1999800: 568.95s
mst generation time: 93.34s
0.05
execution time xor distance calculation: 2242.91s
execution time depth search: 2934.5s
execution time 2499750: 5186.39s
mst generation time: 1951.03s
0.06
execution time xor distance calculation: 493.77s
execution time depth search: 1186.86s
ex

In [12]:
approximate_evaluation = {}
for hash_length in [2000, 2500, 3000, 3500]:
    approximate_evaluation[hash_length] = {}
    for iterations in [1, 2, 4, 8, 16, 32, 64]:
        print(hash_length, iterations)
        approximate_lsh_candidates = and_or_lsh(encodings, hash_length, iterations)
        approximate_evaluation[hash_length][iterations] = get_candidate_evaluation_and_export_mst(f"approximate_candates_{hash_length}_{iterations}", approximate_lsh_candidates, graph_path, distance_matrix, gentrain_community_labels, sequences_df, 0)

2000 1
xor lsh execution time: 1.35
mst generation time: 0.74s
2000 2
xor lsh execution time: 2.93
mst generation time: 77.81s
2000 4
xor lsh execution time: 14.0
mst generation time: 82.33s
2000 8
xor lsh execution time: 11.04
mst generation time: 102.76s
2000 16
xor lsh execution time: 22.59
mst generation time: 155.11s
2000 32
xor lsh execution time: 53.42
mst generation time: 196.1s
2000 64
xor lsh execution time: 88.43
mst generation time: 237.19s
2000 128
xor lsh execution time: 177.48
mst generation time: 210.25s
2000 256
xor lsh execution time: 461.4
mst generation time: 388.45s
3000 1
xor lsh execution time: 1.93
mst generation time: 17.97s
3000 2
xor lsh execution time: 4.01
mst generation time: 27.52s
3000 4
xor lsh execution time: 16.75
mst generation time: 27.28s
3000 8
xor lsh execution time: 24.98
mst generation time: 46.55s
3000 16
xor lsh execution time: 50.08
mst generation time: 67.53s
3000 32
xor lsh execution time: 91.01
mst generation time: 79.55s
3000 64
xor lsh 

KeyboardInterrupt: 

In [13]:
approximate_evaluation

{2000: {1: {'computation_rate': np.float64(0.008665286528652866),
   'infection_detection_rate': np.float64(0.37361586178908923),
   'runtime': 0,
   'mean_edge_weight': np.float64(2.837505206164141),
   'max_edge_weight': np.float64(32.0),
   'subgraph_count': 5198,
   'adjusted_rand_index': 0.07369369907975616},
  2: {'computation_rate': np.float64(0.028596739673967395),
   'infection_detection_rate': np.float64(0.6331091657487391),
   'runtime': 0,
   'mean_edge_weight': np.float64(3.122219020172888),
   'max_edge_weight': np.float64(30.0),
   'subgraph_count': 3060,
   'adjusted_rand_index': 0.15781845617362872},
  4: {'computation_rate': np.float64(0.042232843284328435),
   'infection_detection_rate': np.float64(0.7484491854600267),
   'runtime': 0,
   'mean_edge_weight': np.float64(3.176170912078836),
   'max_edge_weight': np.float64(32.0),
   'subgraph_count': 1481,
   'adjusted_rand_index': 0.23528547363555594},
  8: {'computation_rate': np.float64(0.054993459345934595),
   'in

In [None]:
hnsw_evaluation = {}
for calculation_rate in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07,  0.08, 0.09, 0.1, 0.15, 0.2]:
    hnsw_candidates, runtime = get_hnsw_candidates(encodings, int(calculation_rate * calculations_count))
    hnsw_evaluation[calculation_rate] = get_candidate_evaluation_and_export_mst(f"hnsw_candidates_{calculation_rate}", hnsw_candidates, graph_path, distance_matrix, gentrain_community_labels, sequences_df, runtime)

execution time 49995: 23.09s
mst generation time: 0.07s
execution time 499950: 22.78s


## Evaluation

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
import pandas as pd
fig = go.Figure()

computation_rates = {}
ari = {}

for index, hash_length in enumerate([2000, 3000, 4000, 5000]):
    computation_rates[hash_length] = []
    ari[hash_length] = []
    for iterations in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
        if approximate_evaluation[hash_length][iterations]["computation_rate"] <= 0.1:
            computation_rates[hash_length].append(approximate_evaluation[hash_length][iterations]["computation_rate"])
            ari[hash_length].append(approximate_evaluation[hash_length][iterations]["adjusted_rand_index"])
            fig.add_trace(go.Scatter(
                x=[approximate_evaluation[hash_length][iterations]["computation_rate"]],
                y=[approximate_evaluation[hash_length][iterations]["adjusted_rand_index"]],
                mode='markers',
                showlegend=False,
                marker=dict(color=pc.qualitative.Plotly[index], size=8),
            ))

for index, hash_length in enumerate([2000, 3000, 4000, 5000]):
    fig.add_trace(go.Scatter(
        x=computation_rates[hash_length],
        y=ari[hash_length],
        mode='lines',
        name=hash_length,
        line=dict(color=pc.qualitative.Plotly[index], width=4, dash="solid"),
    ))

fig.update_layout(
    width=1000,
    height=800,
    xaxis=dict(
        title=dict(
            text='Calculation rate',
            standoff=40
        ),
        tickangle=0
    ),
    yaxis=dict(
        title=dict(
            text='Community ARI',
            standoff=40
        ),
        tickangle=0
    ),
    legend_title="Hash length",
    template='presentation',
    font=dict(size=30),
    margin=dict(l=120,r=0,t=0,b=100),
      legend=dict(
        itemwidth=80,
        x=0.7,
        y=0.1,
        xanchor='left',
        orientation='v',
        borderwidth=0
    ),
)
fig.show()
fig.write_image("figures/approximate_cs_hash_lengths_ari_nrw_2022.svg", width=1000, height=800)

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
import pandas as pd
fig = go.Figure()

computation_rates = {}
infection_detection_rates = {}

for index, hash_length in enumerate([2000, 3000, 4000, 5000]):
    computation_rates[hash_length] = []
    infection_detection_rates[hash_length] = []
    for iterations in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
        if approximate_evaluation[hash_length][iterations]["computation_rate"] < 0.1:
            computation_rates[hash_length].append(approximate_evaluation[hash_length][iterations]["computation_rate"])
            infection_detection_rates[hash_length].append(approximate_evaluation[hash_length][iterations]["infection_detection_rate"])
            fig.add_trace(go.Scatter(
                x=[approximate_evaluation[hash_length][iterations]["computation_rate"]],
                y=[approximate_evaluation[hash_length][iterations]["infection_detection_rate"]],
                mode='markers',
                showlegend=False,
                marker=dict(color=pc.qualitative.Plotly[index], size=8),
            ))

for index, hash_length in enumerate([2000, 3000, 4000, 5000]):
    fig.add_trace(go.Scatter(
        x=computation_rates[hash_length],
        y=infection_detection_rates[hash_length],
        mode='lines',
        name=hash_length,
        line=dict(color=pc.qualitative.Plotly[index], width=4, dash="solid"),
    ))



fig.update_layout(
    width=1000,
    height=800,
    xaxis=dict(
        title=dict(
            text='Calculation rate',
            standoff=40
        ),
        tickangle=0
    ),
    yaxis=dict(
        title=dict(
            text='Infection recall',
            standoff=40
        ),
        tickangle=0
    ),
    legend_title="Hash length",
    template='presentation',
    font=dict(size=30),
    margin=dict(l=120,r=0,t=0,b=100),
    legend=dict(
        itemwidth=80,
        x=0.7,
        y=0.1,
        xanchor='left',
        orientation='v',
        borderwidth=0
    ),
)
fig.show()
fig.write_image("figures/approximate_cs_hash_lengths_infection_recall_nrw_2022.svg", width=1000, height=800)

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
import pandas as pd
fig = go.Figure()

computation_rates = []
infection_detection_rates = []
for iterations in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
    if approximate_evaluation[3000][iterations]["computation_rate"] < 0.1:
        computation_rates.append(approximate_evaluation[3000][iterations]["computation_rate"])
        infection_detection_rates.append(approximate_evaluation[3000][iterations]["infection_detection_rate"])
        fig.add_trace(go.Scatter(
            x=[approximate_evaluation[3000][iterations]["computation_rate"]],
            y=[approximate_evaluation[3000][iterations]["infection_detection_rate"]],
            mode='markers',
            showlegend=False,
            marker=dict(color=pc.qualitative.Plotly[2], size=8),
        ))

fig.add_trace(go.Scatter(
    x=computation_rates,
    y=infection_detection_rates,
    mode='lines',
    name="AND OR LSH",
    line=dict(color=pc.qualitative.Plotly[2], width=4, dash="solid"),
))

computation_rates_hnsw = []
ari_hnsw = []
for computation_rate in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15, 0.2]:
    computation_rates_hnsw.append(hnsw_evaluation[computation_rate]["computation_rate"])
    ari_hnsw.append(hnsw_evaluation[computation_rate]["infection_detection_rate"])
    fig.add_trace(go.Scatter(
        x=[hnsw_evaluation[computation_rate]["computation_rate"]],
        y=[hnsw_evaluation[computation_rate]["infection_detection_rate"]],
        mode='markers',
        showlegend=False,
        marker=dict(color=pc.qualitative.Plotly[6], size=8),
    ))
    
fig.add_trace(go.Scatter(
    x=computation_rates_hnsw,
    y=ari_hnsw,
    mode='lines',
    name="HNSW",
    line=dict(color=pc.qualitative.Plotly[6], width=4, dash="solid"),
))

ari_accurate = []
for computation_rate in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09]:
    ari_accurate.append(accurate_evaluation[computation_rate]["infection_detection_rate"])
    fig.add_trace(go.Scatter(
        x=[computation_rate],
        y=[accurate_evaluation[computation_rate]["infection_detection_rate"]],
        mode='markers',
        showlegend=False,
        marker=dict(color="grey", size=8),
    ))

fig.add_trace(go.Scatter(
    x=[0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09],
    y=ari_accurate,
    mode='lines',
    name="Accurate CS",
    line=dict(color="grey", width=4, dash="dash"),
))


fig.update_layout(
    width=1000,
    height=800,
    xaxis=dict(
        title=dict(
            text='Calculation rate',
            standoff=40
        ),
        tickangle=0
    ),
    yaxis=dict(
        title=dict(
            text='Infection recall',
            standoff=40
        ),
        tickangle=0
    ),
    legend_title="",
    template='presentation',
    font=dict(size=30),
    margin=dict(l=120,r=0,t=0,b=100),
    legend=dict(
        itemwidth=80,
        x=0.6,
        y=0.1,
        xanchor='left',
        orientation='v',
        borderwidth=0
    ),
)
fig.show()
fig.write_image("figures/approximate_cs_comparison_infection_recall_nrw_2022.svg", width=1000, height=800)

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
import pandas as pd
fig = go.Figure()

computation_rates = []
infection_detection_rates = []
for iterations in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
    if approximate_evaluation[3000][iterations]["computation_rate"] < 0.1:
        computation_rates.append(approximate_evaluation[3000][iterations]["computation_rate"])
        infection_detection_rates.append(approximate_evaluation[3000][iterations]["adjusted_rand_index"])
        fig.add_trace(go.Scatter(
            x=[approximate_evaluation[3000][iterations]["computation_rate"]],
            y=[approximate_evaluation[3000][iterations]["adjusted_rand_index"]],
            mode='markers',
            showlegend=False,
            marker=dict(color=pc.qualitative.Plotly[2], size=8),
        ))

fig.add_trace(go.Scatter(
    x=computation_rates,
    y=infection_detection_rates,
    mode='lines',
    name="AND OR LSH",
    line=dict(color=pc.qualitative.Plotly[2], width=4, dash="solid"),
))

computation_rates_hnsw = []
ari_hnsw = []
for computation_rate in [0.001, 0.01, 0.02, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15, 0.2]:
    computation_rates_hnsw.append(hnsw_evaluation[computation_rate]["computation_rate"])
    ari_hnsw.append(hnsw_evaluation[computation_rate]["adjusted_rand_index"])
    fig.add_trace(go.Scatter(
        x=[hnsw_evaluation[computation_rate]["computation_rate"]],
        y=[hnsw_evaluation[computation_rate]["adjusted_rand_index"]],
        mode='markers',
        showlegend=False,
        marker=dict(color=pc.qualitative.Plotly[6], size=8),
    ))
    
fig.add_trace(go.Scatter(
    x=computation_rates_hnsw,
    y=ari_hnsw,
    mode='lines',
    name="HNSW",
    line=dict(color=pc.qualitative.Plotly[6], width=4, dash="solid"),
))

ari_accurate = []
for computation_rate in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09]:
    ari_accurate.append(accurate_evaluation[computation_rate]["adjusted_rand_index"])
    fig.add_trace(go.Scatter(
        x=[computation_rate],
        y=[accurate_evaluation[computation_rate]["adjusted_rand_index"]],
        mode='markers',
        showlegend=False,
        marker=dict(color="grey", size=8),
    ))

fig.add_trace(go.Scatter(
    x=[0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09],
    y=ari_accurate,
    mode='lines',
    name="Accurate CS",
    line=dict(color="grey", width=4, dash="dash"),
))


fig.update_layout(
    width=1000,
    height=800,
    xaxis=dict(
        title=dict(
            text='Calculation rate',
            standoff=40
        ),
        tickangle=0
    ),
    yaxis=dict(
        title=dict(
            text='Community ARI',
            standoff=40
        ),
        tickangle=0
    ),
    legend_title="",
    template='presentation',
    font=dict(size=30),
    margin=dict(l=120,r=0,t=0,b=100),
    legend=dict(
        itemwidth=80,
        x=0.6,
        y=0.1,
        xanchor='left',
        orientation='v',
        borderwidth=0
    ),
)
fig.show()
fig.write_image("figures/approximate_cs_comparison_ari_nrw_2022.svg", width=1000, height=800)