# Model fitting

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import pickle

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")

from scipy import stats
from scipy.optimize import minimize
from scipy.stats import truncexpon

import optuna
from optuna.samplers import TPESampler

from pygam import GAM, s

from sklearn.linear_model import LinearRegression

from polymodel.utils import (
    find_soln_given_beta_and_no_control,
    logit10_vectorised,
    inverse_logit10,
    find_beta,
    find_beta_vectorised,
    truncated_exp_pdf,
    find_sev_given_beta_and_no_control
)
from polymodel.params import PARAMS

# Fungicide data

NB `"../data/02_processed/proth_control_raw.csv"` is generated from `src/fitting/fungicide/FungicideControlCurveGenerator.R`.

This converts raw fungicide data (sent via email from Frank I think) to `"../data/02_processed/proth_control_raw.csv"`.

In [None]:
prothio_df = (
    pd.read_csv("../data/02_processed/proth_control_raw.csv")
    .iloc[:, 1:]
    .set_index('year')
    .rename(columns = {
        'RD2': 'asymptote',
        'k2': 'curvature',
        'Srel': 'stb_relative_to_uncontrolled',
        'Control': 'control',
        'minNum': 'min_num',
    })
    .assign(
        untreated_large_sev = 0.37,
        notes = '3 replicates per year - Frank email',
        notes2 = 'untreated_large_sev = Mean disease severity of the 10% largest S0 values - see `Identifying when it is financially beneficial ` paper',
    )
)

prothio_df

## Save fungicide control?

In [None]:
if False:
    filename = "../data/03_model_inputs/control_prothio.csv"
    print(f"saving to {filename}")
    prothio_df.to_csv(filename)

# Host data

From `01_raw/host_trials.csv`, applying `HostCurveGenerator.R` we get:

<!-- - `03_model_inputs/Host/Varieties/FrameFull.csv` -->
<!-- - `03_model_inputs/Host/Varieties/YearlyWorstSevs.csv` -->
- `03_model_inputs/Host/Varieties/allData.csv`

These are the relevant columns from the raw data, considering untreated by fungicide and excluding cultivar mixtures.

We also filter by high pressure locations.

In [None]:
all_data = (
    pd.read_csv('../data/02_processed/Host/allData.csv')
    .iloc[:, 1:]
    # .loc[lambda df: df.location.isin(locations_use.location)]
)

all_data

Location score greater than or equal to 6

In [None]:
high_pressure_locations = (
    all_data
    .astype({'post_code': 'int64'})
    .set_index('post_code')
    .join(
        pd.read_csv('../data/01_raw/location_scores.csv')
        .replace('<4000', 4000)
        .astype('int64')
        .set_index('postal code')
        .loc[:, ['score_return']]
        .rename(columns={'score_return': 'location_score'})
    )
    .loc[lambda df : df.location_score>=6]
    .drop('location_score', axis=1)
)

high_pressure_locations.head()

In [None]:
cultivars_use = pd.read_csv('../data/04_justification/whichHost/MoreThan8Years.csv').iloc[:, 1:]

cultivars_use

Group data by year, cultivar and location, and get the mean stb score.

Retain the number of data from which this mean was calculated.

In [None]:
mean_by_yr_clt_loc = (
    high_pressure_locations
    .groupby(['year', 'cultivar', 'location'])
    .agg(['mean', 'count'])
    .reset_index()
)

mean_by_yr_clt_loc.columns = mean_by_yr_clt_loc.columns.droplevel(1)
mean_by_yr_clt_loc.columns = list(mean_by_yr_clt_loc.columns)[:-1] + ['count']

mean_by_yr_clt_loc = (
    mean_by_yr_clt_loc
    .sort_values(['cultivar', 'year'])
    .rename(columns={'stb': 'stb_mean'})
    .loc[lambda df: df.stb_mean>0]
)

mean_by_yr_clt_loc

Find the worst performing cultivar in each year/location, by mean stb score

In the event of multiple cultivars being equally bad, retain the one with max N data points

If still a clash, arbitrarily select the one first alphabetically

