# Algorithm Optimization (nrw_2022, 1,250 sequences)
This notebook contains the evaluation of the optimized algorithm concerning the aggregate nrw_2022 with 1,250 sequences. The distance matrices were evaluated in terms of correlation, errors, as well as infection recall and precision.


In [1]:
!pip install ../../gentrain/.

[0mProcessing /Users/benkraling/code/thesis/gentrain
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: gentrain
  Building wheel for gentrain (setup.py) ... [?25ldone
[?25h  Created wheel for gentrain: filename=gentrain-0.1.2-py3-none-any.whl size=27270 sha256=64e08be85d29f1b4554f69c7e7fc487225d9adfc6bc628421829e7557b5dceed
  Stored in directory: /private/var/folders/2h/923cq6912sqb0snfvqqfdnmm0000gn/T/pip-ephem-wheel-cache-swol0eh1/wheels/cf/e4/57/91c03db2e8c043adeefe35dd0969d3049f61ae0218be0acc9f
Successfully built gentrain
[0mInstalling collected packages: gentrain
  Attempting uninstall: gentrain
    Found existing installation: gentrain 0.1.2
    Uninstalling gentrain-0.1.2:
      Successfully uninstalled gentrain-0.1.2
[0mSuccessfully installed gentrain-0.1.2


In [2]:
import pandas as pd
import numpy as np
from scipy.stats import kendalltau, pearsonr, spearmanr
from sklearn.metrics import mean_squared_error
from gentrain.evaluation import get_lineage_purity
from gentrain.distance_matrix import get_infection_recall, get_kendall_tau_correlation, get_signed_infection_rmse, get_signed_rmse, get_infection_recall, get_infection_precision, get_infection_f1, median_distance
from gentrain.graph import build_graph, build_mst, get_outbreak_community_labels, export_graph_gexf, mean_edge_weight
import plotly.express as px
import plotly.colors as pc
from sklearn.metrics import adjusted_rand_score
import os
import shutil

In [3]:
aggregate = "nrw_2022"
size = 1250

In [4]:
graph_path = f"graphs/{aggregate}/{size}"
if os.path.exists(f"graphs/{aggregate}/{size}") and os.path.isdir(f"graphs/{aggregate}/{size}"):
    shutil.rmtree(f"graphs/{aggregate}/{size}")
if not os.path.exists(f"graphs/{aggregate}"):
    os.mkdir(f"graphs/{aggregate}")
os.mkdir(f"graphs/{aggregate}/{size}")

In [5]:
sequences_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/sequences_and_metadata.csv", delimiter=";", index_col="Unnamed: 0").sort_index()

In [6]:
distance_matrix_gentrain_df = pd.read_csv(f"../00_data_understanding_and_preparation/aggregates/{aggregate}/{size}/distance_matrix.csv", delimiter=";", index_col="Unnamed: 0").sort_index()
distance_matrix_optimized_df = pd.read_csv(f"./distance_matrices/{aggregate}/{size}/distance_matrix.csv", delimiter=";", index_col="Unnamed: 0").sort_index()
triu_mask = np.triu(np.ones_like(distance_matrix_gentrain_df, dtype=bool), k=1)
distance_matrix_gentrain = distance_matrix_gentrain_df.to_numpy()
distance_matrix_optimized = distance_matrix_optimized_df.to_numpy()

In [7]:
gentrain_graph = build_graph(distance_matrix_gentrain)
gentrain_mst = build_mst(gentrain_graph)
gentrain_community_labels = get_outbreak_community_labels(gentrain_mst)
optimized_graph = build_graph(distance_matrix_optimized)
optimized_mst = build_mst(optimized_graph)
optimized_community_labels = get_outbreak_community_labels(optimized_mst)

mst generation time: 2.23s
mst generation time: 2.42s


In [8]:
median_distance(distance_matrix_gentrain)

np.float64(27.0)

In [9]:
export_graph_gexf(gentrain_mst, gentrain_community_labels, sequences_df, f"{graph_path}/gentrain")
export_graph_gexf(optimized_mst, gentrain_community_labels, sequences_df, f"{graph_path}/optimized")

In [10]:
gentrain_mean_edge_weight = mean_edge_weight(gentrain_mst)
optimized_mean_edge_weight = mean_edge_weight(optimized_mst)
optimized_mean_edge_weight - gentrain_mean_edge_weight

np.float64(-0.550840672538035)

In [11]:
gentrain_mean_edge_weight

np.float64(4.849319455564454)

In [12]:
optimized_purity = get_lineage_purity(list(sequences_df["clade"]), optimized_community_labels)
gentrain_purity = get_lineage_purity(list(sequences_df["clade"]), gentrain_community_labels)
optimized_purity - gentrain_purity

np.float64(0.0007999999999999119)

In [13]:
evaluation = []
corr, _ = kendalltau(distance_matrix_gentrain, distance_matrix_optimized)
scores = {
    "signed_rmse": get_signed_rmse(distance_matrix_gentrain, distance_matrix_optimized),
    "signed_infection_rmse": get_signed_infection_rmse(distance_matrix_gentrain, distance_matrix_optimized),
    "correlation": corr,
    "infection_recall": get_infection_recall(distance_matrix_gentrain, distance_matrix_optimized),
    "infection_precision": get_infection_precision(distance_matrix_gentrain, distance_matrix_optimized),
    "infection_f1": get_infection_f1(distance_matrix_gentrain, distance_matrix_optimized),
    "ari_com": adjusted_rand_score(gentrain_community_labels, optimized_community_labels),
    "lineage_purity":get_lineage_purity(list(sequences_df["clade"]), optimized_community_labels),
    "mean_edge_weight": mean_edge_weight(optimized_mst)
}
evaluation.append(scores)
pd.DataFrame(evaluation)

Unnamed: 0,signed_rmse,signed_infection_rmse,correlation,infection_recall,infection_precision,infection_f1,ari_com,lineage_purity,mean_edge_weight
0,-1.877005,-0.40134,0.962203,0.983221,0.54461,0.700957,0.306935,0.952,4.298479


In [14]:
gentrain_flatten = distance_matrix_gentrain.flatten()
optimized_flatten = distance_matrix_optimized.flatten()
data = pd.DataFrame({
    "gentrain": gentrain_flatten,
    "optimized": optimized_flatten
})


data = data.groupby(["gentrain", "optimized"]).size().reset_index(name="count")

fig = px.scatter(data, x="gentrain", y="optimized", color="count", color_continuous_scale=[[0, "#f1f1f1"], [1, "#000000"]], labels={"count": "Count"})

fig.add_shape(
    type="line",
    x0=min(gentrain_flatten.min(), optimized_flatten.min()),
    y0=min(gentrain_flatten.min(), optimized_flatten.min()),
    x1=max(gentrain_flatten.max(), optimized_flatten.max()),
    y1=max(gentrain_flatten.max(), optimized_flatten.max()),
    line=dict(color="black", width=2, dash="dash"),
    name="y = x",
)
fig.update_layout(width=1000, height=1000,template="presentation",font=dict(size=30),  xaxis=dict(
        title=dict(
            text="GENTRAIN algorithm distance",
            standoff=40
        ),
        tickangle=0
    ),
    margin=dict(l=120,r=0,t=0,b=120),
    yaxis=dict(
        title=dict(
            text="Optimized algorithm distance",
            standoff=40
        ),
        tickangle=0
    ),)
fig.show()