# Frequency correlation plots for simulated populations

Another attempt at calculating clade frequencies from tip-to-clade mappings without using a full tree.

In [1]:
import altair as alt
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import seaborn as sns

%matplotlib inline

In [2]:
sns.set_style("white")
plt.style.use("huddlej")

In [3]:
mpl.rcParams['savefig.dpi'] = 200
mpl.rcParams['figure.dpi'] = 200
mpl.rcParams['font.weight'] = 300
mpl.rcParams['axes.labelweight'] = 300
mpl.rcParams['font.size'] = 18

In [4]:
!pwd

/Users/jlhudd/projects/nextstrain/flu-forecasting/analyses


## Load data

In [5]:
data_root = "../results/builds/simulated/simulated_sample_3/"

In [6]:
tips = pd.read_csv(
    "%s/tip_attributes_with_weighted_distances.tsv" % data_root,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["strain", "timepoint", "frequency"]
)

In [7]:
tips.head(1)

Unnamed: 0,strain,timepoint,frequency
0,sample_3236_14,2016-10-01,1.3e-05


In [8]:
tips_to_clades = pd.read_csv("%s/tips_to_clades.tsv" % data_root, sep="\t", parse_dates=["timepoint"])

## Find clades for tips at future timepoint

In [9]:
tips["future_timepoint"] = tips["timepoint"] + pd.DateOffset(months=12)

In [10]:
tips.head()

Unnamed: 0,strain,timepoint,frequency,future_timepoint
0,sample_3236_14,2016-10-01,1.3e-05,2017-10-01
1,sample_3236_17,2016-10-01,1.3e-05,2017-10-01
2,sample_3236_19,2016-10-01,1.3e-05,2017-10-01
3,sample_3236_21,2016-10-01,1.3e-05,2017-10-01
4,sample_3236_22,2016-10-01,1.3e-05,2017-10-01


In [24]:
tips_with_clades = tips.merge(
    tips_to_clades,
    left_on=["strain", "timepoint"],
    right_on=["tip", "timepoint"]
).drop(columns=["tip"]).sort_values(["timepoint", "strain", "depth"])

In [12]:
closest_clades_to_tips = tips_to_clades[tips_to_clades["depth"] == 0].copy()

In [13]:
closest_clades_to_tips.head()

Unnamed: 0,tip,clade_membership,depth,timepoint
0,sample_2152_67,5a98c8d,0,2016-10-01
2,sample_2180_74,0f7e665,0,2016-10-01
5,sample_2224_75,0f7e665,0,2016-10-01
8,sample_2224_30,0f7e665,0,2016-10-01
11,sample_2200_47,0f7e665,0,2016-10-01


In [14]:
closest_clades_to_tips.shape

(317520, 4)

In [15]:
future_clades_for_current_tips = tips.merge(
    closest_clades_to_tips,
    left_on=["strain", "future_timepoint"],
    right_on=["tip", "timepoint"],
    suffixes=["", "_future"]
).drop(columns=["tip", "depth", "timepoint_future"])

In [16]:
future_clades_for_current_tips.head()

Unnamed: 0,strain,timepoint,frequency,future_timepoint,clade_membership
0,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,236287d
1,sample_3236_17,2016-10-01,1.3e-05,2017-10-01,607c9ba
2,sample_3236_19,2016-10-01,1.3e-05,2017-10-01,025b499
3,sample_3236_21,2016-10-01,1.3e-05,2017-10-01,236287d
4,sample_3236_22,2016-10-01,1.3e-05,2017-10-01,236287d


In [17]:
future_clades_for_current_tips.shape

(29610, 5)

In [18]:
future_clades_for_current_tips.groupby("timepoint")["frequency"].sum()