In [None]:
worst_stb_each_loc = (
    mean_by_yr_clt_loc
    .drop(['count', 'cultivar'], axis=1)
    .groupby(['year', 'location'])
    .max()
    .reset_index()
)

worst_stb_with_cult_and_count_non_unique = (
    worst_stb_each_loc
    .set_index(['year', 'location', 'stb_mean'])
    .join(mean_by_yr_clt_loc.set_index(['year', 'location', 'stb_mean']))
    .reset_index()
)

worst_stb_with_cult_and_count = (
    worst_stb_with_cult_and_count_non_unique
    # now pick worst cultivar with highest count (alphabetically sorted in case of a tie)
    # so that have 1 unique worst cultivar per year and location
    .sort_values(['count', 'cultivar'], ascending=[False, True])
    .groupby(['year', 'location'])
    .first()
    .reset_index()
)

worst_stb_with_cult_and_count

Combine so that we have dataframe with, for each year and location:
- stb scores from cultivars we care about
- stb scores from worst cultivar
- names of both
- number of data for both
- control offered
- minimum of number of data for cultivar vs worst cultivar


<i>We could filter out any year/location combos where the worst cultivar had a mean severity lower than 5.</i>

<i>This is because dividing by small numbers causes greater variation, and we are most interested in situations where fungicide control would be required i.e. higher disease pressure - CHOOSING NOT TO, UNNECESSARY COMPLICATION</i>



In [None]:
control_df = (
    mean_by_yr_clt_loc
    .set_index(['year', 'location'])
    .join(
        worst_stb_with_cult_and_count
        .set_index(['year', 'location']),
        rsuffix='_worst'
    )
    .reset_index()
    
    .loc[lambda df: df.cultivar.isin(cultivars_use.cultivar)]
    .assign(
        control = lambda df: 100 * (df.stb_mean_worst - df.stb_mean) / df.stb_mean_worst,
        min_num = lambda df: df.loc[:, ['count', 'count_worst']].min(axis=1)
    )
    
    .sort_values(['cultivar', 'year', 'location'])
    # .loc[lambda df: df.stb_mean_worst>=5]
)

control_df

In [None]:
f, ax = plt.subplots(figsize=(14,7))

sns.scatterplot(
    x='year', 
    y='control',
    size='min_num', 
    data=control_df.loc[control_df.cultivar=='Mariboss'],
    alpha=0.8,
)

## Save host control?

These are the control values for each interesting host variety for model fitting

In [None]:
if False:
    filename = '../data/03_model_inputs/control_host.csv'
    print(f"saving to {filename}")
    control_df.to_csv(filename)

control_df.head(5)

In [None]:
mean_worst_severities = (
    control_df
    .loc[:, ['year',
             'location',
             'cultivar',
             'count_worst',
             'stb_mean_worst',
    ]]
    
    .assign(scaled_worst = lambda df: df.count_worst * df.stb_mean_worst)
    
    .drop('stb_mean_worst', axis=1)
    
    .groupby(['year', 'cultivar'])
    
    .sum()
    
    .assign(worst_stb = lambda df: df.scaled_worst / df.count_worst)
    
    .loc[:, ['worst_stb']]
    .reset_index()
    
    .sort_values(['cultivar', 'year'])
    
    .reset_index(drop=True)
    
)

mean_worst_severities

In [None]:
f, ax = plt.subplots(figsize=(14,7))

sns.scatterplot(x='year', y='worst_stb', hue='cultivar', data=mean_worst_severities, ax=ax)

ax.set_ylim([0,50])

## Save worst stb?

These are the severities to use as inputs into the host model fitting for each year, corresponding to the mean worst stb, by host and year

In [None]:
if False:
    filename = '../data/03_model_inputs/input_severities_host.csv'
    print(f"saving to {filename}")
    mean_worst_severities.to_csv(filename)

mean_worst_severities.head(5)

# I0

Default value of I0 is going to be found from the time series data and just kept as fixed.

Filter out those which are 0

Growth stage mapping used: 
- 2275.0 dd : GS61-65 (based on 2066 = GS61)
- 2485.0 dd : GS71
- 2690.0 dd : GS75

NB GS87 = 2900

