In [77]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

#### Functions representing dz/dt

In [130]:
def ode_torch(t,y):
    '''Example differential equation
    Args: y is initial value
          t is time'''
    t = torch.tensor(t)
    m = y - 1/2 * torch.exp(t/2) * torch.sin(5*t) + 5*torch.exp(t/2)*torch.cos(5*t)
    return m

class mini_net(nn.Module):
    
    def __init__(self):
        super(mini_net, self).__init__()
        self.W = nn.Parameter(torch.randn(1))
        
    def forward(self, t, y):
        # This is dz/dt (change in hidden state with respect to time)
        return(sig(self.W * y))
    
def sig(x):
    return(torch.exp(x) / (torch.exp(x) + 1))

class ode_forward(nn.Module):
    
    def __init__(self):
        super(ode_forward, self).__init__()
#         self.W = nn.Parameter(torch.randn(6)) # parameters
        self.A = nn.Parameter(torch.randn(1)) # parameters
        self.B = nn.Parameter(torch.randn(1)) # parameters
        self.C = nn.Parameter(torch.randn(1)) # parameters
        self.D = nn.Parameter(torch.randn(1)) # parameters
        self.E = nn.Parameter(torch.randn(1)) # parameters
        self.F = nn.Parameter(torch.randn(1)) # parameters
        
    def forward(self,t,y):
        '''Example differential equation
        Args: y is initial value
              t is time'''
        t = torch.tensor(t)
        m = y - self.A * torch.exp(t/self.B) * torch.sin(self.C*t) +\
                                self.D*torch.exp(t/self.E)*torch.cos(self.F*t)
        return m

# f = ax + bt
class diff_eq(nn.Module):
    '''This class represents the 'f' function (change in hidden state with respect to time)'''
    
    def __init__(self):
        super(diff_eq, self).__init__()
        
        self.a = nn.Parameter(torch.tensor([3.]))
        self.b = nn.Parameter(torch.tensor([1.]))
        
    def forward(self, t, z):
        return(self.a*z + self.b*t)

#### Solvers

In [134]:
def rk(fun, t, y0):
    '''Function performs RK4 algorithm
        Args: fun is a function representing the derivative of the state with respect to time
              t is list of times to obtain estimates for
              y0 is the initial condition of the state
        Return: tensor reprsenting state at each time'''
    y = [y0]
    y_ = y0
    for i, t_ in enumerate(t[:-1]):
        h = t[1]-t[0]
        y_ = rk_step(fun, t=t[i], y=y_, h=h)
        y.append(y_)   
    return torch.cat(y,dim=0)

def rk_step(fun, t, y, h):
    '''Function takes a single step in RK4 algorithm
        Args: fun is a function representing the derivative of the state with respect to time
              t is a two-element tuple representing the initial and final end times
              y0 is the initial condition of the state'''
    k1 = h * fun(t     , y)
    k2 = h * fun(t+.5*h, y+.5*k1)
    k3 = h * fun(t+.5*h, y+.5*k2)
    k4 = h * fun(t+   h, y+   k3)
    
    w = torch.tensor([[1/6, 1/3, 1/3, 1/6]]) # weights
    y_ = y + torch.sum(w*torch.cat([k1,k2,k3,k4],dim=1), dim=1).view(-1,1)
    
    return y_

def euler(fun, t_span, y0, h, states = None):
    '''Function to compute numerical approximation of ODE. 
    Args: - dy_dt is a function which takes in y,t and represents the derivative
          - y0 and t0 represent the initial value and time, respectively
          - h is the step size (fixed, in this case)
          - n_steps is the number of steps to take'''
    n_steps = int((t_span[1]-t_span[0])/h)
    y = y0 # set state to initial state
    t = t_span[0]
    all_y = [y0] # set state to initial state
    all_t = [t_span[0]]
    for i in range(1, n_steps+1):
        if states is None:
#             m = fun(t,y,states) # compute derivative at t0
            m = fun(t,y)
        else:
            m = fun(t,y,states[-i], h=h)
        y = y + h*m # approximate next value
        t = t + h
        all_y.append(y) # record values
        all_t.append(t)  
    return(torch.stack(all_y, dim=0), torch.tensor(all_t))

