In [1]:
import sys, os, time
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
import gym
import h5py as hf

import foundation as fd
from foundation import util
from foundation import nets
from foundation import train

from nb_backend import *

In [2]:
args = util.NS()

args.hidden_dims = [32, 32]
args.nonlin = 'elu'

In [3]:
env = gym.make('Pendulum-v0')
env.reset()

array([0.94204949, 0.33547394, 0.81032408])

In [4]:
args.din, args.dout = env.observation_space.shape[0], env.observation_space.shape[0]

In [5]:
model = nets.make_MLP(args.din, args.dout, hidden_dims=args.hidden_dims, nonlin=args.nonlin)
model

Sequential(
  (0): Linear(in_features=3, out_features=32, bias=True)
  (1): ELU(alpha=1.0, inplace)
  (2): Linear(in_features=32, out_features=32, bias=True)
  (3): ELU(alpha=1.0, inplace)
  (4): Linear(in_features=32, out_features=3, bias=True)
)

In [6]:
optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [7]:
true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

    
ii = 0

func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()

time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

for itr in range(1, args.niters + 1):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint(func, batch_y0, batch_t)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()

    time_meter.update(time.time() - end)
    loss_meter.update(loss.item())

    if itr % args.test_freq == 0:
        with torch.no_grad():
            pred_y = odeint(func, true_y0, t)
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
            visualize(true_y, pred_y, func, ii)
            ii += 1

    end = time.time()

AttributeError: 'NS' object has no attribute 'data_size'