In [1]:
from model_helpers import *

import cfgrib
import xarray as xr

import pandas as pd
import numpy as np

from pyPhenology import models, utils

from tqdm import trange, tqdm

import matplotlib.pyplot as plt

from warnings import warn
import warnings

warnings.filterwarnings('ignore')

high_cutoff_year = 2022
low_cutoff_year = 2010

In [10]:
def make_test_df(train_df):
    #print(train_df)
    species_sites = train_df['site_id'].unique()
        
    #print(species_sites)
    
    site_ripenesses = []

    for site in species_sites:
        site_df = train_df[train_df['site_id'] == site]

        site_ripenesses.append({
            'site_id': site,
            'doy': np.mean(site_df['doy'])
        })

    species_test_df = pd.DataFrame(site_ripenesses)
    species_test_df['year'] = high_cutoff_year
    
    return species_test_df

# More specific to our uses.
def train_ripeness_small(observations, predictors, test_observations, test_predictors, model_name = 'ThermalTime'):

    print("running model {m}".format(m=model_name))
    Model = utils.load_model(model_name)
    model = Model()
    model.fit(observations, predictors, optimizer_params='practical')
    
    print(model)
    
    print("making predictions for model {m}".format(m=model_name))        
    preds = model.predict(test_observations, test_predictors)

    print(preds)
    test_days = test_observations.doy.values
    print(test_days)
    
    # Various error types
    model_mae = mae(test_days, preds)
    model_rmse = rmse(test_days, preds)
    median_error = np.median(np.abs(test_days - preds))

    print('model {m} got a MAE of {a}'.format(m=model_name,a=model_mae))
    print('model {m} got an RMSE of {a}'.format(m=model_name,a=model_rmse))
    print('model {m}\'s median error is: {a}'.format(m=model_name,a=median_error))

    print("Ripeness Day: {}".format(np.mean(preds)))
    
    ripeness_data = test_observations
    ripeness_data['flowering_day'] = preds
    
    return ripeness_data

Idea for corrections: take the lower error between the base error and the year-transformed error.


Best Approach is:
- High time resolution, correcting for missing data by using averaged data from previous years.
- Add European Weather data.
- No southern hemisphere. 
- Corrected error (i.e. date wrapping). 

Best reporting statistic: what portion of results lie under X.
For example, 80% of results lie under 1 STD, 95% lie under 2 STD. 
Can make a "confidence score" from this – percentile error? Ex. This is less error than 90% of things.  

In [28]:
# Load in high-res weather data
print("loading weather data")
grib_data = cfgrib.open_datasets('../data/monthly_weather_data.grib')

core_data = grib_data[0]

print("Loading Plant Data")
formatted_plants = pd.read_csv("../data/model_training_data/all_plants_formatted.csv", index_col=0)

formatted_plants['rounded_lat'] = np.round(formatted_plants['latitude'], 1)
formatted_plants['rounded_lon'] = np.round(formatted_plants['lon_360'], 1)

rounded_sites = formatted_plants[['site_id', 'rounded_lat', 'rounded_lon']].drop_duplicates()

site_x_vals = xr.DataArray(rounded_sites['rounded_lat'], dims=['site'])
site_y_vals = xr.DataArray(rounded_sites['rounded_lon'], dims=['site'])

print("filtering weather data")
full_weather_data = core_data.sel(latitude=site_x_vals, longitude=site_y_vals, method='nearest').to_dataframe().dropna()

print("formatting weather data")
formatted_weather = format_weather_data(full_weather_data)

formatted_weather['latitude'] = np.round(formatted_weather['latitude'], 1)
formatted_weather['longitude'] = np.round(formatted_weather['longitude'], 1)

print("adding site IDs to weather data")
rounded_sites['coordstring'] = rounded_sites['rounded_lat'].astype(str) + rounded_sites['rounded_lon'].astype(str)
formatted_weather['coordstring'] = formatted_weather['latitude'].astype(str) + formatted_weather['longitude'].astype(str)

## Add Site ID to the weather data
weather_with_sites = pd.merge(formatted_weather, rounded_sites[['coordstring', 'site_id']], on='coordstring')#.drop('coordstring', axis=1)
## Separate into training data and testing data

# filter out current year
print("separating weather data")
weather_with_sites = weather_with_sites[weather_with_sites['year'] != 2023]

weather_training = weather_with_sites[weather_with_sites['year'] < high_cutoff_year]
weather_test = weather_with_sites[weather_with_sites['year'] >= high_cutoff_year]

# final formatting steps for plants
print("formatting plant data")
species_list = formatted_plants['formatted_sci_name'].unique()
formatted_plants.drop('species', axis=1, inplace=True)

# correct for missing sites
weather_sites = weather_with_sites['site_id'].unique()

print("filtering plant data")
filtered_plants = formatted_plants[(formatted_plants['site_id'].isin(weather_sites)) & 
                                   (formatted_plants['year'] != 2023) &
                                   (formatted_plants['latitude'] > 0) &
                                    (formatted_plants['doy'] >= 60)]