Rounding to nearest 5.

## Growth stage mapping to degree days

In [None]:
gs_dd_map = (
    pd.DataFrame(dict(
        dd = [2066, np.nan, np.nan, np.nan, 2900],
        GS = [61, 63, 71, 75, 87]
    ))
    .interpolate()
)

gs_dd_map

In [None]:
gs_dd_map = (
    pd.DataFrame(dict(
        dd = [2066, 
              np.nan,
              np.nan,
              np.nan, 
              2900],
        GS = [61, 
              63, 
              71,
              75, 
              87]
    ))
    .set_index('GS')
    .interpolate(method='index')
    .reset_index()
)

gs_dd_map

In [None]:
gs_dd_map.set_index('GS').plot()

In [None]:
gs_dd_map = gs_dd_map.assign(to_use = lambda df: 5*round(df.dd/5))
gs_dd_map

In [None]:
gs_dd_map.to_use[0]

No longer filter by cultivar, because all cultivars should have same I0

In [None]:
I_at_diff_times = (
    pd.read_csv('../data/02_processed/I0_in_better_colnames.csv')
    
    .loc[:, [
             'treatment',
             # 'cultivar',
             'L2',
             '17_06_gs71_l2',
             '25_06_19_gs75_l2',
            ]
    ]
    
    .rename(columns = {
        'L2': 'I1',
        '17_06_gs71_l2': 'I2',
        '25_06_19_gs75_l2': 'I3',
    })
    
    .loc[lambda df: (
        # untreated
        (df.treatment=='a') & 
        (df.I1>0) & 
        (df.I2>0) & 
        (df.I3>0))
    ]
    .drop('treatment', axis=1)
    
    .reset_index(drop=True)
    
    .assign(
        t0 = gs_dd_map.to_use[0],
        t1 = gs_dd_map.to_use[1],
        t2 = gs_dd_map.to_use[2],
        t3 = gs_dd_map.to_use[3],
        
        IL1 = lambda df: logit10_vectorised(0.01*df.I1),
        IL2 = lambda df: logit10_vectorised(0.01*df.I2),
        IL3 = lambda df: logit10_vectorised(0.01*df.I3),
    )
)

I_at_diff_times

In [None]:
stacked = (
    I_at_diff_times
    .filter(like='IL')
    .stack()
    .reset_index(level=1)
    .replace('IL1', gs_dd_map.to_use[1])
    .replace('IL2', gs_dd_map.to_use[2])
    .replace('IL3', gs_dd_map.to_use[3])
    .rename(columns={
        'level_1': 'time_dd', 
        0: 'stb'
    })
    .reset_index(drop=True)
)

stacked.head(5)

In [None]:
lr = LinearRegression().fit(
    np.array(stacked.time_dd).reshape(-1, 1),
    np.array(stacked.stb)
)

In [None]:
IL0_pred = lr.predict(np.array([1456]).reshape(-1, 1))[0]

In [None]:
IL0_pred

In [None]:
I0_pred = inverse_logit10(IL0_pred)

In [None]:
tt = np.linspace(1456, stacked.time_dd.max(), 5)
model_preds = lr.predict(tt.reshape(-1,1))

## New (works with non-constant host)

In [None]:
class I0Objective:
    def __init__(self) -> None:
        self.df = stacked

    def __call__(self, trial):
        if trial is None:
            params = {'I0': 0.001, 'beta': 0.001}
        else:
            params = self.get_params(trial)
        
        res = self.run_model(params)
        
        score = np.sum(res.sq_residual)
        
        return score
    


    def run_model(self, params):
        t_vals = [PARAMS.T_1] + list(stacked.time_dd.unique())

        t_vals.sort()
        
        soln = find_soln_given_beta_and_no_control(
            params['beta'],
            params['I0'],
            t_vals)
        
        no_control_model = (
            pd.DataFrame(soln.T)
            .rename(columns={0: 'S', 1: 'I'})
            .assign(
                time_dd = t_vals,
                sev = lambda df: df.I / (df.S + df.I),
                sev_logit = lambda df: logit10_vectorised(df.sev)
            )
            .set_index('time_dd')
        )

        res = (
            no_control_model

            .join(
                stacked.set_index('time_dd'),
                how='outer'
            )

            .rename(columns = {
                'stb': 'data', 
                'sev_logit': 'model'
            })

            .loc[:, ['data', 'model']]

            .reset_index()

            .loc[lambda df: df.time_dd>1456]
            
            .assign(sq_residual = lambda df: (df.data - df.model)**2)
        )

        return res


    def get_params(self, trial):
        params = {
            "I0": trial.suggest_float(
                "I0",
                1e-6,
                1e-2
            ),
            "beta": trial.suggest_float(
                "beta",
                1e-5,
                1e-2,
                log=True
            ),
        }
        return params