timepoint
2016-10-01    0.999992
2017-04-01    0.999948
2017-10-01    0.999954
2018-04-01    0.999964
2018-10-01    0.999980
2019-04-01    0.999958
2019-10-01    0.999995
2020-04-01    0.999993
2020-10-01    1.000014
2021-04-01    0.999946
2021-10-01    0.999965
2022-04-01    0.999971
2022-10-01    0.999965
2023-04-01    0.999934
2023-10-01    0.999937
2024-04-01    0.999936
2024-10-01    0.999948
2025-04-01    0.999969
2025-10-01    0.999941
2026-04-01    0.999946
2026-10-01    0.999977
2027-04-01    0.999968
2027-10-01    0.999980
2028-04-01    0.999943
2028-10-01    0.999949
2029-04-01    0.999987
2029-10-01    0.999965
2030-04-01    0.999966
2030-10-01    0.999952
2031-04-01    0.999962
2031-10-01    0.999954
2032-04-01    0.999936
2032-10-01    0.999960
2033-04-01    0.999971
2033-10-01    0.999982
2034-04-01    0.999963
2034-10-01    0.999966
2035-04-01    0.999955
2035-10-01    0.999973
2036-04-01    0.999985
2036-10-01    0.999994
2037-04-01    0.999973
2037-10-01    0.999955
2

Find tips at future timepoint with same clades as current tips. To do this, first find the distinct future clades associated with current tips. Group by timepoint and clade to get the total frequency of each clade as the "initial frequency".

In [19]:
future_clades_for_current_timepoints = future_clades_for_current_tips.groupby([
    "timepoint", "future_timepoint", "clade_membership"
])["frequency"].sum().reset_index()

In [20]:
future_clades_for_current_timepoints.head()

Unnamed: 0,timepoint,future_timepoint,clade_membership,frequency
0,2016-10-01,2017-10-01,00fecc0,0.000819
1,2016-10-01,2017-10-01,0207ea8,0.003637
2,2016-10-01,2017-10-01,025b499,2.6e-05
3,2016-10-01,2017-10-01,02f3aa6,0.002266
4,2016-10-01,2017-10-01,0442976,0.002118


In [21]:
np.allclose(
    np.ones_like(future_clades_for_current_timepoints.groupby("timepoint")["frequency"].sum().values),
    future_clades_for_current_timepoints.groupby("timepoint")["frequency"].sum().values,
    1e-4
)

True

In [22]:
future_clades_for_current_timepoints.groupby("timepoint")["frequency"].sum().values

array([0.999992, 0.999948, 0.999954, 0.999964, 0.99998 , 0.999958,
       0.999995, 0.999993, 1.000014, 0.999946, 0.999965, 0.999971,
       0.999965, 0.999934, 0.999937, 0.999936, 0.999948, 0.999969,
       0.999941, 0.999946, 0.999977, 0.999968, 0.99998 , 0.999943,
       0.999949, 0.999987, 0.999965, 0.999966, 0.999952, 0.999962,
       0.999954, 0.999936, 0.99996 , 0.999971, 0.999982, 0.999963,
       0.999966, 0.999955, 0.999973, 0.999985, 0.999994, 0.999973,
       0.999955, 0.999951, 0.999987, 0.999973, 0.999977])

In [25]:
tips_with_clades.head()

Unnamed: 0,strain,timepoint,frequency,future_timepoint,clade_membership,depth
0,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,236287d,0
1,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,b03374d,6
2,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,d98135f,15
3,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,01a86f0,20
4,sample_3236_14,2016-10-01,1.3e-05,2017-10-01,81ba982,23


In [26]:
tips.shape

(30870, 4)

In [27]:
tips_to_clades.shape

(4546068, 4)

In [28]:
tips_with_clades.shape

(723318, 6)

Next, find future tips that belong to the same clades as the current tips or which have descended from these clades. Instead of taking every clade assigned to each tip, we want to pick the closest clade to each tip.