loading weather data
Loading Plant Data
filtering weather data
formatting weather data
formatting date columns
correcting leap years
adding site IDs to weather data
separating weather data
formatting plant data
filtering plant data


TODO: make the europe data rounded to .1 degrees. 

In [4]:
euro_data = load_euro_weather_data(euro_path, '../data/high_res_euro_stations.csv')

In [5]:
## Merge both weather data. 

## If site is in europe data, replace it with the european version. 

# Create a list of "mutual sites".
euro_coords = euro_data['coordstring'].unique()

mutual_sites = weather_with_sites[weather_with_sites['coordstring'].isin(euro_coords)][['site_id', 'coordstring']].drop_duplicates()
# Get those sites from europe
mutual_sites_euro = euro_data[euro_data['coordstring'].isin(mutual_sites['coordstring'])]
mutual_sites_euro = mutual_sites_euro.merge(mutual_sites, on='coordstring')

mutual_sites_euro['temperature'] += 272.5

# Remove those sites from monthly 
unmutual_monthly = weather_with_sites[~weather_with_sites['site_id'].isin(mutual_sites)]

# rbind the two (a union essentially?)
merged_euro = pd.concat([mutual_sites_euro, unmutual_monthly]).drop('station', axis=1).drop_duplicates()

merged_euro['temperature'] = np.round(merged_euro['temperature'], 1)


In [6]:

weather_training = merged_euro[merged_euro['year'] < high_cutoff_year]
weather_test = merged_euro[merged_euro['year'] >= high_cutoff_year]

In [None]:
## Train models

species_prediction_dict = {}

for s in tqdm(species_list):
    print(s)
    species_train_df = filtered_plants.query('formatted_sci_name == "{}" and year < {}'.format(s, high_cutoff_year))
    
    if len(species_train_df) == 0:
        continue
    
    species_test_df = filtered_plants.query('formatted_sci_name == "{}" and year >= {}'.format(s, high_cutoff_year))
    
   # print(species_train_df, species_test_df)
    
    if len(species_test_df) == 0:
        # make predictions and compare to the mean ripeness day at each site
        species_test_df = make_test_df(species_train_df)
    
    if len(species_test_df) == 0:
        print("No test data for {}".format(s))
        #print(species_test_df)
        
    predictions = train_ripeness_small(species_train_df, weather_training,
                        species_test_df, weather_test)
    
    species_prediction_dict[s] = predictions

  0%|                                                                                                                                                                              | 0/97 [00:00<?, ?it/s]

Rubus
running model ThermalTime


  1%|█▋                                                                                                                                                                    | 1/97 [00:03<05:30,  3.44s/it]

making predictions for model ThermalTime
[183 183 183 183 183 183 183 183 183 183 183 183 183 183 183 214 183 183
 183 214 183 214 214 183 183 183 183 183 183 183]
[183. 181. 190. 193. 165. 165. 184. 181. 178. 184. 246. 202. 288. 192.
 173. 214. 159. 191. 191. 196. 187. 214. 216. 136. 305. 305. 305. 274.
 191. 195.]
model ThermalTime got a MAE of 28.6
model ThermalTime got an RMSE of 49.24970389081881
model ThermalTime's median error is: 9.5
Ripeness Day: 187.13333333333333
Rubus occidentalis
running model ThermalTime


  2%|███▍                                                                                                                                                                  | 2/97 [00:07<06:21,  4.01s/it]

making predictions for model ThermalTime
[183 183 183 183 183 183 214 214 214 183]
[164. 189. 189. 189. 184. 189. 191. 191. 191. 187.]
model ThermalTime got a MAE of 11.7
model ThermalTime got an RMSE of 14.522396496446445
model ThermalTime's median error is: 6.0
Ripeness Day: 192.3
Ficus
running model ThermalTime


  3%|█████▏                                                                                                                                                                | 3/97 [00:11<05:44,  3.67s/it]

making predictions for model ThermalTime
[245 245 245 245 245 245 245 245 245 245 245]
[240. 210. 227. 229. 229. 248. 270. 270. 281. 239. 300.]
model ThermalTime got a MAE of 21.818181818181817
model ThermalTime got an RMSE of 26.460948928219075
model ThermalTime's median error is: 18.0
Ripeness Day: 245.0
Ficus auriculata
Ficus carica
running model ThermalTime


  5%|████████▌                                                                                                                                                             | 5/97 [00:13<03:29,  2.27s/it]

making predictions for model ThermalTime
[245 245 245 275 275 245 245 245 245 245]
[180. 232. 236. 278. 248. 258. 258. 248. 205. 335.]
model ThermalTime got a MAE of 27.6
model ThermalTime got an RMSE of 39.06404996924922
model ThermalTime's median error is: 13.0
Ripeness Day: 251.0
Ficus citrifolia
running model ThermalTime


  6%|██████████▎                                                                                                                                                           | 6/97 [00:18<04:42,  3.10s/it]