In [None]:
optuna.logging.set_verbosity(
    optuna.logging.WARNING
)

In [None]:
sampler = TPESampler(seed=0)
study = optuna.create_study(sampler=sampler)
obj = I0Objective()

In [None]:
study.optimize(obj, n_trials=500)

In [None]:
study.best_trial

In [None]:
study.best_params

In [None]:
res = I0Objective().run_model(study.best_params)
res.head()

In [None]:
f, ax = plt.subplots()

res.set_index('time_dd').model.plot(ax=ax, lw=4, c='r')

res.plot.scatter(x='time_dd', y='data', ax=ax)

## Save I0 value?

In [None]:
I0_df = (
    pd.DataFrame(dict(I0_value = [study.best_params['I0']]))
)

I0_df

In [None]:
if True:
    filename = '../data/03_model_inputs/I0_value.csv'
    print(f'saving to {filename}')
    I0_df.to_csv(filename)

# Beta

We collate all of the worst cultivars in the high pressure locations across all years, and then use a gaussian kernel-density estimate using Gaussian kernels and small bandwith 0.05 to smooth the values.

We can use the single I0 value found above to give us a beta value for each of these smoothed final severities.

In [None]:
I0_value = pd.read_csv('../data/03_model_inputs/I0_value.csv').I0_value.iloc[0]
I0_value

In [None]:
stb_values = (
    worst_stb_with_cult_and_count_non_unique
    .drop(['stb_mean', 'count'], axis=1)
    .set_index(['year', 'location', 'cultivar'])
    .join(
        high_pressure_locations
        .set_index(['year', 'location', 'cultivar'])
    )
    .reset_index()
)

stb_values.stb.hist()

In [None]:
stb_values

## Truncated exponential

In [None]:
stb_values.stb.describe()

In [None]:
xx = np.linspace(-0.5,100.5,301)

In [None]:
def neg_log_likelihood(lambd):
    log_probs = [np.log(truncated_exp_pdf(x, lambd)) for x in stb_values.stb]
    return - np.sum(log_probs)

In [None]:
min_out = minimize(
    neg_log_likelihood,
    [0.04],
    bounds=[(1e-6, 100)],
    tol=1e-6,
)
min_out

In [None]:
lambd_fitted = min_out.x[0]
lambd_fitted

In [None]:
my_line = [truncated_exp_pdf(ii, lambd_fitted) for ii in xx]

In [None]:
if False:
    stb_values.to_csv('../data/03_model_inputs/stb_vals.csv')

In [None]:
f, ax = plt.subplots(figsize=(8,7))

(
    stb_values
    .stb
    .hist(ax=ax, 
          bins=20,
          density=True
    )
)

ax.plot(xx, my_line, c='r', lw=3)

In [None]:
f.savefig('../figures/paper_figs/trunc_exp_20.jpg')

In [None]:
f, ax = plt.subplots(figsize=(8,7))

(
    stb_values
    .stb
    .hist(ax=ax, 
          bins=40,
          density=True
    )
)

ax.plot(xx, my_line, c='r', lw=3)

In [None]:
f.savefig('../figures/paper_figs/trunc_exp_40.jpg')

In [None]:
ldf = pd.DataFrame(dict(lambda_fitted = [lambd_fitted]))
ldf

## Save lambda (exponential value)

In [None]:
filename = '../data/03_model_inputs/lambda_fitted.csv'

if False:
    print(f'saving to {filename}')
    ldf.to_csv(filename)
    
