# Testing Gradient Updates for Demographic Stochasticity
When running the model with arbitrarily chosen stochasticities, we find that egg hatch times are happening too soon. Here, we aim to build up a method for updating stochasticity parameters using gradient descent in a way that can work with current memory constraints.

In [1]:
%load_ext autoreload
%autoreload 2

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

import SpongyMothIPM.meteorology as met
from SpongyMothIPM.config import Config
import SpongyMothIPM.util as util
import SpongyMothIPM.kernels as kernels
import SpongyMothIPM.visualization as viz

## Load Weather Data

In [28]:
df = met.load_daymet_data('../data/mont_st_hilaire/mont_st_hilaire_1980_1991.csv')
low_time = 1
high_time = 13
sample_period = 4
sample_start_time = 1
temps = met.daymet_to_diurnal(df, 
                            low_time, 
                            high_time, 
                            sample_period, 
                            sample_start_time, 365)


config = Config(dtype=torch.float,
                delta_t=sample_period/24)

days = len(temps)//(24//sample_period)
learning_rate = 0.001

## Model Setup

In [None]:
class SimpleModel():
    def __init__(self):
        # Build life stages
        self.prediapause = kernels.Prediapause(
            config, save=False, save_rate=1, mortality=0)
        self.diapause = kernels.Diapause(
            config, n_bins_I=45, n_bins_D=45, save=False, save_rate=1, mortality=0)
        self.postdiapause = kernels.Postdiapause(
            config, save=False, save_rate=1, mortality=0)
        self.first_instar = kernels.FirstInstar(
            config, save=False, save_rate=1, mortality=0, 
            file_path='memory')
        self.second_instar = kernels.SecondInstar(
            config, save=False, save_rate=1, mortality=0)
        self.third_instar = kernels.ThirdInstar(
            config, save=False, save_rate=1, mortality=0)
        self.fourth_instar = kernels.FourthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.male_late_instar = kernels.MaleFifthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.female_late_instar = kernels.FemaleFifthSixthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.male_pupae = kernels.MalePupae(
            config, save=False, save_rate=1, mortality=0)
        self.female_pupae = kernels.FemalePupae(
            config, save=False, save_rate=1, mortality=0)
        self.adults = kernels.Adult(
            config, save=False, save_rate=1, mortality=0)
        
    def init_pop(self):
        # Initiate populations
        mu = 0.2
        sigma = 1.1
        total = 1
        empty = 0
        self.prediapause.init_pop(empty, mu, sigma)
        self.diapause.init_pop(total, mu, sigma)
        self.postdiapause.init_pop(empty, mu, sigma)
        self.first_instar.init_pop(empty, mu, sigma)
        self.second_instar.init_pop(empty, mu, sigma)
        self.third_instar.init_pop(empty, mu, sigma)
        self.fourth_instar.init_pop(empty, mu, sigma)
        self.male_late_instar.init_pop(empty, mu, sigma)
        self.female_late_instar.init_pop(empty, mu, sigma)
        self.male_pupae.init_pop(empty, mu, sigma)
        self.female_pupae.init_pop(empty, mu, sigma)
        self.adults.init_pop(empty, mu, sigma)

        # For tracking emerging eggs
        self.hatched = []
        
    def forward(self):
        # Run Model
        start_year = temps['year'].min()
        end_year = temps['year'].max()
        start = 0
        for year in range(start_year, end_year+1):
            print(f"Starting year {year}")
            days = temps.loc[temps['year'] == year, 'yday'].max()
            for day in range(1, days+1):
                end = start + (24//sample_period)
                day_temps = temps.iloc[start:end]
                transfers = self.prediapause.run_one_step(day_temps)
                transfers = self.diapause.run_one_step(day_temps, transfers)
                transfers = self.postdiapause.run_one_step(day_temps, transfers)
                transfers = self.first_instar.run_one_step(day_temps, transfers)
                if day in range(100, 201):
                    self.hatched.append(transfers)

                transfers = self.second_instar.run_one_step(day_temps, transfers)
                transfers = self.third_instar.run_one_step(day_temps, transfers)
                transfers_dif = self.fourth_instar.run_one_step(day_temps, transfers)
                transfers = self.male_late_instar.run_one_step(day_temps, transfers_dif/2)
                to_adult = self.male_pupae.run_one_step(day_temps, transfers)
                transfers = self.female_late_instar.run_one_step(day_temps, transfers_dif/2)
                to_adult += self.female_pupae.run_one_step(day_temps, transfers)
                transfers = self.adults.run_one_step(day_temps, to_adult)
                self.prediapause.add_transfers(transfers/2)

                start = end

    def update_params(self, validation):
        # Create a tensor with the relative abundances at each time point.
        self.cum_hatched = [0]*len(self.hatched)
        self.cum_hatched[0] = self.hatched[0]
        for i in range(1, len(self.hatched)):
            self.cum_hatched[i] = self.cum_hatched[i-1] + self.hatched[i]
        self.cum_hatched = torch.stack(self.cum_hatched)

        # Compute loss and gradients
        loss = torch.mean((self.cum_hatched - validation)**2)
        print(loss)
        loss.backward()

        # Use gradients to update trainable parameters
        print(self.prediapause.sigma, self.prediapause.sigma.grad)
        self.prediapause.sigma = self.prediapause.sigma - self.prediapause.sigma.grad * learning_rate
        print(self.prediapause.sigma)
        self.diapause.sigma_I = self.diapause.sigma_I - self.diapause.sigma_I.grad * learning_rate
        self.diapause.sigma_D = self.diapause.sigma_D - self.diapause.sigma_D.grad * learning_rate
        self.postdiapause.sigma = self.postdiapause.sigma - self.postdiapause.sigma.grad * learning_rate
        self.first_instar.sigma = self.first_instar.sigma - self.first_instar.sigma.grad * learning_rate
        self.second_instar.sigma = self.second_instar.sigma - self.second_instar.sigma.grad * learning_rate
        self.third_instar.sigma = self.third_instar.sigma - self.third_instar.sigma.grad * learning_rate
        self.fourth_instar.sigma = self.fourth_instar.sigma - self.fourth_instar.sigma.grad * learning_rate
        self.male_late_instar.sigma = self.male_late_instar.sigma - self.male_late_instar.sigma.grad * learning_rate
        self.female_late_instar.sigma = self.female_late_instar.sigma - self.female_late_instar.sigma.grad * learning_rate
        self.male_pupae.sigma = self.male_pupae.sigma - self.male_pupae.sigma.grad * learning_rate
        self.female_pupae.sigma = self.female_pupae.sigma - self.female_pupae.sigma.grad * learning_rate
        self.adults.sigma = self.adults.sigma - self.adults.sigma.grad * learning_rate


In [44]:
validation = pd.read_csv('../data/mont_st_hilaire/hilaire_88.csv')
print(validation)
validation['doy'] = validation['doy'].round()
validation = np.interp(np.arange(100, 201), 
                       validation['doy'],
                       validation['hatch'])
validation = torch.tensor(validation)

           doy     hatch
0   128.000000  0.000000
1   128.994628  0.019048
2   130.002686  0.119048
3   131.006279  0.180952
4   132.000384  0.411905
5   132.994767  0.520238
6   133.991244  0.646429
7   135.000837  0.759524
8   136.003593  0.814286
9   136.981929  0.910714
10  137.998919  0.961905
11  138.981859  0.972619


In [50]:
torch.autograd.set_detect_anomaly(True)

model = SimpleModel()
model.init_pop()
print(model.first_instar.sigma.grad)
model.forward()
model.update_params(validation)

None
Starting year 1980


  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\site-packages\ipykernel\kernelapp.py", line 739, in start
    self.io_loop.start()
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\site-packages\tornado\platform\asyncio.py", line 211, in start
    self.asyncio_loop.run_forever()
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\asyncio\base_events.py", line 645, in run_forever
    self._run_once()
  File "c:\Users\406260\AppData\Local\miniforge3\envs\SpongyMothIPM\Lib\asyncio\base_events.py", li

RuntimeError: Function 'ErfBackward0' returned nan values in its 0th output.

In [48]:
model.first_instar.sigma.grad

  model.first_instar.sigma.grad
