In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

#### Functions representing dz/dt

In [2]:
# f = ax + bt
class diff_eq(nn.Module):
    '''This class represents the 'f' function (change in hidden state with respect to time)'''
    
    def __init__(self):
        super(diff_eq, self).__init__()
        
        self.a = nn.Parameter(torch.tensor([3.]))
        self.b = nn.Parameter(torch.tensor([1.]))
        
    def forward(self, t, z):
        return(self.a*z + self.b*t)

class mini_net(nn.Module):
    
    def __init__(self):
        super(mini_net, self).__init__()
        self.W = nn.Parameter(torch.randn(1))
        
    def forward(self, t, y):
        # This is dz/dt (change in hidden state with respect to time)
        return(sig(self.W * y))
    
class small_net(nn.Module):
    
    def __init__(self, dim=5):
        super(small_net, self).__init__()
        self.W = nn.Parameter(torch.randn(dim,dim))
        self.Y = nn.Parameter(torch.randn(dim,dim))
        
    def forward(self, t, y):
        # This is dz/dt (change in hidden state with respect to time)
        out = sig(self.W @ y)
        return(self.Y @ out)
    
def sig(x):
    return(torch.exp(x) / (torch.exp(x) + 1))

class ode_forward(nn.Module):
    
    def __init__(self):
        super(ode_forward, self).__init__()
#         self.W = nn.Parameter(torch.randn(6)) # parameters
        self.A = nn.Parameter(torch.randn(1)) # parameters
        self.B = nn.Parameter(torch.randn(1)) # parameters
        self.C = nn.Parameter(torch.randn(1)) # parameters
        self.D = nn.Parameter(torch.randn(1)) # parameters
        self.E = nn.Parameter(torch.randn(1)) # parameters
        self.F = nn.Parameter(torch.randn(1)) # parameters
        
    def forward(self,t,y):
        '''Example differential equation
        Args: y is initial value
              t is time'''
        t = torch.tensor(t)
        m = y - self.A * torch.exp(t/self.B) * torch.sin(self.C*t) +\
                                self.D*torch.exp(t/self.E)*torch.cos(self.F*t)
        return m

#### Solvers

In [3]:
def rk(fun, t, y0):
    '''Function performs RK4 algorithm
        Args: fun is a function representing the derivative of the state with respect to time
              t is list of times to obtain estimates for
              y0 is the initial condition of the state
        Return: tensor reprsenting state at each time'''
    y = [y0]
    y_ = y0
    for i, t_ in enumerate(t[:-1]):
        h = t[1]-t[0]
        y_ = rk_step(fun, t=t[i], y=y_, h=h)
        y.append(y_)   
    return y

def rk_step(fun, t, y, h):
    '''Function takes a single step in RK4 algorithm
        Args: fun is a function representing the derivative of the state with respect to time
              t is a two-element tuple representing the initial and final end times
              y0 is the initial condition of the state'''
    k1 = h * fun(t     , y)
    k2 = h * fun(t+.5*h, y+.5*k1)
    k3 = h * fun(t+.5*h, y+.5*k2)
    k4 = h * fun(t+   h, y+   k3)
    
    w = torch.tensor([1/6, 1/3, 1/3, 1/6]) # weights
    w = w[(None,)*y.ndimension()] # add singleton dimensions to match shape of ks
    y_ = y + torch.sum(w*torch.stack([k1,k2,k3,k4],dim=-1), dim=-1)
    
    return y_

## Adjoint method

In [4]:
class aug_dynamics(nn.Module):
    '''Class representing augmented dynamics of the system. Forward dynamics are governed by .
        To update parameters, we require derivative of loss with respect to the hidden state, 
        dL/dz (i.e., the adjoint). To compute the adjoint and the derivative of the loss with 
        respect to the parameters, we use will use the ODE solver with the augmented dynamics.'''
    
    def __init__(self, model, init_hidden_state):
        super(aug_dynamics, self).__init__()
        
        self.model = model # model to represent the dynamics of hidden state (i.e. dz/dt)
        self.init_hidden_state = init_hidden_state

    def forward(self, t, aug_state):
        '''Forward pass of the dynamics. That is, evaluate the derivatives.
            Args: t is time
                  aug_state is tensor concatenation [a,z], i.e. adjoint state and hidden state
                  fun is function used to evaluate ode_vjp
            Returns: derivatives of adjoint and hidden state (respectively) with respect to time
                        i.e. torch.concat([da/dt, dz/dt])'''
        dt = t[1]-t[0]
        a,z,_ = unpack_aug_state(aug_state, self.init_hidden_state, list(self.model.parameters()))
        z = z.detach().requires_grad_()
        
        dz_dt = self.model(t[0], z) # evaluate function 
        
        self.model.zero_grad() # zero gradients in model
        dz_dt.backward(a) # backward step – this updates the gradients in the model parameters
        
        da_dt = -z.grad        
        da_theta_dt = [-p.grad for p in self.model.parameters()]
        aug_state_grad = pack_aug_state(da_dt, dz_dt.detach(), da_theta_dt)
        
        return aug_state_grad
    