lambda_use = float(pd.read_csv(filename).iloc[:, 1])
lambda_use

## Sample stb and beta

Need to check that the resulting beta values are sensible. Should be somewhere in the order of `1e-3`.

Then can post-hoc filter out any values that do something weird.

In [None]:
def find_stb(lambd, p):
    arg = 1 - p + p*np.exp(-100*lambd)
    out = -1/lambd * np.log(arg)
    return out

In [None]:
N_SAMPLE = 20000

In [None]:
np.random.seed(1)
random_unif = np.random.uniform(size=N_SAMPLE)

In [None]:
stb_generated = find_stb(lambda_use, random_unif)

In [None]:
pd.DataFrame(dict(stb=stb_generated)).describe()

In [None]:
I0_value

In [None]:
find_beta_vectorised([2e-3, 9.9e-1], I0_value)

In [None]:
beta_df = (
    pd.DataFrame(dict(stb=stb_generated))
    .assign(beta = lambda df: 
            find_beta_vectorised(0.01*df.stb, I0_value)
    )
)

beta_df.head()

In [None]:
beta_df.to_csv('../data/03_model_inputs/many_sampled_betas.csv')

In [None]:
f, ax = plt.subplots()

beta_df.hist(ax=ax
             # , bins=50
            )

In [None]:
beta_df.sort_values(['beta', 'stb']).loc[lambda df: np.isclose(df.beta, 1e-4)]

In [None]:
bad_betas = (
    beta_df
    .loc[lambda df: (
        (df.beta<=1e-4) |
        (df.beta>=5e-2) |
        (df.beta.isin([np.nan]))
        
    )]
)

bad_betas

In [None]:
beta_df.loc[np.isclose(beta_df.beta,0.0001), :] = np.nan

In [None]:
sampled_betas_use = beta_df.loc[~np.isclose(beta_df.beta,0.0001)]

sampled_betas_use.shape

NB that think can't achieve this lowest sev with same I0 - even with beta=0 will have some minimum severity. So just filter out these very rare cases - 13 out of 20000.

In [None]:
sampled_betas_use.describe()

In [None]:
(
    sampled_betas_use
    .assign(log_b = lambda df: np.log(df.beta))
    .plot
    .scatter(x='log_b', y='stb', alpha=0.1)
)

In [None]:
# f.savefig('../figures/paper_figs/stb_vs_beta_expo.jpg')

## Save sampled betas?

In [None]:
# if SAVING:
if True:
    filename = '../data/03_model_inputs/beta_sampled.csv'
    print(f'saving beta to {filename}')
    sampled_betas_use.beta.to_csv(filename)

## Resample to get more

In [None]:
betas = pd.read_csv('../data/03_model_inputs/beta_sampled.csv').iloc[1:]

In [None]:
betas2 = (
    pd.concat([betas] +
        [betas.sample(frac=1, random_state=ii) for ii in range(10)]
    )
)

betas2.head(10)

In [None]:
betas2.shape

In [None]:
betas2.loc[:, 'beta'].to_csv('../data/03_model_inputs/beta_sampled.csv', index=False)

### for nik cdf

In [None]:
cdf_df = (
    pd.DataFrame(dict(x = np.linspace(0,1,100)))
    .assign(cdf = lambda df: find_stb(lambda_use, df.x))
)

In [None]:
f, ax= plt.subplots(figsize=(14,8))

cdf_df.plot(x='x', y='cdf', ax=ax)


(
    stb_values
    .sort_values('stb')
    .reset_index(drop=True)
    .assign(quartile = lambda df: (
        df.index / (df.shape[0] - 1)
    )
    )
    .set_index('quartile')
    .loc[:, ['stb']]
    .plot(ax=ax)
)

# (
#     random_sample
#     .sort_values('random_stb')
#     .reset_index(drop=True)
#     .assign(quartile = lambda df: (
#         df.index / (df.shape[0] - 1)
#     )        
#     )
#     .set_index('quartile')
#     .loc[:, ['random_stb']]
#     .rename(columns={'random_stb': 'smoothed_stb'})
#     .plot(ax=ax)
# )

