# Fitting mutation scales

## Approach

Find a maximum fungicide and host mutation scale that would allow us to still fit the data, if the distribution were a delta function (i.e. narrowest possible). This corresponds to all of the breakdown being caused by mutation from a single initial strain.

We find this mutation scale using the first and last years (not the initial ones since the exact shape of decline depends on shape of initial distribution which we don't think is actually a delta function).


**CHOICES**:
- gaussian or exponential kernel (*Gaussian seems best*)
- mutation proportion
- how bad is acceptable??

Then fix maximum mutation scale.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")

import plotly.graph_objects as go

import optuna
from optuna.visualization import (
    plot_optimization_history,
    plot_contour,
)
from optuna.samplers import TPESampler

from polymodel.fitting import FungMaxMutationObjective, score_for_this_df
from polymodel.config import Config
from polymodel.consts import MUTATION_PROP

from plots.fns import standard_layout

In [None]:
optuna.logging.set_verbosity(0)

# Fungicide

<!-- OLD - NOT TRUE ANY MORE - NB we restrict our attention to values with `mean>5` and `scale>0.1` so that we don't just fit the mean value. Something to do with the dispersal kernel meaning that this doesn't give a good fit without doing this. -->

## Find optimal value

In [None]:
NOT_USED_NUM = 0.5

fung_fit_config = Config(
    'single',
    n_k=500,
    n_l=10,
    mutation_proportion=MUTATION_PROP,
    mutation_scale_fung=1,
    mutation_scale_host=1,
    verbose=False
)

fung_fit_config.k_mu = NOT_USED_NUM
fung_fit_config.k_b = NOT_USED_NUM
fung_fit_config.l_mu = NOT_USED_NUM
fung_fit_config.l_b = NOT_USED_NUM

fung_fit_config.print_string_repr()

In [None]:
sampler = TPESampler(seed=0)
study = optuna.create_study(sampler=sampler)
obj_f = FungMaxMutationObjective(fung_fit_config)

In [None]:
%%time

study.optimize(obj_f, n_trials=300)
int(study.best_value)

In [None]:
%%time

study.optimize(obj_f, n_trials=300)
int(study.best_value)

In [None]:
plot_contour(study)

In [None]:
plot_optimization_history(study)

## Replicate results

In [None]:
study.best_params

In [None]:
yf = (
    FungMaxMutationObjective(fung_fit_config)
    .run_model(params = study.best_params)
    
    # .run_model(params = {
    #     'mean': 0.08,
    #     'mutation_scale': 0.02,
    # })
)

yf

In [None]:
control_data_f = (
    obj_f.df
    .loc[:, ['year', 'data_control', 
             # 'n_data'
            ]]
    .assign(year = lambda df: df.year - df.year.min())
)

control_data_f

In [None]:
score_for_this_df(
    (
        control_data_f
        .rename(columns={
            'index': 'year', 
            'control': 'data_control',
            # 'min_num': 'n_data',
        })
    ),
    yf
)

In [None]:
f, ax = plt.subplots(figsize=(14,7))

sns.scatterplot(
    x='year',
    y='data_control',
    # size='n_data',
    data=control_data_f,
    ax=ax,
)

ax.plot(yf, color='red', lw=4)

## Save mutation scale?

In [None]:
fdf = pd.DataFrame(dict(
    mutation_scale = [study.best_params['mutation_scale']],
    fung_mean = [study.best_params['mean']],
))

if True:
    print('saving')
    fdf.to_csv('../data/03_model_inputs/mutation_scale.csv')

fdf

## Plot

In [None]:
COLZ = sns.color_palette('muted').as_hex()

In [None]:
def fung_fig(df_in, y_in):
    
    col1 = COLZ[0]
    col2 = COLZ[1]
    
    data = [
        go.Scatter(
            x = df_in.year,
            y = df_in.data_control,
            mode = 'markers',
            name='Data (fungicide)',
            marker=dict(color=col1),
        ),
        go.Scatter(
            x = np.arange(df_in.year.min(), df_in.year.max()+1),
            y = y_in,
            mode = 'lines',
            name='Model (mutation only)',
            line=dict(color=col2),
        )
    ]
               
    fig = go.Figure(data=data, layout=standard_layout(True, height=400))
    
    fig.update_layout(legend=dict(x=0.05, y=0.1))
    
    fig.update_xaxes(title='Year')
    fig.update_yaxes(title='Control (%)', range=[0,100])
    
    return fig

In [None]:
data_use = obj_f.df.loc[:, ['year', 'data_control']]

In [None]:
f = fung_fig(data_use, yf)

f.show()

In [None]:
f.write_image('../figures/paper_figs/app2_fung_mutation.png')