def get_loss_and_adjoint(loss_fn, y, y_forward):
    '''Function to compute loss and adjoint (initial condition for backward pass)
        Args: loss_fn is function to compute loss between ground truth and predicted value
              y is ground truth
              y_forward are the outputs of the forward pass at each time step
        Returns: loss is a scalar
                 a is the adjoint at the final time – i.e., the derivative of the loss
                     with respect to the final hidden state'''
    z1 = y_forward[-1].detach().clone().requires_grad_(True) # get final hidden state
    loss = loss_fn(y, z1) # compute loss
#     loss = z1
    loss.backward() # compute adjoint at t = t_1
    a = z1.grad # get initial adjoint state (gradient of loss W.R.T. final hidden state)
    return loss, a # return loss and adjoint state

def my_backward(func, t, y_forward, a):
    '''Custom backward function to update parameters using adjoint method
        Args: func is the function representing derivative of hidden state with respect to time
              t is a vector representing the function evaluation times for the forward pass
              y_forward is a tensor representing the hidden states at each time in the forward pass'''
    t = torch.flip(t,dims=(0,)) # flip time and y_forward along time axis
    z = y_forward[::-1]
    dyn = aug_dynamics(model=func, init_hidden_state = z[0].detach()) # define hidden dynamics function
    
    theta_grad = [torch.zeros_like(p) for p in list(func.parameters())] # get model parameters
    aug_state = pack_aug_state(a,z[0],theta_grad)

    for i, t_ in enumerate(t[:-1]): # Traverse time steps and hidden states in reverse
        h_ = t[i+1] - t[i] # get step size
        aug_state[a.numel():2*a.numel()] = z[i].view(-1) # reset hidden state to be ground truth
        
        # Go to next time step, using ODE solver
        aug_state = rk_step(fun=dyn.forward, t = torch.tensor([t[i], t[i+1]]), y=aug_state, h=h_)

    # update parameters
    _,_,theta_grad = unpack_aug_state(aug_state, z[0], theta_grad)
    for i, p in enumerate(func.parameters()):
        p.grad = theta_grad[i]

def pack_aug_state(a,z,theta_grad):
    '''Function packs the adjoint state, hidden state, and parameter gradients into vector for ODE solver
        Args: a is adjoint state
              z is hidden state
              theta is list of tensors, where each tensor corresponds to single variable
        Returns: augmented state is a vector with adjoint state, hidden state, and model parameters'''
    return torch.cat([a.view(-1), z.view(-1), torch.cat([p.view(-1) for p in theta_grad])]).detach()

def unpack_aug_state(aug_state, z_dummy, parameter_dummy):
    '''Function unpacks the augmented state into the adjoint state, hidden state, and parameters
        Args: aug_state is a vector representing the augmented state
              hidden_state_shape is torch.Size representing shape of the hidden state
              parameter_shape is list-of(torch.Size) representing size of each parameter in model
        Returns: a is adjoint
                 z is hidden state
                 theta'''
    a = aug_state[:z_dummy.numel()].view(z_dummy.shape)
    z = aug_state[z_dummy.numel():2*z_dummy.numel()].view(z_dummy.shape)
    packed_params = aug_state[2*z_dummy.numel():]
    theta_grad = []
    idx = 0
    for p in parameter_dummy:
        theta_grad.append(packed_params[idx:idx+p.numel()].view(p.shape))
        idx += p.numel()
        
    return a,z,theta_grad

## Compare normal backprop with custom adjoint and torchdiffeq adjoint

#### One-dimensional hidden state

In [5]:
loss_fn = nn.MSELoss()
y = torch.randn(1)

# Parameters for Runge-Kutta
t_span = [1.,2.] # start/end times 
h=.01 # step size
t = torch.tensor(list(np.concatenate([np.arange(*t_span, h), [t_span[-1]]]))) # get evaluation times
lr = .3 # set learning rate for optimizer

import torch.optim as optim
f = ode_forward()