In [None]:
f.savefig('../figures/paper_figs/cdf.jpg')

In [None]:
betas = pd.read_csv('../data/03_model_inputs/beta_sampled.csv').iloc[:, 1:]
betas

In [None]:
betas.mean(), betas.median()

## Save median beta?

NB have already saved a value - could keep using this value so that don't need to re-run? Although beta is fixed in the fung/host fitting tbf

In [None]:
bdf = pd.DataFrame(dict(beta_median=[betas.median()[0]]))

if False:
    filename = '../data/03_model_inputs/beta_value.csv'
    print(f'saving to {filename}')
    bdf.to_csv(filename)

bdf

## Save 'actual' betas for figure?

Not for model

In [None]:
b_actual = find_beta_vectorised(
    0.01*pd.read_csv('../data/03_model_inputs/stb_vals.csv').stb,
    I0_value
)

In [None]:
if True:
    print('saving')
    
    (
        pd.DataFrame(dict(betas = b_actual))
        .to_csv('../data/03_model_inputs/beta_from_data_not_sampled_dist.csv')
    )

# Yield relationship

here we choose to include all cultivars - could have chosen to only use the 'good' ones / not the benchmark or not the mixtures?

In [None]:
yield_df = (
    pd.read_csv("../data/01_raw/YR_in_soenderborg.csv")
    .rename(columns = {'stb L2 27.06': 'stb'})
    .assign(
        stb = lambda df: df.stb/100,
        yld = lambda df: df.loc[:, ['yield']]/10,
    )
    .loc[:, ['yld', 'stb', 'cult', 'treat']]
)

In [None]:
ydf_good = yield_df.loc[lambda df: df.cult.isin(['kalmar', 'sheriff', 'informer'])]

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=yield_df,
    size='treat',
    ax=ax
)

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=ydf_good,
    size='treat',
    ax=ax
)

In [None]:
ydf_dont = (
    yield_df
    .loc[lambda df: ~df.cult.isin(['kalmar', 'sheriff', 'informer'])]
)

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=yield_df,
    size='treat',
    ax=ax
)

In [None]:
ydf_use = yield_df.loc[:, ['stb', 'yld']].dropna()

In [None]:
if False:
    ydf_use.to_csv('../data/03_model_inputs/yield_vs_stb.csv')

In [None]:
ydf_use

### Old - find pars - linear

In [None]:
yield_lr = LinearRegression().fit(
    np.array(ydf_use.stb).reshape(-1,1),
    np.array(ydf_use.yld)
)

In [None]:
yield_lr.intercept_, yield_lr.coef_[0]

In [None]:
gdt = yield_lr.coef_[0] / yield_lr.intercept_
gdt

In [None]:
intercept = yield_lr.intercept_
intercept

In [None]:
yr_out = pd.DataFrame(dict(gdt = [gdt], intercept=[intercept]))

if False:
    filename = '../data/03_model_inputs/old/yield_relationship_linear.csv'
    print(f'saving yield relationship params to {filename}')

    yr_out.to_csv(filename)

yr_out

### GAM

Constrained to be monotonic decreasing

In [None]:
g = GAM(
    s(0, n_splines=5), 
    constraints='monotonic_dec'
)

g.fit(np.array(ydf_use.stb), np.array(ydf_use.yld))

In [None]:
xx = np.linspace(0,1,100)

f, ax = plt.subplots(figsize=(10,8))

preds_df = pd.DataFrame(dict(x=xx, GAM=g.predict(xx))).set_index('x')

preds_df_linear = pd.DataFrame(dict(x=xx, Linear=yield_lr.predict(xx.reshape(-1,1)))).set_index('x')

ydf_use.plot.scatter(x='stb', y='yld', ax=ax)

preds_df.plot(ax=ax, color='r', lw=3)

preds_df_linear.plot(ax=ax, color='g', lw=3, ls='--')

f.savefig('../figures/paper_figs/test_yr.jpg')

### Save GAM?

In [None]:
if False:
    filename = 'gam.pickle'
    
    with open(filename, 'wb') as f:
        pickle.dump(g, f)

# End