# Fitting fungicide initial distribution

Now doing with `p_m` from Alexey/McDonald paper

- Gamma dist better than beta
- `3/4` is ok but `2/3` is better

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import optuna

from optuna.visualization import (
    plot_optimization_history,
    plot_contour,
)

from optuna.samplers import TPESampler

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import seaborn as sns

from polymodel.fitting import (
    HostObjective,
    FungicideObjective,
    score_for_this_df,
    fitting_df,
)
    
from polymodel.config import Config

from polymodel.consts import (
    MUTATION_PROP,
    DEFAULT_P,
    HOST_MUTATION_SCALE,
    FUNG_MUTATION_SCALE,
)

# Fungicide

## Fit

In [None]:
optuna.logging.set_verbosity(0)

In [None]:
from polymodel.consts import DEFAULT_P

In [None]:
DEFAULT_P

In [None]:
fung_fit_config = Config(
    'single', 
    n_k=500,
    n_l=10,
    mutation_proportion=MUTATION_PROP,
    mutation_scale_fung=DEFAULT_P * FUNG_MUTATION_SCALE,
    mutation_scale_host=DEFAULT_P * HOST_MUTATION_SCALE,
)

In [None]:
sampler_f = TPESampler(seed=0)
study_f = optuna.create_study(sampler=sampler_f)
obj_f = FungicideObjective(fung_fit_config)

In [None]:
%%time

study_f.optimize(obj_f, n_trials=300)
int(study_f.best_value)

In [None]:
%%time

study_f.optimize(obj_f, n_trials=300)
int(study_f.best_value)

In [None]:
plot_optimization_history(study_f)

In [None]:
plot_contour(study_f)

## Replicate result

In [None]:
study_f.best_params

In [None]:
yf = (
    obj_f
    .run_model(params = study_f.best_params)
    
    # .run_model(params = {
    #     'mu': 10,
    #     'b': 1,
    # })
)

yf

In [None]:
control_data_f = (
    obj_f.df
    .loc[:, [
        'data_control', 
        # 'n_data',
        'year',
    ]]
    .assign(year = lambda df: df.year - df.year.min(),
           yearnoise = lambda x: x.year + np.random.normal(scale=0.05, size=len(x))
    )
)

control_data_f.head()

In [None]:
score_for_this_df(control_data_f, yf)

In [None]:
f, ax = plt.subplots(figsize=(14,7))

sns.scatterplot(
    x='yearnoise',
    y='data_control',
    # size='n_data',
    data=control_data_f,
    ax=ax,
    alpha=0.5,
)

ax.plot(yf, lw=4, color='red')

ax.set_ylim([0,100])

For
- `p=0.01; score: 12275; 2 vs 3 looks bad`
- `p=0.10; score: ; 2 vs 3 looks bad`
- `p=0.20; score: ; 2 vs 3 looks good/bad?`
- `p=0.50; score: 10841; 2 vs 3 looks good/bad?`

Have tried changing the decay rate from `1.11e-2` to `6.91e-3` (Elderfield/van den Berg/Hobbelen), because wasn't able to achieve high enough control in early years.

Need TRAIN_TEST_SPLIT_PROPORTION = 0.75 or 1, so that have a reasonable chance of curve going through final points which are higher than some of the ones around years 9-11.

Could try with a different beta value?

Think mutation doesn't do much within the range of plausible values.

### Variables to tweak

- `TRAIN_TEST_SPLIT_PROPORTION=0.75,1`
- `FUNG_DECAY_RATE=6.91e-3, ~9e-3, 1.11e-2`
- `n_k = 50,400,1000,2000`
- `DEFAULT_P = 0.1, others`
- maybe `beta` in the 1 vs 2 vs 3 simultaion

NB if change decay rate, need to refit fung max mutation scale


### Outcome

Needed to change to gamma dist on curvature space for initial distribution. This helps get the right shape for the fungicide decline.

With `TRAIN_TEST=0.8`, got something which only just works for 2 vs 3. Think best to proceed with `TRAIN_TEST=1`, especially given mutation thing. Then should get something which works for 2 vs 3 and for fitting, and no extra nuance to explain to non-ML people.

Need to check effect of `n_k` - hopefully minimal.

## Save best values

In [None]:
filename = '../data/03_model_inputs/fitted.csv'

In [None]:
fitted = pd.read_csv(filename)
fitted

In [None]:
data_f = fitting_df(fung_fit_config, study_f, trait='Fungicide')
data_f

In [None]:
combined_f = (
    pd.concat([
        fitted,
        data_f
    ])
    .sort_values('date', ascending=False)
    .drop_duplicates()
    .reset_index(drop=True)
)

combined_f

In [None]:
combined_f.to_csv('../data/03_model_inputs/fitted.csv', index=False)