In [29]:
future_clades_for_current_timepoints.merge(
    tips_with_clades,
    how="left",
    left_on=["future_timepoint", "clade_membership"],
    right_on=["timepoint", "clade_membership"],
    suffixes=["", "_future"]
).drop(columns=["timepoint_future", "future_timepoint_future"]).sort_values(["timepoint", "strain", "depth"]).head(10)

Unnamed: 0,timepoint,future_timepoint,clade_membership,frequency,strain,frequency_future,depth
163,2016-10-01,2017-10-01,539d848,0.107898,sample_3436_1,1.3e-05,8.0
777,2016-10-01,2017-10-01,76839f5,0.046673,sample_3436_1,1.3e-05,11.0
1572,2016-10-01,2017-10-01,f40963d,0.002408,sample_3436_1,1.3e-05,14.0
1502,2016-10-01,2017-10-01,de9425b,0.027904,sample_3436_11,1.3e-05,0.0
1366,2016-10-01,2017-10-01,a54b235,0.085959,sample_3436_11,1.3e-05,9.0
1522,2016-10-01,2017-10-01,ec2a1ce,0.000416,sample_3436_11,1.3e-05,10.0
110,2016-10-01,2017-10-01,4cd3f27,0.000113,sample_3436_11,1.3e-05,17.0
164,2016-10-01,2017-10-01,539d848,0.107898,sample_3436_12,1.3e-05,6.0
778,2016-10-01,2017-10-01,76839f5,0.046673,sample_3436_12,1.3e-05,13.0
1573,2016-10-01,2017-10-01,f40963d,0.002408,sample_3436_12,1.3e-05,16.0


In [30]:
future_tips_with_future_clades = future_clades_for_current_timepoints.merge(
    tips_with_clades,
    left_on=["future_timepoint", "clade_membership"],
    right_on=["timepoint", "clade_membership"],
    suffixes=["", "_future"]
).drop(columns=["timepoint_future", "future_timepoint_future"]).sort_values(["timepoint", "strain", "depth"]).groupby([
    "timepoint", "future_timepoint", "strain", "frequency_future"
]).first().reset_index()

In [32]:
future_tips_with_future_clades.head()

Unnamed: 0,timepoint,future_timepoint,strain,frequency_future,clade_membership,frequency,depth
0,2016-10-01,2017-10-01,sample_3436_1,1.3e-05,539d848,0.107898,8
1,2016-10-01,2017-10-01,sample_3436_11,1.3e-05,de9425b,0.027904,0
2,2016-10-01,2017-10-01,sample_3436_12,1.3e-05,539d848,0.107898,6
3,2016-10-01,2017-10-01,sample_3436_13,1.3e-05,5923b41,0.025487,5
4,2016-10-01,2017-10-01,sample_3436_17,1.3e-05,539d848,0.107898,8


In [33]:
future_tips_with_future_clades.groupby("timepoint")["frequency_future"].sum()

timepoint
2016-10-01    0.999954
2017-04-01    0.999964
2017-10-01    0.999980
2018-04-01    0.999958
2018-10-01    0.999995
2019-04-01    0.999993
2019-10-01    1.000014
2020-04-01    0.999946
2020-10-01    0.999965
2021-04-01    0.999971
2021-10-01    0.999965
2022-04-01    0.999934
2022-10-01    0.999937
2023-04-01    0.999936
2023-10-01    0.999948
2024-04-01    0.999969
2024-10-01    0.999941
2025-04-01    0.999946
2025-10-01    0.999977
2026-04-01    0.999968
2026-10-01    0.999980
2027-04-01    0.999943
2027-10-01    0.999949
2028-04-01    0.999987
2028-10-01    0.999965
2029-04-01    0.999966
2029-10-01    0.999952
2030-04-01    0.999962
2030-10-01    0.999954
2031-04-01    0.999936
2031-10-01    0.999960
2032-04-01    0.999971
2032-10-01    0.999982
2033-04-01    0.999963
2033-10-01    0.999966
2034-04-01    0.999955
2034-10-01    0.999973
2035-04-01    0.999985
2035-10-01    0.999994
2036-04-01    0.999973
2036-10-01    0.999955
2037-04-01    0.999951
2037-10-01    0.999987
2