## Adjoint method

In [135]:
class aug_dynamics(nn.Module):
    '''Class representing augmented dynamics of the system. Forward dynamics are governed by .
        To update parameters, we require derivative of loss with respect to the hidden state, 
        dL/dz (i.e., the adjoint). To compute the adjoint and the derivative of the loss with 
        respect to the parameters, we use will use the ODE solver with the augmented dynamics.'''
    
    def __init__(self, model, len_hidden_state):
        super(aug_dynamics, self).__init__()
        
        self.model = model # model to represent the dynamics of hidden state (i.e. dz/dt)
        self.len_hidden_state = len_hidden_state

    def forward(self, t, aug_state):
        '''Forward pass of the dynamics. That is, evaluate the derivatives.
            Args: t is time
                  aug_state is tensor concatenation [a,z], i.e. adjoint state and hidden state
                  fun is function used to evaluate ode_vjp
            Returns: derivatives of adjoint and hidden state (respectively) with respect to time
                        i.e. torch.concat([da/dt, dz/dt])'''

        dt = t[1]-t[0]
        a = aug_state[:self.len_hidden_state,:]
        z = aug_state[self.len_hidden_state:2*self.len_hidden_state,:].detach().clone().requires_grad_(True)
        
        dz_dt = self.model(t[0], z) # evaluate function 
        
        self.model.zero_grad() # zero gradients in model
        dz_dt.backward(a) # backward step – this updates the gradients in the model parameters
        
        da_dt = -z.grad
        da_theta_dt = -torch.tensor([[p.grad for p in self.model.parameters()]])
        
        return torch.cat([da_dt, dz_dt.detach(), da_theta_dt], dim=1).transpose(0,1)
    
def get_loss_and_adjoint(loss_fn, y, y_forward):
    '''Function to compute loss and adjoint (initial condition for backward pass)
        Args: loss_fn is function to compute loss between ground truth and predicted value
              y is ground truth
              y_forward are the outputs of the forward pass at each time step
        Returns: loss is a scalar
                 a is the adjoint at the final time – i.e., the derivative of the loss
                     with respect to the final hidden state'''
    z1 = y_forward[-1:,:].detach().clone().requires_grad_(True) # get final hidden state
    loss = loss_fn(y, z1) # compute loss
#     loss = z1
    loss.backward() # compute adjoint at t = t_1
    a = z1.grad # get initial adjoint state (gradient of loss W.R.T. final hidden state)
    return loss, a # return loss and adjoint state

def my_backward(func, t, y_forward, a):
    '''Custom backward function to update parameters using adjoint method
        Args: func is the function representing derivative of hidden state with respect to time
              t is a vector representing the function evaluation times for the forward pass
              y_forward is a tensor representing the hidden states at each time in the forward pass'''
    t = torch.flip(t,dims=(0,)) # flip time and y_forward along time axis
    z = torch.flip(y_forward,dims=(0,))
    dyn = aug_dynamics(model=func, len_hidden_state = len(a)) # define hidden dynamics function
    aug_state = torch.cat([a, z[0:1], # Initialize augmented state
                           torch.tensor([[0. for p in func.parameters()]])],dim=1).transpose(0,1) 
    for i, t_ in enumerate(t[:-1]): # Traverse time steps and hidden states in reverse
        h_ = t[i+1] - t[i] # get step size
        aug_state[len(a):2*len(a)] = z[i:i+1] # reset hidden state to be ground truth
        
        # Go to next time step, using ODE solver
        aug_state = rk_step(fun=dyn.forward, t = torch.tensor([t[i], t[i+1]]), y=aug_state, h=h_)

    # update parameters
    for i, p in enumerate(func.parameters()):
        p.grad = aug_state[i+2]

## Compare normal backprop with custom adjoint and torchdiffeq adjoint

#### One-dimensional hidden state

In [136]:
loss_fn = nn.MSELoss()
y = torch.randn(1)

