In [1]:
import os
import argparse
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from argparse import Namespace


args = Namespace(method='dopri5', data_size=1000, batch_time=10, batch_size=20, niters=2000, test_freq=20, viz=False, gpu=0, adjoint=False, device='cpu')

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint




true_y0 = torch.tensor([[2., 0.]]).to(args.device)
t = torch.linspace(0., 25., args.data_size).to(args.device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(args.device)


class Lambda(nn.Module):
    def forward(self, t, y):
        return torch.mm(y**3, true_A)

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

def get_batch(device):
    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)


class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y**3)
    

func = ODEFunc().to(args.device)

optimizer = optim.RMSprop(func.parameters(), lr=1e-3)

for itr in range(1, args.niters + 1):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch(args.device)
    pred_y = odeint(func, batch_y0, batch_t).to(args.device)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()

    if itr % args.test_freq == 0:
        with torch.no_grad():
            pred_y = odeint(func, true_y0, t)
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))

    end = time.time()

Iter 0020 | Total Loss 0.637718
Iter 0040 | Total Loss 0.762712
Iter 0060 | Total Loss 0.805768
Iter 0080 | Total Loss 0.870600
Iter 0100 | Total Loss 0.368864
Iter 0120 | Total Loss 0.659126
Iter 0140 | Total Loss 0.560290
Iter 0160 | Total Loss 0.298013
Iter 0180 | Total Loss 0.407210
Iter 0200 | Total Loss 0.313286
Iter 0220 | Total Loss 0.279348
Iter 0240 | Total Loss 0.322765
Iter 0260 | Total Loss 0.486933
Iter 0280 | Total Loss 0.482868
Iter 0300 | Total Loss 0.254792
Iter 0320 | Total Loss 0.427532
Iter 0340 | Total Loss 0.267018
Iter 0360 | Total Loss 0.547727
Iter 0380 | Total Loss 0.258040
Iter 0400 | Total Loss 0.580842
Iter 0420 | Total Loss 0.249644
Iter 0440 | Total Loss 0.180842
Iter 0460 | Total Loss 0.359119
Iter 0480 | Total Loss 0.676958
Iter 0500 | Total Loss 0.236674
Iter 0520 | Total Loss 0.327736
Iter 0540 | Total Loss 0.303796
Iter 0560 | Total Loss 0.566056
Iter 0580 | Total Loss 0.613506
Iter 0600 | Total Loss 0.277182
Iter 0620 | Total Loss 0.231397
Iter 064