In [1]:
import logging
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)  # set level to INFO for wordy
import tqdm

from mpl_toolkits import mplot3d
from matplotlib import cm
import matplotlib.pyplot as plt

import numpy as np

from extravaganza.dynamical_systems import KSwitchingLTV
from extravaganza.observables import Observable, TimeDelayedObservation, Trajectory
from extravaganza.lifters import Identity, NN
from extravaganza.system_models import LiftedController, OfflineSysid, HardFTH
from extravaganza.explorer import Explorer
from extravaganza.utils import summarize_lds, sample, opnorm, SAMPLING_METHOD

# seeds for randomness. setting to `None` uses random seeds
SYSTEM_SEED = None
SYSID_SEED = None

INFO: Created a temporary directory at /var/folders/5m/0xr906c130vdqvkm3g21n6wr0000gn/T/tmplnfk5vp3
INFO: Writing /var/folders/5m/0xr906c130vdqvkm3g21n6wr0000gn/T/tmplnfk5vp3/_remote_module_non_scriptable.py


In [None]:
def run(observable, ds: int, du: int, k: int, switch_every: int,
        T0: int, reset_every: int, sysid_args):
    """
    fn to gather data and train a sysid.
    """
    # make system
    system = KSwitchingLTV(ds, du, k, switch_every, disturbance_type='none', cost_fn='quad')

    # make sysid
    sysid = Lifter(**sysid_args)

    # interaction loop
    control = np.zeros(du)
    max_sq_norm = 0.
    traj = Trajectory()
    for t in tqdm.trange(T0):
        if t % reset_every == 0: 
            system.reset()
            system.state = sample((ds,), sampling_method='sphere')
            sysid.end_trajectory()
            traj = Trajectory()
        cost, state = system.interact(control)
        traj.add_state(cost, state)
        obs = observable(traj)
        control = sysid.explore(cost, obs)
        traj.add_control(control)
        max_sq_norm = max(max_sq_norm, np.linalg.norm(state) ** 2)
    sysid.end_exploration(wordy=True)
    return A, B, sysid, max_sq_norm