In [34]:
future_clades_for_current_timepoints.head()

Unnamed: 0,timepoint,future_timepoint,clade_membership,frequency
0,2016-10-01,2017-10-01,00fecc0,0.000819
1,2016-10-01,2017-10-01,0207ea8,0.003637
2,2016-10-01,2017-10-01,025b499,2.6e-05
3,2016-10-01,2017-10-01,02f3aa6,0.002266
4,2016-10-01,2017-10-01,0442976,0.002118


In [38]:
future_clades_for_future_timepoints = future_tips_with_future_clades.groupby([
    "timepoint", "future_timepoint", "clade_membership"
])["frequency_future"].sum().reset_index()

In [39]:
future_clades_for_future_timepoints.head()

Unnamed: 0,timepoint,future_timepoint,clade_membership,frequency_future
0,2016-10-01,2017-10-01,3e385a7,0.008505
1,2016-10-01,2017-10-01,48af62e,0.000872
2,2016-10-01,2017-10-01,522b264,0.000278
3,2016-10-01,2017-10-01,539d848,0.931737
4,2016-10-01,2017-10-01,543f4ef,1.9e-05


In [47]:
merged_clades = future_clades_for_current_timepoints.merge(
    future_clades_for_future_timepoints,
    how="outer",
    on=["timepoint", "future_timepoint", "clade_membership"]
).sort_values(["timepoint", "future_timepoint", "clade_membership"]).fillna(0.0)

In [48]:
merged_clades.head()

Unnamed: 0,timepoint,future_timepoint,clade_membership,frequency,frequency_future
0,2016-10-01,2017-10-01,00fecc0,0.000819,0.0
1,2016-10-01,2017-10-01,0207ea8,0.003637,0.0
2,2016-10-01,2017-10-01,025b499,2.6e-05,0.0
3,2016-10-01,2017-10-01,02f3aa6,0.002266,0.0
4,2016-10-01,2017-10-01,0442976,0.002118,0.0


In [52]:
merged_clades.groupby(["timepoint"])["frequency"].sum().values

array([0.999992, 0.999948, 0.999954, 0.999964, 0.99998 , 0.999958,
       0.999995, 0.999993, 1.000014, 0.999946, 0.999965, 0.999971,
       0.999965, 0.999934, 0.999937, 0.999936, 0.999948, 0.999969,
       0.999941, 0.999946, 0.999977, 0.999968, 0.99998 , 0.999943,
       0.999949, 0.999987, 0.999965, 0.999966, 0.999952, 0.999962,
       0.999954, 0.999936, 0.99996 , 0.999971, 0.999982, 0.999963,
       0.999966, 0.999955, 0.999973, 0.999985, 0.999994, 0.999973,
       0.999955, 0.999951, 0.999987, 0.999973, 0.999977])

In [51]:
merged_clades.groupby(["timepoint"])["frequency_future"].sum().values

array([0.999954, 0.999964, 0.99998 , 0.999958, 0.999995, 0.999993,
       1.000014, 0.999946, 0.999965, 0.999971, 0.999965, 0.999934,
       0.999937, 0.999936, 0.999948, 0.999969, 0.999941, 0.999946,
       0.999977, 0.999968, 0.99998 , 0.999943, 0.999949, 0.999987,
       0.999965, 0.999966, 0.999952, 0.999962, 0.999954, 0.999936,
       0.99996 , 0.999971, 0.999982, 0.999963, 0.999966, 0.999955,
       0.999973, 0.999985, 0.999994, 0.999973, 0.999955, 0.999951,
       0.999987, 0.999973, 0.999977, 0.999977, 0.999954])

