# Testing Gradient Updates for Demographic Stochasticity
When running the model with arbitrarily chosen stochasticities, we find that egg hatch times are happening too soon. Here, we aim to build up a method for updating stochasticity parameters using gradient descent in a way that can work with current memory constraints.

In [60]:
%load_ext autoreload
%autoreload 2

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

import SpongyMothIPM.meteorology as met
from SpongyMothIPM.config import Config
import SpongyMothIPM.util as util
import SpongyMothIPM.kernels as kernels
import SpongyMothIPM.visualization as viz

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load Weather Data

In [74]:
df = met.load_daymet_data('../data/mont_st_hilaire/mont_st_hilaire_1980_1991.csv')
low_time = 1
high_time = 13
sample_period = 4
sample_start_time = 1
temps = met.daymet_to_diurnal(df, 
                            low_time, 
                            high_time, 
                            sample_period, 
                            sample_start_time, 365)


config = Config(dtype=torch.float,
                delta_t=sample_period/24)

days = len(temps)//(24//sample_period)
learning_rate = 0.00000001

## Model Setup

In [None]:
class SimpleModel():
    def __init__(self):
        # Build life stages
        self.prediapause = kernels.Prediapause(
            config, save=False, save_rate=1, mortality=0)
        self.diapause = kernels.Diapause(
            config, n_bins_I=45, n_bins_D=45, save=False, save_rate=1, mortality=0)
        self.postdiapause = kernels.Postdiapause(
            config, save=False, save_rate=1, mortality=0)
        self.first_instar = kernels.FirstInstar(
            config, save=False, save_rate=1, mortality=0, 
            file_path='memory')
        self.second_instar = kernels.SecondInstar(
            config, save=False, save_rate=1, mortality=0)
        self.third_instar = kernels.ThirdInstar(
            config, save=False, save_rate=1, mortality=0)
        self.fourth_instar = kernels.FourthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.male_late_instar = kernels.MaleFifthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.female_late_instar = kernels.FemaleFifthSixthInstar(
            config, save=False, save_rate=1, mortality=0)
        self.male_pupae = kernels.MalePupae(
            config, save=False, save_rate=1, mortality=0)
        self.female_pupae = kernels.FemalePupae(
            config, save=False, save_rate=1, mortality=0)
        self.adults = kernels.Adult(
            config, save=False, save_rate=1, mortality=0)
        
    def init_pop(self):
        # Initiate populations
        mu = 0.2
        sigma = 1.1
        total = 1
        empty = 0
        self.prediapause.init_pop(empty, mu, sigma)
        self.diapause.init_pop(total, mu, sigma)
        self.postdiapause.init_pop(empty, mu, sigma)
        self.first_instar.init_pop(empty, mu, sigma)
        self.second_instar.init_pop(empty, mu, sigma)
        self.third_instar.init_pop(empty, mu, sigma)
        self.fourth_instar.init_pop(empty, mu, sigma)
        self.male_late_instar.init_pop(empty, mu, sigma)
        self.female_late_instar.init_pop(empty, mu, sigma)
        self.male_pupae.init_pop(empty, mu, sigma)
        self.female_pupae.init_pop(empty, mu, sigma)
        self.adults.init_pop(empty, mu, sigma)

        # For tracking emerging eggs
        self.hatched = []
        
    def forward(self):
        # Run Model
        start_year = temps['year'].min()
        end_year = temps['year'].max()
        start = 0
        for year in range(start_year, end_year+1):
            print(f"Starting year {year}")
            days = temps.loc[temps['year'] == year, 'yday'].max()
            for day in range(1, days+1):
                end = start + (24//sample_period)
                day_temps = temps.iloc[start:end]
                transfers = self.prediapause.run_one_step(day_temps)
                transfers = self.diapause.run_one_step(day_temps, transfers)
                transfers = self.postdiapause.run_one_step(day_temps, transfers)
                transfers = self.first_instar.run_one_step(day_temps, transfers)
                if day in range(100, 201):
                    self.hatched.append(transfers)

                transfers = self.second_instar.run_one_step(day_temps, transfers)
                transfers = self.third_instar.run_one_step(day_temps, transfers)
                transfers_dif = self.fourth_instar.run_one_step(day_temps, transfers)
                transfers = self.male_late_instar.run_one_step(day_temps, transfers_dif/2)
                to_adult = self.male_pupae.run_one_step(day_temps, transfers)
                transfers = self.female_late_instar.run_one_step(day_temps, transfers_dif/2)
                to_adult += self.female_pupae.run_one_step(day_temps, transfers)
                transfers = self.adults.run_one_step(day_temps, to_adult)
                self.prediapause.add_transfers(transfers/2)

                start = end

    def print_params(self):
        print('Prediapause: ', self.prediapause.sigma, self.prediapause.sigma.grad)
        print('Diapause I: ', self.diapause.sigma_I, self.diapause.sigma_I.grad)
        print('Diapause D: ', self.diapause.sigma_D, self.diapause.sigma_D.grad)
        print('Postdiapause: ', self.postdiapause.sigma, self.postdiapause.sigma.grad)
        print('First Instar: ', self.first_instar.sigma, self.first_instar.sigma.grad)
        print('Second Instar: ', self.second_instar.sigma, self.second_instar.sigma.grad)
        print('Third Instar: ', self.third_instar.sigma, self.third_instar.sigma.grad)
        print('Fourth Instar: ', self.fourth_instar.sigma, self.fourth_instar.sigma.grad)
        print('Male Late Instar: ', self.male_late_instar.sigma, self.male_late_instar.sigma.grad)
        print('Female Late Instar: ', self.female_late_instar.sigma, self.female_late_instar.sigma.grad)
        print('Male Pupae: ', self.male_pupae.sigma, self.male_pupae.sigma.grad)
        print('Female Pupae: ', self.female_pupae.sigma, self.female_pupae.sigma.grad)
        print('Adult: ', self.adults.sigma, self.adults.sigma.grad)

    def update_params(self, validation):
        # Create a tensor with the relative abundances at each time point.
        self.cum_hatched = [0]*len(self.hatched)
        self.cum_hatched[0] = self.hatched[0]
        for i in range(1, len(self.hatched)):
            self.cum_hatched[i] = self.cum_hatched[i-1] + self.hatched[i]
        self.cum_hatched = torch.stack(self.cum_hatched)

        # Compute loss and gradients
        loss = torch.mean((self.cum_hatched - validation)**2)
        print(loss)
        loss.backward()

        # Use gradients to update trainable parameters
        with torch.no_grad():
            self.print_params()
            # Prediapause
            self.prediapause.sigma -= self.prediapause.sigma.grad * learning_rate
            self.prediapause.sigma.grad.data.zero_()
            # Diapause
            self.diapause.sigma_I -= self.diapause.sigma_I.grad * learning_rate
            self.diapause.sigma_I.grad.data.zero_()
            self.diapause.sigma_D -= self.diapause.sigma_D.grad * learning_rate
            self.diapause.sigma_D.grad.data.zero_()
            # Postdiapause
            self.postdiapause.sigma -= self.postdiapause.sigma.grad * learning_rate
            self.postdiapause.sigma.grad.data.zero_()
            # First Instar
            self.first_instar.sigma -= self.first_instar.sigma.grad * learning_rate
            self.first_instar.sigma.grad.data.zero_()
            # Second Instar
            self.second_instar.sigma -= self.second_instar.sigma.grad * learning_rate
            self.second_instar.sigma.grad.data.zero_()
            # Thrid Instar
            self.third_instar.sigma -= self.third_instar.sigma.grad * learning_rate
            self.third_instar.sigma.grad.data.zero_()
            # Fourth Instar
            self.fourth_instar.sigma -= self.fourth_instar.sigma.grad * learning_rate
            self.fourth_instar.sigma.grad.data.zero_()
            # Male Fifth Instar
            self.male_late_instar.sigma -= self.male_late_instar.sigma.grad * learning_rate
            self.male_late_instar.sigma.grad.data.zero_()
            # Female Fifth/Sixth Instar
            self.female_late_instar.sigma -= self.female_late_instar.sigma.grad * learning_rate
            self.female_late_instar.sigma.grad.data.zero_()
            # Male Pupae
            self.male_pupae.sigma -= self.male_pupae.sigma.grad * learning_rate
            self.male_pupae.sigma.grad.data.zero_()
            # Female Pupae
            self.female_pupae.sigma -= self.female_pupae.sigma.grad * learning_rate
            self.female_pupae.sigma.grad.data.zero_()
            # Adults
            self.adults.sigma -= self.adults.sigma.grad * learning_rate
            self.adults.sigma.grad.data.zero_()


In [55]:
validation = pd.read_csv('../data/mont_st_hilaire/hilaire_88.csv')
print(validation)
validation['doy'] = validation['doy'].round()
validation = np.interp(np.arange(100, 201), 
                       validation['doy'],
                       validation['hatch'])
validation = torch.tensor(validation)

           doy     hatch
0   128.000000  0.000000
1   128.994628  0.019048
2   130.002686  0.119048
3   131.006279  0.180952
4   132.000384  0.411905
5   132.994767  0.520238
6   133.991244  0.646429
7   135.000837  0.759524
8   136.003593  0.814286
9   136.981929  0.910714
10  137.998919  0.961905
11  138.981859  0.972619


In [92]:
torch.autograd.set_detect_anomaly(False)

model = SimpleModel()
model.init_pop()
model.forward()
model.update_params(validation)

Starting year 1980
tensor(0.0733, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(1.1000, requires_grad=True) tensor(-15323.6523)
tensor(1.5000, requires_grad=True) tensor(-462749.4375)
tensor(1.5000, requires_grad=True) tensor(-684220.4375)
tensor(1.1000, requires_grad=True) tensor(-4668.1426)
tensor(1.1000, requires_grad=True) tensor(-40865.6875)
tensor(1.1000, requires_grad=True) tensor(-49437.4805)
tensor(1.1000, requires_grad=True) tensor(-27977.7734)
tensor(1.1000, requires_grad=True) tensor(-40237.2031)
tensor(1.1000, requires_grad=True) tensor(-14380.5537)
tensor(1.1000, requires_grad=True) tensor(2289.0955)
tensor(1.1000, requires_grad=True) tensor(-15012.1377)
tensor(1.1000, requires_grad=True) tensor(-40323.7070)
tensor(1.1000, requires_grad=True) tensor(-35134.4102)


In [93]:
model.init_pop()
model.forward()
model.update_params(validation)

Starting year 1980
tensor(0.0730, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(1.1002, requires_grad=True) tensor(-15364.0391)
tensor(1.5046, requires_grad=True) tensor(-464950.8750)
tensor(1.5068, requires_grad=True) tensor(-688919.1250)
tensor(1.1000, requires_grad=True) tensor(-4667.4243)
tensor(1.1004, requires_grad=True) tensor(-40940.1680)
tensor(1.1005, requires_grad=True) tensor(-49656.8867)
tensor(1.1003, requires_grad=True) tensor(-28050.3809)
tensor(1.1004, requires_grad=True) tensor(-40372.0742)
tensor(1.1001, requires_grad=True) tensor(-14427.1455)
tensor(1.1000, requires_grad=True) tensor(2295.8904)
tensor(1.1002, requires_grad=True) tensor(-15023.6279)
tensor(1.1004, requires_grad=True) tensor(-40374.4648)
tensor(1.1004, requires_grad=True) tensor(-35232.4297)


In [94]:
model.init_pop()
model.forward()
model.update_params(validation)

Starting year 1980
tensor(0.0726, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(1.1003, requires_grad=True) tensor(-15404.5137)
tensor(1.5093, requires_grad=True) tensor(-467133.7500)
tensor(1.5137, requires_grad=True) tensor(-693539.7500)
tensor(1.1001, requires_grad=True) tensor(-4666.7021)
tensor(1.1008, requires_grad=True) tensor(-41015.0391)
tensor(1.1010, requires_grad=True) tensor(-49875.4141)
tensor(1.1006, requires_grad=True) tensor(-28123.0059)
tensor(1.1008, requires_grad=True) tensor(-40506.8789)
tensor(1.1003, requires_grad=True) tensor(-14473.7920)
tensor(1.1000, requires_grad=True) tensor(2302.7205)
tensor(1.1003, requires_grad=True) tensor(-15035.2354)
tensor(1.1008, requires_grad=True) tensor(-40425.9922)
tensor(1.1007, requires_grad=True) tensor(-35330.6797)


In [83]:
print(model.prediapause.sigma.grad)

None


  print(model.prediapause.sigma.grad)