making predictions for model ThermalTime
[245 245]
[201. 280.]
model ThermalTime got a MAE of 39.5
model ThermalTime got an RMSE of 39.75550276376844
model ThermalTime's median error is: 39.5
Ripeness Day: 245.0
Ficus macrophylla
running model ThermalTime


In [51]:
apple_train_df = filtered_plants[(filtered_plants['genus'] == 'Malus') & (filtered_plants['year'] < 2022)]
apple_test_df = filtered_plants[(filtered_plants['genus'] == 'Malus') & (filtered_plants['year'] == 2022)]

Model = utils.load_model('ThermalTime')
model = Model()
model.fit(apple_train_df, weather_training, optimizer_params='practical')

print(model)

print("making predictions for model {m}".format(m='ThermalTime'))        
preds = model.predict(apple_test_df, weather_test)


<pyPhenology.models.thermaltime.ThermalTime object at 0x16af50ed0>
making predictions for model ThermalTime


In [52]:
print(preds)

[245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 275 245 245 245 275 275 245 275 275 245 245 275 275
 275 245 245 275 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245
 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 245 275 245
 245 245 245 275 245 245 245 245 245 245 245 245 245 245 275 275 275 245
 245 245 245 245 245 275 275 275 245 245 245 245 245 245 275 245 245 245
 245 245 245 245 245 275 245 245 245 245 245 245 245 245 245 245 245 245
 275 245 245 245 245 245 245 275 245 245 245 245 245 245 245 245 275 275
 275 275 275 245 245 245 245 245 245 245 245 245 245 245 245 245 245 275
 245 245 245 245 245 245 245 245 245 245 245 245 24

In [31]:
specific_test_df = merged_euro[(merged_euro['site_id'].isin(apple_test_df['site_id'])) & (merged_euro['year'] == 2022)]

In [53]:
model.predict(apple_test_df, weather_test)

array([245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 275, 245, 245, 245, 275, 275, 245, 275, 275, 245, 245,
       275, 275, 275, 245, 245, 275, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 275,
       245, 245, 245, 245, 275, 245, 245, 245, 245, 245, 245, 245, 245,
       245, 245, 275, 275, 275, 245, 245, 245, 245, 245, 245, 275, 275,
       275, 245, 245, 245, 245, 245, 245, 275, 245, 245, 245, 24

In [54]:
model.get_params()

{'t1': 185.64310747459587, 'T': 14.178133503486173, 'F': 574.1517057222219}

In [40]:
#print(preds)
test_days = apple_test_df.doy.values
print(test_days)

[313. 355. 334.  63. 191. 273. 270. 203. 231. 217. 252. 238. 224. 280.
 259. 196. 210. 189. 295. 288. 302. 280. 342. 294. 348. 287. 308. 301.
 324. 317. 333. 124. 288. 231. 266. 240. 282. 238. 246. 256. 253. 309.
 309. 289. 321. 323. 304. 303. 288. 324. 297. 328. 316. 310. 302. 294.
 316. 295. 301. 271. 277. 291. 310. 269. 264. 272. 284. 270. 253. 264.
 282. 289. 275. 333. 293. 296. 325. 300. 309. 319. 315. 268.  63. 303.
 173. 172. 243. 250. 257. 278. 285. 292. 264. 271. 311. 273. 342. 318.
 301. 313. 280. 336. 308. 292. 306. 294. 283. 329. 285. 297. 324. 304.
 274. 292. 290. 322. 298. 313. 319. 316. 326. 310. 306. 295. 266. 231.
 222. 138. 250. 236. 323. 252. 236. 182. 186. 186. 186. 203. 295.  85.
 178. 240. 234. 235. 257. 232. 235. 236. 235. 240. 233. 233. 240. 240.
 241. 241. 241. 241. 242. 242. 242. 244. 243. 245. 246. 247. 247. 248.
 248. 249. 248. 249. 250. 250. 251. 251. 251. 251. 215. 214. 216. 220.
 225. 225. 225. 225. 225. 229. 229. 233. 215. 231. 231. 232. 232. 233.
 233. 

In [None]:
genus_prediction_dict = {}

for s in tqdm(filtered_plants['genus'].unique()):
    print(s)
    species_train_df = filtered_plants.query('genus == "{}" and year < {}'.format(s, high_cutoff_year))
    
    if len(species_train_df) == 0:
        continue
    
    species_test_df = filtered_plants.query('genus == "{}" and year >= {}'.format(s, high_cutoff_year))
    
    if len(species_test_df) == 0:
        # make predictions and compare to the mean ripeness day at each site
        species_test_df = make_test_df(species_train_df)
    
    if len(species_test_df) == 0:
        print("No test data for {}".format(s))
        #print(species_test_df)
        
    predictions = train_ripeness_small(species_train_df, weather_training,
                        species_test_df, weather_test)
    
    genus_prediction_dict[s] = predictions