# Parameters for Runge-Kutta
t_span = [1.,2.] # start/end times 
h=.01 # step size
t = torch.tensor(list(np.concatenate([np.arange(*t_span, h), [t_span[-1]]]))) # get evaluation times
lr = .3 # set learning rate for optimizer

import torch.optim as optim
f = ode_forward()

################ backpropping through solver ##############
torch.manual_seed(0)
# diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters
# diffeq = mini_net()
diffeq = diff_eq()
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass
loss = y_forward[-1]
loss = loss_fn(y_forward[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (backprop through solver)')
print(list(diffeq.named_parameters()))


################ using custom adjoint ##############
torch.manual_seed(0)
# diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters
# diffeq = mini_net()
diffeq = diff_eq()
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

#### ADJOINT METHOD ####
diffeq.zero_grad()
y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass
loss, a = get_loss_and_adjoint(loss_fn, y, y_forward)
my_backward(diffeq, t, y_forward, a)
optimizer.step()

print('\nUpdated parameters (custom adjoint)')
print(list(diffeq.named_parameters()))


################ using torchdiffeq package ##############
from torchdiffeq import odeint_adjoint, odeint
method = 'rk4'
torch.manual_seed(0)    
# diffeq = ode_forward()
# diffeq = mini_net()
diffeq = diff_eq()
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

yhat = odeint_adjoint(func=diffeq, y0=torch.tensor([2.]), t = t, method=method)
loss = loss_fn(yhat[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (torchdiffeq adjoint)')
print(list(diffeq.named_parameters()))


Updated parameters (backprop through solver)
[('a', Parameter containing:
tensor([-1332.2523], requires_grad=True)), ('b', Parameter containing:
tensor([-237.2314], requires_grad=True))]

Updated parameters (custom adjoint)
[('a', Parameter containing:
tensor([-1332.2555], requires_grad=True)), ('b', Parameter containing:
tensor([-237.2320], requires_grad=True))]

Updated parameters (torchdiffeq adjoint)
[('a', Parameter containing:
tensor([-1332.2605], requires_grad=True)), ('b', Parameter containing:
tensor([-237.2326], requires_grad=True))]


#### 2-D hidden state

In [99]:
loss_fn = nn.MSELoss()
y = torch.randn(1)

# Parameters for Runge-Kutta
t_span = [1.,2.] # start/end times 
h=.01 # step size
t = torch.tensor(list(np.concatenate([np.arange(*t_span, h), [t_span[-1]]]))) # get evaluation times
lr = .3 # set learning rate for optimizer

import torch.optim as optim
f = ode_forward()

################ backpropping through solver ##############
torch.manual_seed(0)
diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters

solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

#### ADJOINT METHOD ####
diffeq.zero_grad()
y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass

loss = y_forward[-1]
loss = loss_fn(y_forward[-1], y)
loss.backward()
    
optimizer.step()

print('\nUpdated parameters (backprop through solver)')
print(list(diffeq.named_parameters()))


################ using custom adjoint ##############
torch.manual_seed(0)
diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

#### ADJOINT METHOD ####
diffeq.zero_grad()
y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass

# loss, a = get_loss_and_adjoint(None, None, y_forward)
loss, a = get_loss_and_adjoint(loss_fn, y, y_forward)
my_backward(diffeq, t, y_forward, a)
    
optimizer.step()

print('\nUpdated parameters (custom adjoint)')
print(list(diffeq.named_parameters()))


################ using torchdiffeq package ##############
from torchdiffeq import odeint_adjoint, odeint
method = 'rk4'
    
torch.manual_seed(0)    
diffeq = ode_forward()
optimizer = optim.SGD(diffeq.parameters(), lr=lr)
yhat = odeint_adjoint(func=diffeq, y0=torch.tensor([2.]), t = t, method=method)
loss = loss_fn(yhat[-1], y)
# yhat[-1].backward()
loss.backward()
    
optimizer.step()

print('\nUpdated parameters (torchdiffeq adjoint)')
print(list(diffeq.named_parameters()))