# Model fitting

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import pandas as pd
import numpy as np
import pickle

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")

from pygam import GAM, s

# Yield relationship

We choose to include all cultivars in `ydf_use` - could have chosen to only use the 'good' ones / not the benchmark or not the mixtures?

In [None]:
yield_df = (
    pd.read_csv("../data/01_raw/YR_in_soenderborg.csv")
    .rename(columns = {'stb L2 27.06': 'stb'})
    .assign(
        stb = lambda df: df.stb/100,
        yld = lambda df: df.loc[:, ['yield']]/10,
    )
    .loc[:, ['yld', 'stb', 'cult', 'treat']]
)

In [None]:
ydf_good = yield_df.loc[lambda df: df.cult.isin(['kalmar', 'sheriff', 'informer'])]

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=yield_df,
    size='treat',
    ax=ax
)

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=ydf_good,
    size='treat',
    ax=ax
)

In [None]:
ydf_dont = (
    yield_df
    .loc[lambda df: ~df.cult.isin(['kalmar', 'sheriff', 'informer'])]
)

In [None]:
f, ax = plt.subplots(figsize=(10,8))

sns.scatterplot(
    x='stb',
    y='yld',
    hue='cult',
    data=yield_df,
    size='treat',
    ax=ax
)

In [None]:
ydf_use = yield_df.loc[:, ['stb', 'yld']].dropna()

In [None]:
if False:
    ydf_use.to_csv('../data/03_model_inputs/yield_vs_stb.csv')

In [None]:
ydf_use

### GAM

Constrained to be monotonic decreasing

In [None]:
g = GAM(
    s(0, n_splines=5), 
    constraints='monotonic_dec'
)

g.fit(np.array(ydf_use.stb), np.array(ydf_use.yld))

In [None]:
xx = np.linspace(0,1,100)

f, ax = plt.subplots(figsize=(10,8))

preds_df = pd.DataFrame(dict(x=xx, GAM=g.predict(xx))).set_index('x')

preds_df_linear = pd.DataFrame(dict(x=xx, Linear=yield_lr.predict(xx.reshape(-1,1)))).set_index('x')

ydf_use.plot.scatter(x='stb', y='yld', ax=ax)

preds_df.plot(ax=ax, color='r', lw=3)

preds_df_linear.plot(ax=ax, color='g', lw=3, ls='--')

f.savefig('../figures/paper_figs/test_yr.jpg')

### Save GAM?

In [None]:
if False:
    filename = 'gam.pickle'
    
    with open(filename, 'wb') as f:
        pickle.dump(g, f)