################ backpropping through solver ##############
torch.manual_seed(0)
# diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters
# diffeq = mini_net()
diffeq = diff_eq()
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass
loss = y_forward[-1]
loss = loss_fn(y_forward[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (backprop through solver)')
print(list(diffeq.named_parameters()))


################ using custom adjoint ##############
torch.manual_seed(0)
# diffeq = ode_forward() # function representing derivative of hidden state with respect to parameters
# diffeq = mini_net()
diffeq = diff_eq()
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

#### ADJOINT METHOD ####
diffeq.zero_grad()
y_forward = solver(fun=diffeq, t=t, y0=torch.tensor([[2.]])) # Forward pass
loss, a = get_loss_and_adjoint(loss_fn, y, y_forward)
my_backward(diffeq, t, y_forward, a)
optimizer.step()

print('\nUpdated parameters (custom adjoint)')
print(list(diffeq.named_parameters()))


################ using torchdiffeq package ##############
from torchdiffeq import odeint_adjoint, odeint
method = 'rk4'
torch.manual_seed(0)    
# diffeq = ode_forward()
# diffeq = mini_net()
diffeq = diff_eq()
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

yhat = odeint_adjoint(func=diffeq, y0=torch.tensor([2.]), t = t, method=method)
loss = loss_fn(yhat[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (torchdiffeq adjoint)')
print(list(diffeq.named_parameters()))


Updated parameters (backprop through solver)
[('a', Parameter containing:
tensor([-1336.1219], requires_grad=True)), ('b', Parameter containing:
tensor([-237.9217], requires_grad=True))]

Updated parameters (custom adjoint)
[('a', Parameter containing:
tensor([-1336.1255], requires_grad=True)), ('b', Parameter containing:
tensor([-237.9224], requires_grad=True))]

Updated parameters (torchdiffeq adjoint)
[('a', Parameter containing:
tensor([-1336.1305], requires_grad=True)), ('b', Parameter containing:
tensor([-237.9231], requires_grad=True))]


#### 2-D hidden state

In [7]:
loss_fn = nn.MSELoss()
y = torch.randn(5)
y0 = torch.randn(5)

loss_fn = nn.MSELoss()

# Parameters for Runge-Kutta
t_span = [1.,2.] # start/end times 
h=.01 # step size
t = torch.tensor(list(np.concatenate([np.arange(*t_span, h), [t_span[-1]]]))) # get evaluation times
lr = 1 # set learning rate for optimizer

import torch.optim as optim

################ backpropping through solver ##############
torch.manual_seed(0)
diffeq = small_net() # function representing derivative of hidden state with respect to parameters
# print('Original parameters')
# print(list(diffeq.named_parameters()))
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

y_forward = solver(fun=diffeq, t=t, y0=y0) # Forward pass
loss = loss_fn(y_forward[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (backprop through solver)')
print(list(diffeq.named_parameters()))


################ using custom adjoint ##############
torch.manual_seed(0)
diffeq = small_net()
solver = rk
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

#### ADJOINT METHOD ####
diffeq.zero_grad()
y_forward = solver(fun=diffeq, t=t, y0=y0) # Forward pass
loss, a = get_loss_and_adjoint(loss_fn, y, y_forward)
my_backward(diffeq, t, y_forward, a)
optimizer.step()

print('\nUpdated parameters (custom adjoint)')
print(list(diffeq.named_parameters()))


################ using torchdiffeq package ##############
from torchdiffeq import odeint_adjoint, odeint
method = 'rk4'
torch.manual_seed(0)    
diffeq = small_net()
optimizer = optim.SGD(diffeq.parameters(), lr=lr)

yhat = odeint_adjoint(func=diffeq, y0=y0, t = t, method=method)
loss = loss_fn(yhat[-1], y)
loss.backward()
optimizer.step()

print('\nUpdated parameters (torchdiffeq adjoint)')
print(list(diffeq.named_parameters()))


Updated parameters (backprop through solver)
[('W', Parameter containing:
tensor([[-1.1335, -1.2001, -0.2830, -0.4458,  0.8574],
        [ 0.7190, -0.1226, -2.0350,  0.3863, -0.2137],
        [ 1.3352, -0.4272, -0.2711,  0.6907, -0.0038],
        [-0.1010, -0.5419,  1.2960,  2.0265,  0.0359],
        [ 0.6098, -0.4651, -0.8746, -2.3297, -0.0920]], requires_grad=True)), ('Y', Parameter containing:
tensor([[-0.7411,  0.5684,  0.5148,  0.0708, -0.6752],
        [ 0.9338,  0.4604, -1.5833, -0.7748,  1.0989],
        [ 1.2833, -1.5199,  1.2361, -1.7298,  0.3181],
        [ 1.5101,  2.0870, -0.2958, -0.3247, -1.0645],
        [ 1.1116, -0.1611,  0.1360, -0.3055, -0.7695]], requires_grad=True))]

Updated parameters (custom adjoint)
[('W', Parameter containing:
tensor([[-1.1335, -1.2001, -0.2830, -0.4458,  0.8574],
        [ 0.7190, -0.1226, -2.0350,  0.3863, -0.2137],
        [ 1.3352, -0.4272, -0.2711,  0.6907, -0.0038],
        [-0.1010, -0.5419,  1.2960,  2.0265,  0.0359],
        [ 0.609