In [35]:
future_clades_for_current_timepoints.merge(
    tips_to_clades,
    left_on=["future_timepoint", "clade_membership"],
    right_on=["timepoint", "clade_membership"],
    suffixes=["", "_future"]
).drop(columns=["timepoint_future"]).sort_values(["timepoint", "tip", "depth"]).groupby([
    "timepoint", "future_timepoint", "tip"
]).first().reset_index().merge(
    tips,
    left_on=["future_timepoint", "tip"],
    right_on=["timepoint", "strain"],
    suffixes=["", "_tips"]
).groupby(["future_timepoint"])["frequency_tips"].sum()

future_timepoint
2017-10-01    0.999954
2018-04-01    0.999964
2018-10-01    0.999980
2019-04-01    0.999958
2019-10-01    0.999995
2020-04-01    0.999993
2020-10-01    1.000014
2021-04-01    0.999946
2021-10-01    0.999965
2022-04-01    0.999971
2022-10-01    0.999965
2023-04-01    0.999934
2023-10-01    0.999937
2024-04-01    0.999936
2024-10-01    0.999948
2025-04-01    0.999969
2025-10-01    0.999941
2026-04-01    0.999946
2026-10-01    0.999977
2027-04-01    0.999968
2027-10-01    0.999980
2028-04-01    0.999943
2028-10-01    0.999949
2029-04-01    0.999987
2029-10-01    0.999965
2030-04-01    0.999966
2030-10-01    0.999952
2031-04-01    0.999962
2031-10-01    0.999954
2032-04-01    0.999936
2032-10-01    0.999960
2033-04-01    0.999971
2033-10-01    0.999982
2034-04-01    0.999963
2034-10-01    0.999966
2035-04-01    0.999955
2035-10-01    0.999973
2036-04-01    0.999985
2036-10-01    0.999994
2037-04-01    0.999973
2037-10-01    0.999955
2038-04-01    0.999951
2038-10-01    0.9

## Find large clades

Find all clades with an initial frequency some minimum value (e.g., >15%).

In [None]:
tips.head()

In [None]:
clade_tip_initial_frequencies = tips_to_clades.merge(
    tips,
    how="left",
    left_on=["tip"],
    right_on=["strain"]
).drop(columns=["strain"])

In [None]:
clade_tip_initial_frequencies["frequency"] = clade_tip_initial_frequencies["frequency"].fillna(0.0)

In [None]:
clade_tip_initial_frequencies.head()

In [None]:
initial_clade_frequencies = clade_tip_initial_frequencies.groupby(["timepoint", "clade_membership"])["frequency"].sum().reset_index()

In [None]:
initial_clade_frequencies.head()

In [None]:
initial_clade_frequencies.query("clade_membership == 'c139e7c'")

In [None]:
initial_clades = initial_clade_frequencies.query("frequency > 0.15 & timepoint >= '2002-10-01' & timepoint < '2015-04-01'").copy()

In [None]:
initial_clades.head()

In [None]:
initial_clades.tail()

In [None]:
initial_clades.shape

In [None]:
initial_clades[initial_clades["clade_membership"] == "c139e7c"]

## Find future frequencies of large clades

In [None]:
initial_clades["final_timepoint"] = initial_clades["timepoint"] + pd.DateOffset(months=12)

In [None]:
initial_and_observed_clade_frequencies = initial_clades.merge(
    clade_tip_initial_frequencies,
    left_on=["final_timepoint", "clade_membership"],
    right_on=["timepoint", "clade_membership"],
    suffixes=["", "_final"]
).groupby(["timepoint", "clade_membership", "frequency"])["frequency_final"].sum().reset_index()

In [None]:
initial_and_observed_clade_frequencies.shape

In [None]:
initial_and_observed_clade_frequencies.head()

In [None]:
initial_and_observed_clade_frequencies.query("clade_membership == 'c139e7c'")

In [None]:
initial_and_observed_clade_frequencies["observed_growth_rate"] = (
    initial_and_observed_clade_frequencies["frequency_final"] / initial_and_observed_clade_frequencies["frequency"]
)

In [None]:
initial_and_observed_clade_frequencies.head()

In [None]:
alt.Chart(initial_and_observed_clade_frequencies).mark_rect().encode(
    x=alt.X("observed_growth_rate:Q", bin=True, title="Observed growth rate"),
    y="count()"
)

## Find estimated future frequencies of large clades

Use the LBI model as an example, first.

In [None]:
clade_tip_estimated_frequencies = tips_to_clades.merge(
    forecasts,
    how="left",
    left_on=["tip"],
    right_on=["strain"]
).drop(columns=["strain", "fitness"])

In [None]:
clade_tip_estimated_frequencies.head()

In [None]:
clade_tip_estimated_frequencies["frequency"] = clade_tip_estimated_frequencies["frequency"].fillna(0.0)
clade_tip_estimated_frequencies["projected_frequency"] = clade_tip_estimated_frequencies["projected_frequency"].fillna(0.0)

In [None]:
clade_tip_estimated_frequencies.head()

In [None]:
estimated_clade_frequencies = clade_tip_estimated_frequencies.groupby(
    ["timepoint", "clade_membership"]
).aggregate({"projected_frequency": "sum", "frequency": "sum"}).reset_index()

In [None]:
estimated_clade_frequencies.head()

In [None]:
complete_clade_frequencies = initial_and_observed_clade_frequencies.merge(
    estimated_clade_frequencies,
    on=["timepoint", "clade_membership"],
    suffixes=["", "_other"]
)

In [None]:
complete_clade_frequencies["estimated_growth_rate"] = (
    complete_clade_frequencies["projected_frequency"] / complete_clade_frequencies["frequency"]
)

In [None]:
complete_clade_frequencies["year"] = complete_clade_frequencies["timepoint"].dt.year

In [None]:
complete_clade_frequencies.head()

In [None]:
complete_clade_frequencies.shape

In [None]:
r, p = pearsonr(
    complete_clade_frequencies["observed_growth_rate"],
    complete_clade_frequencies["estimated_growth_rate"]
)

In [None]:
mcc, confusion_matrix = get_matthews_correlation_coefficient_for_data_frame(complete_clade_frequencies, True)

In [None]:
mcc

In [None]:
growth_accuracy = confusion_matrix["tp"] / float(confusion_matrix["tp"] + confusion_matrix["fp"])
growth_accuracy

In [None]:
decline_accuracy = confusion_matrix["tn"] / float(confusion_matrix["tn"] + confusion_matrix["fn"])
decline_accuracy

In [None]:
min_growth_rate = 0
max_growth_rate = complete_clade_frequencies.loc[:, ["observed_growth_rate", "estimated_growth_rate"]].max().max() + 0.2

In [None]:
pseudofrequency = 0.001

In [None]:
complete_clade_frequencies["log_observed_growth_rate"] = (
    np.log10((complete_clade_frequencies["frequency_final"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency))
)

complete_clade_frequencies["log_estimated_growth_rate"] = (
    np.log10((complete_clade_frequencies["projected_frequency"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency))
)

In [None]:
plt.plot(complete_clade_frequencies["log_observed_growth_rate"], complete_clade_frequencies["log_estimated_growth_rate"], "o")

In [None]:
#complete_clade_frequencies["log_observed_growth_rate"] = np.log(complete_clade_frequencies["observed_growth_rate"])
#complete_clade_frequencies["log_estimated_growth_rate"] = np.log(complete_clade_frequencies["estimated_growth_rate"])

In [None]:
tooltip_attributes = ["observed_growth_rate:Q", "estimated_growth_rate:Q", "timepoint:N", "frequency:Q", "frequency_final:Q",
                      "projected_frequency:Q", "clade_membership:N"]

chart = alt.Chart(complete_clade_frequencies).mark_circle().encode(
    alt.X("observed_growth_rate:Q", scale=alt.Scale(domain=(min_growth_rate, max_growth_rate))),
    alt.Y("estimated_growth_rate:Q", scale=alt.Scale(domain=(min_growth_rate, max_growth_rate))),
    alt.Tooltip(tooltip_attributes)
).properties(
    width=400,
    height=400,
    title="Forecasts by LBI: Pearson's R = %.2f, MCC = %.2f" % (r, mcc)
)

chart.save("forecast_growth_correlation_natural_lbi.svg")
chart

In [None]:
complete_clade_frequencies.head()

In [None]:
upper_limit = np.ceil(complete_clade_frequencies.loc[:, ["observed_growth_rate", "estimated_growth_rate"]].max().max())

log_lower_limit = complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].min().min() - 0.1
log_upper_limit = np.ceil(complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].max().max()) + 0.1

In [None]:
tooltip_attributes = ["observed_growth_rate:Q", "estimated_growth_rate:Q", "timepoint:N", "frequency:Q", "frequency_final:Q",
                      "projected_frequency:Q", "clade_membership:N"]

chart = alt.Chart(complete_clade_frequencies).mark_circle().encode(
    alt.X("log_observed_growth_rate:Q", scale=alt.Scale(domain=(log_lower_limit, log_upper_limit))),
    alt.Y("log_estimated_growth_rate:Q", scale=alt.Scale(domain=(log_lower_limit, log_upper_limit))),
    alt.Color("timepoint:N"),
    alt.Tooltip(tooltip_attributes)
).properties(
    width=400,
    height=400,
    title="Forecasts by LBI: Pearson's R = %.2f, MCC = %.2f" % (r, mcc)
)

#chart.save("forecast_log_growth_correlation_natural_lbi.svg")
chart

In [None]:
log_lower_limit

In [None]:
complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].min().min()

In [None]:
r, p = pearsonr(
    complete_clade_frequencies["log_observed_growth_rate"],
    complete_clade_frequencies["log_estimated_growth_rate"]
)

In [None]:
r

In [None]:
p

In [None]:
pearsonr(
    complete_clade_frequencies["observed_growth_rate"],
    complete_clade_frequencies["estimated_growth_rate"]
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.plot(
    complete_clade_frequencies["log_observed_growth_rate"],
    complete_clade_frequencies["log_estimated_growth_rate"],
    "o",
    alpha=0.4
)

ax.axhline(color="#cccccc", zorder=-5)
ax.axvline(color="#cccccc", zorder=-5)

if p < 0.001:
    p_value = "$p$ < 0.001"
else:
    p_value = "$p$ = %.3f" % p

ax.text(
    0.02,
    0.9,
    "Growth accuracy = %.2f\nDecline accuracy = %.2f\n$R$ = %.2f\n%s" % (growth_accuracy, decline_accuracy, r, p_value),
    fontsize=12,
    horizontalalignment="left",
    verticalalignment="center",
    transform=ax.transAxes
)

ax.set_xlabel("Observed $log_{10}$ growth rate")
ax.set_ylabel("Estimated $log_{10}$ growth rate")
ax.set_title("Validation of LBI + HI tree + non-epitope mutations model", fontsize=12)

ticks = np.arange(-6, 4, 1)
ax.set_xticks(ticks)
ax.set_yticks(ticks)

ax.set_xlim(log_lower_limit, log_upper_limit)
ax.set_ylim(log_lower_limit, log_upper_limit)
ax.set_aspect("equal")

plt.savefig("../manuscript/figures/validation-of-best-model-for-natural-populations.pdf")

In [None]:
complete_clade_frequencies.shape

In [None]:
initial_and_observed_clade_frequencies.query("clade_membership == 'c139e7c'")