In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from itertools import combinations
import os
from pathlib import Path
import pickle
from types import SimpleNamespace
from datetime import datetime
import time

import numpy as np
import matplotlib.pyplot as plt
from sklearn import manifold
from sklearn import model_selection
from scipy.integrate import trapz
import optuna
from tqdm.auto import tqdm, trange

import notebook_setup
import ppo, utils, meta
from ppo import DEVICE
from systems import TanksPhysicalEnv, TanksFactory, plot_tanks

datadir = Path(os.path.abspath('../bin/systol21'))
os.makedirs(datadir, exist_ok=True)
os.makedirs(datadir / 'plots', exist_ok=True)

now = lambda: datetime.now().strftime('%y%m%d%H%M-');

## Environment

In [None]:
# Env parameters
alpha = 0.0367
model_seed = 0
tstep = 1000e-3
update_interval = 500
n_latent_var = 96
epochs = 10
timesteps = 40000
random_faults = False

# Single mode faults, 0 index is nominal
faults = [
        dict(),
    # valves 25
        dict(valves_min=np.asarray([0,0.25,0,0,0.25,0])),
        dict(valves_min=np.asarray([0,0,0.25,0,0,0.25])),
        dict(valves_min=np.asarray([0.25,0.25,0,0,0,0])),
        dict(valves_min=np.asarray([0,0.25,0.25,0,0,0])),
    # valves 50
        dict(valves_min=np.asarray([0,0.5,0,0.5,0,0])),
        dict(valves_min=np.asarray([0,0,0.5,0,0.5,0])),
        dict(valves_min=np.asarray([0.5,0,0.5,0,0,0])),
        dict(valves_min=np.asarray([0,0.5,0.5,0,0,0])),
#     # valve
#         dict(valves_min=np.asarray([0,0.5,0,0,0,0])),
#         dict(valves_min=np.asarray([0,0,0.5,0,0,0])),
#         dict(valves_min=np.asarray([0,0,0,0.5,0,0])),
#         dict(valves_min=np.asarray([0,0,0,0,0.5,0])),
    # leaks
        dict(leaks=np.asarray([0.005,0,0,0,0,0])),
        dict(leaks=np.asarray([0,0.005,0,0,0,0])),
        dict(leaks=np.asarray([0,0,0,0,0.005,0])),
        dict(leaks=np.asarray([0,0,0,0,0,0.005])),
    # pumps
        dict(pumps=np.asarray([1e-2,1e-3,1e-2,1e-2,1e-2,1e-2])),
        dict(pumps=np.asarray([1e-2,1e-2,1e-3,1e-2,1e-2,1e-2])),
        dict(pumps=np.asarray([1e-2,1e-2,1e-3,1e-3,1e-2,1e-2])),
        dict(pumps=np.asarray([1e-2,1e-2,1e-2,1e-3,1e-3,1e-2])),
    ]

mode_indices = dict(
    nominal = slice(0, 1),
    valves25 = slice(1, 5),
    valves50 = slice(5,9),
    # valve = slice(9,13),
    leaks = slice(9, 13),
    pumps = slice(13, 17)
)

scatter_fmt = dict(
        nominal=dict(c='black', marker='o', s=60),
        valves25=dict(c='r', marker='h', s=40),
        valves50=dict(c='lightcoral', marker='D', s=40),
        valve=dict(c='firebrick', marker='s', s=40),
        leaks=dict(c='g', marker='v', s=40),
        pumps=dict(c='b', marker='p', s=40)
      )

ls = \
    [dict(ls='-')] + \
    [dict(ls=':')] * 22


if random_faults:
    # Random composite faults
    random_ = np.random.RandomState(model_seed)
    faults = [dict()]
    for _ in range(len(faults) - 1):
        faults.append(dict(valves_min=random_.rand(6) * 0.5,
                            valves_max=random_.rand(6) * 0.5 + 0.5,
                            pumps=random_.rand(6) * 1e-2 + 5e-3,
                            engines=random_.rand(2) * 2e-2 + 1e-2))

env_params = SimpleNamespace(alpha=alpha, model_seed=model_seed, tstep=tstep,
                             update_interval=update_interval, timesteps=timesteps,
                             random_faults=random_faults, faults=faults, colors=colors, ls=ls)

In [None]:
def get_task(seed=None, randomize=False, params={}, tstep=tstep):
    env = TanksPhysicalEnv(TanksFactory(**params), seed=seed, tstep=tstep)
    return env if not randomize else env.randomize()

def make_model(seed=None, n_latent_var=n_latent_var, lr=alpha, update_interval=update_interval, epochs=epochs):
    return ppo.PPO(
        env=None,
        policy=ppo.ActorCriticDiscrete,
        state_dim=6,
        action_dim=6,
        n_latent_var=n_latent_var,
        lr=lr,
        seed=seed,
        update_interval=update_interval,
        epochs=epochs)

In [None]:
# read params
env_params = utils.read_pickle(datadir / 'env_params')
alpha = env_params.alpha
model_seed = env_params.model_seed
tstep = env_params.tstep
update_interval = env_params.update_interval
timesteps = env_params.timesteps
random_faults = env_params.random_faults
colors = env_params.colors
ls = env_params.ls
faults = env_params.faults

### RL hyperparameters

In [None]:
def objective(trial: optuna.Trial):
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    n_latent_vars = trial.suggest_int('n_latent_vars', 1, 200, log=True)
    update_interval = trial.suggest_int('update_interval', 500, 5000, step=500)
    epochs = trial.suggest_int('epochs', 1, 10)
    
    env_ = get_task(model_seed, tstep=tstep)
    agent_ = make_model(model_seed, n_latent_var, lr, update_interval, epochs)
    agent_.env = env_
    r = agent_.learn(timesteps)
    return np.mean(r[-10:])

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=25)
print(study.best_params)

## Reward components

In [None]:
env = get_task(seed=model_seed, params=faults[0])
done = False
s = env.reset()
done = False
components = []
reward = []
while not done:
    a = np.zeros(6)
    ns, r, done, _ = env.step(a)
    comp = env.reward_components(env.t - env.tstep, s, a, ns, done)
    components.append(comp)
    reward.append(r)
    s = ns

plt.figure(figsize=(8,4))
components = np.asarray(components)
for name, comp in zip(('centre', 'activity', 'spread', 'level', 'deficit'),
                      np.transpose(components)):
    plt.plot(comp, label=name)
plt.ylabel('Components')
plt.legend()
plt.twinx()
plt.plot(reward, ls=':', c='black', label='reward')
plt.ylabel('Cumulative Reward')
plt.legend()
plt.xlabel('Timesteps')
plt.title('Rewards')
plt.savefig(datadir / 'plots' / 'rewards')

## Nominal vs. faulty rewards

In [None]:
plt.figure(figsize=(12,6))
rewards = []
for f, fault in enumerate(tqdm(faults[:1] + faults[5:], total=len(faults), leave=False)):
    env = get_task(seed=model_seed, params=fault)
    done = False
    env.reset()
    reward = []
    while not done:
        _, r, done, _ = env.step(np.zeros(6))
        reward.append(r)
    rewards.append(reward)
    plt.plot(reward, label='nominal' if f==0 else str(f), ls='-' if f==0 else ':')
    print('%2d\t%4.1f\t%s' % (f, sum(reward), fault))
plt.legend()

## RL with nominal and faulty env

In [None]:
fig = plt.figure(figsize=(12, 6))
for f, fault in enumerate(tqdm(faults[:2], total=len(faults), leave=False)):
    env = get_task(model_seed, params=fault)
    agent = make_model(model_seed)
    agent.env = env
    r = agent.learn(timesteps * 1.5)
    plt.plot(utils.rollmean(r, 5), label='nominal' if f==0 else str(f), ls=':')
plt.legend()
plt.title('lr=%.3f, tstep=%.3f, interval=%d, steps=%d' % (alpha, tstep, update_interval, timesteps))

## Transfer RL using proximal policies

* Sample memory
* Find closest policy in library
* Substitute and learn
* Add to library and prune

Benchmark against:

* Vanilla RL
* Bayesean policy re-use

### Experiment Setup

In [None]:
def adapt(agent, env, library, timesteps):
    m = ppo.Memory()
    agent.experience(m, timesteps, env, agent.policy)
    idx, vals = meta.rank_library(library, agent, m)
    return library[idx[0]], idx[0]

def component_agg(store: dict):
    def agg(lcl: dict):
        info = lcl.get('info')
        done = lcl.get('done')
        reward = lcl.get('reward')
        
        if done:
            for key, value in info.items():
                store[key][-1].append(0.)
        else:
            for key, value in info.items():
                store[key][-1][-1] += value
    return agg  

In [None]:
# train nominal environment
agent = make_model(model_seed) # agent is the global agent class, agent_ is local
env = get_task(model_seed)     # env is the global nominal env class, env_ is local
agent.env = env
nominal_r = agent.learn(timesteps)
nominal_params = agent.policy.state_dict()

In [None]:
policies[0] = nominal_params
policies_r[0] = nominal_r

In [None]:
# train policies for all faults
# each entry in faults corresponds to an entry in policies
policies = [nominal_params]
policies_r = [nominal_r]
for fault in tqdm(env_params.faults[1:], leave=False, desc='Training policies for all faults'):
    agent_ = make_model(model_seed)
    env_ = get_task(model_seed, params=fault)
    agent_.env = env_
    policies_r.append(agent_.learn(timesteps))
    policies.append(agent_.policy.state_dict())

In [None]:
# plot policy rewards
for i, r in enumerate(policies_r[:1] + policies_r[1:]):
    plt.plot(utils.rollmean(r, 5), label='nominal' if i==0 else str(i), ls='-' if i==0 else ':')
# plt.legend()

In [None]:
# save policies
latest_policies_version = now() + 'policies'
utils.write_pickle([policies, policies_r, env_params], datadir / latest_policies_version)

In [None]:
# load policies
policies, policies_r, env_params = utils.read_pickle(datadir / latest_policies_version)

In [None]:
# Set up the domain of faults to use for the experiments to follow,
# if manual=False, domain_* lists are used to to cross-validation splits
domain = 'novelpumps'
prune = 0
manual_split = False

if domain in mode_indices:
    domain_idx = mode_indices[domain]
elif domain.startswith('valves'):
    domain_idx = slice(1, 9)
elif domain.startswith('novelvalves'):
    manual_split = True
    train_idx = [[13,14,15,16], [9,10,11,12]]
    test_idx = [[5], [6]]
elif domain.startswith('novelpumps'):
    manual_split = True
    train_idx = [[5,6,7,8], [9,10,11,12]]
    test_idx = [[13], [14]]
elif domain.startswith('novelleaks'):
    manual_split = True
    train_idx = [[5,6,7,8], [13,14,15,16]]
    test_idx = [[9], [10]]
elif domain.startswith('leakspumps'):
    domain_idx = slice(11,15)
elif domain.startswith('valveleaks'):
    domain_idx = slice(3, 7)
else:
    domain_idx = slice(None)

if not manual_split:
    if isinstance(domain_idx, slice):
        domain_faults = faults[domain_idx]
        domain_policies = policies[domain_idx]
    else:
        domain_faults = [faults[i] for i in domain_idx]
        domain_policies = [policies[i] for i in domain_idx]
    xvalidator = model_selection.LeavePOut(1)
    bench_iterator = enumerate(tqdm(xvalidator.split(domain_policies),
                              leave=False,
                              total=xvalidator.get_n_splits(domain_policies),
                              desc='Cross-validation split'))
else:
    domain_policies = policies
    domain_faults = faults
    bench_iterator = enumerate(tqdm(zip(train_idx, test_idx),
                              leave=False,
                              total=len(train_idx),
                              desc='Cross-validation split'))

### Experiment - Benchmark

In [None]:
# Load existing experiment
exp_name = '2104141120-benchmark-cv-%s-prune%d' % (domain, prune)
exp_data = utils.read_pickle(datadir / exp_name)

In [None]:
# Set up new experiment
exp_name = now() + ('benchmark-cv-%s-prune%d' % (domain, prune))
exp_data = dict(
    vanilla_rs = [],
    vanilla_comp = dict(centre=[], activity=[], spread=[], level=[], deficit=[]),
    adapt_rs = [],
    adapt_comp = dict(centre=[], activity=[], spread=[], level=[], deficit=[]),
    other_rs = [],
    param_idxs = [],
    fault_idxs = [],
    time_adapt = [],
    time_prune = [],
    nominal_r = nominal_r,
    domain_idx = domain_idx,
    domain = domain,
    prune = prune,
    kept_idx = [],
    env_params = env_params
)

In [None]:
print(exp_name)
fault_rs = dict()
# use subset of faults/policies for train/cross validation
for i, (train_faults, test_faults) in bench_iterator:
    library = [domain_policies[fidx] for fidx in train_faults]
    if prune > 0:
        memory = ppo.Memory()
        agent.experience(memory, update_interval, agent.env, agent.policy)
        start = time.process_time()
        library, _, kept_idx = meta.prune_library(library, agent, memory, max_library_size=prune, seed=model_seed)
        end = time.process_time()
        exp_data['kept_idx'].append(kept_idx)
        exp_data['time_prune'].append(end - start)

    # Benchmarking
    fidx = test_faults[0]
    fault = domain_faults[fidx] # only 1 test fault (leave 1 out)
    exp_data['fault_idxs'].append(fidx)

    # Vanilla RL
    agent_ = make_model(model_seed)
    agent_.policy.load_state_dict(nominal_params)
    env_ = get_task(model_seed, params=fault)
    agent_.env = env_
    for _, v in exp_data['vanilla_comp'].items(): v.append([0.])
    exp_data['vanilla_rs'].append(agent_.learn(timesteps, step_callback=component_agg(exp_data['vanilla_comp'])))

    # Proximal policy transfer RL
    agent_ = make_model(model_seed)
    env_ = get_task(model_seed, params=fault)
    agent_.env = env_
    m = ppo.Memory()
    start = time.process_time()
    params, idx = adapt(agent_, env_, [nominal_params] + library, update_interval)
    end = time.process_time()
    exp_data['time_adapt'].append(end - start)
    idx -= 1  # nominal params are prepended to library, converting back to library index
              # if -1, means that nominal params were most suitable
    exp_data['param_idxs'].append(idx)
    agent_.policy.load_state_dict(params)
    for _, v in exp_data['adapt_comp'].items(): v.append([0.])
    exp_data['adapt_rs'].append(agent_.learn(timesteps, step_callback=component_agg(exp_data['adapt_comp'])))

    # Remainder of policies, to see whether the sampled policy is indeed better
    r = []
    for i, params in enumerate(tqdm(library, leave=False, desc='Library benchmark')):
        if i==idx:
            continue  # idx was selected policy, already calculated
        elif (i, fidx) in fault_rs:
            continue
        else:
            agent_ = make_model(model_seed)
            env_ = get_task(model_seed, params=fault)
            agent_.env = env_
            agent_.policy.load_state_dict(params)
            fault_rs[(i, fidx)] = agent_.learn(timesteps)
        r.append(agent_.learn(timesteps))
    exp_data['other_rs'].append(r)
    utils.write_pickle(exp_data, datadir / exp_name) # write with each cv split, so we can check results before cell finishes

In [None]:
# Tabulated numbers
exp_name = '2105172156-benchmark-cv-novelpumps-prune0'
exp_data = utils.read_pickle(datadir / exp_name)

print('ADAPT')
a = utils.homogenous_array(exp_data['adapt_rs'])
am = np.mean(a, axis=0)
aauc = trapz(am)
astd = np.std(a, axis=0)
print('Total:', aauc)
for comp, vals in exp_data['adapt_comp'].items():
    c = utils.homogenous_array(vals)
    cm = np.mean(c, axis=0)
    cauc = trapz(cm)
    print(comp, cauc)

print('\nVANILLA')
v = utils.homogenous_array(exp_data['vanilla_rs'])
vm = np.mean(v, axis=0)
vauc = trapz(vm)
print('Total:', vauc)
for comp, vals in exp_data['vanilla_comp'].items():
    c = utils.homogenous_array(vals)
    cm = np.mean(c, axis=0)
    cauc = trapz(cm)
    print(comp, cauc)

print('\nAVERAGE')
x = []
for other_x in exp_data['other_rs']: x.extend(other_x)
x = utils.homogenous_array(x)
xm = np.mean(x, axis=0)
xauc = trapz(xm)
print('Total:', xauc)

In [None]:
# plotting params
rolling_mean = int(np.sqrt(len(nominal_r)) / 2)
nominal_span = rolling_mean * 2

In [None]:
# Total Reward performance per split
rows = len(exp_data['vanilla_rs'])
fig, axes = plt.subplots(rows, 1, figsize=(6, 3 * rows), squeeze=False)
axes = np.ravel(axes)
for ax, vr, ar, xr, pidx, fidx in zip(axes,
                                      exp_data['vanilla_rs'],
                                      exp_data['adapt_rs'],
                                      exp_data['other_rs'],
                                      exp_data['param_idxs'],
                                      exp_data['fault_idxs']):
    ax.plot(np.arange(-nominal_span+1, 1), nominal_r[-nominal_span:], c='b')
    ax.axvline(0, c='r', ls='-.')
    ax.plot(utils.rollmean(ar, rolling_mean), c='g')
    ax.plot(utils.rollmean(vr, rolling_mean), c='b')
    for r in xr:
        ax.plot(utils.rollmean(r, rolling_mean), c='g', ls=':')
    ax.set_ylabel('Reward / episode')
    ax.set_xlabel('Episode')
plt.tight_layout()
plt.savefig(datadir / 'plots' / exp_name)

In [None]:
# Average reward performance over all splits
v = utils.homogenous_array(exp_data['vanilla_rs'])
x = []
for x_ in exp_data['other_rs']:
    x.extend(x_)
x = utils.homogenous_array(x)
a = utils.homogenous_array(exp_data['adapt_rs'])
vstd = utils.rollmean(np.std(v, axis=0), rolling_mean)
xstd = utils.rollmean(np.std(x, axis=0), rolling_mean)
astd = utils.rollmean(np.std(a, axis=0), rolling_mean)
vm = utils.rollmean(np.mean(v, axis=0), rolling_mean)
xm = utils.rollmean(np.mean(x, axis=0), rolling_mean)
am = utils.rollmean(np.mean(a, axis=0), rolling_mean)
plt.plot(np.arange(-nominal_span+1, 1), nominal_r[-nominal_span:], c='b', label='Vanilla RL')
plt.gca().axvline(0, c='r', ls='-.', label='Fault')
plt.plot(vm, c='b')
plt.plot(xm, c='lime', ls='-.', label='Random average')
plt.plot(am, c='g', label='Ours')
plt.fill_between(np.arange(len(vm)), vm-vstd, vm+vstd, alpha=0.3, color='b')
plt.fill_between(np.arange(len(am)), am-astd, am+astd, alpha=0.3, color='g')
plt.fill_between(np.arange(len(xm)), xm-xstd, xm+xstd, alpha=0.3, color='lime')
plt.ylabel('Reward / episode')
plt.xlabel('Episode')
plt.tight_layout()
plt.savefig(datadir / 'plots' / (exp_name + '-agg'))

In [None]:
# Average component performance over all splits
vanilla_comp = {key: utils.homogenous_array(val) for key, val in exp_data['vanilla_comp'].items()}
adapt_comp = {key: utils.homogenous_array(val) for key, val in exp_data['adapt_comp'].items()}

v = utils.homogenous_array(exp_data['vanilla_rs'])
a = utils.homogenous_array(exp_data['adapt_rs'])
vstd = utils.rollmean(np.std(v, axis=0), rolling_mean)
astd = utils.rollmean(np.std(a, axis=0), rolling_mean)
vm = utils.rollmean(np.mean(v, axis=0), rolling_mean)
am = utils.rollmean(np.mean(a, axis=0), rolling_mean)

plt.plot(np.arange(-nominal_span+1, 1), nominal_r[-nominal_span:], c='b', label='Vanilla RL')
plt.gca().axvline(0, c='r', ls='-.', label='Fault')
plt.plot(vm, c='b')
plt.plot(am, c='g', label='Ours')
plt.fill_between(np.arange(len(vm)), vm-vstd, vm+vstd, alpha=0.3, color='b')
plt.fill_between(np.arange(len(am)), am-astd, am+astd, alpha=0.3, color='g')
plt.ylabel('Reward / episode')
plt.xlabel('Episode')

plt.twinx()
handles = []
for comp, ls in zip((vanilla_comp, adapt_comp), (':', '-')):
    for key, val in comp.items():
        handles.append(plt.plot(utils.rollmean(np.mean(val, axis=0), rolling_mean), label=key, ls=ls)[0])
    plt.gca().set_prop_cycle(None)
plt.legend(handles=handles[:len(handles)//2])
plt.tight_layout()
plt.savefig(datadir / 'plots' / (exp_name + '-comp-agg'))

### Experiment - Policy distances

In [None]:
# plotting policies in relation to each other using Multi-dimensional scaling
memory = ppo.Memory()
agent.experience(memory, update_interval, agent.env, agent.policy)

metrics = {'jensenshannon': None, 'euclidean': None, 'l1': None, 'cosine': None}
for metric_name in metrics:
    distmat = np.zeros((len(policies), len(policies)))
    for p1, p2 in combinations(range(len(policies)), 2):
        distmat[p1, p2], distmat[p2, p1] = meta.distance(memory, agent.policy,
                                                         policies[p1], policies[p2],
                                                         metric=metric_name)
    distmat /= np.max(distmat)
    proj = manifold.MDS(dissimilarity='precomputed', random_state=model_seed)
    coords = proj.fit_transform(distmat)
    metrics[metric_name]= dict(distmat=distmat, coords=coords)

In [None]:
plt.figure(figsize=(12,12))
for i, (metric_name, metric) in enumerate(metrics.items()):
    plt.subplot(2, 2, i+1)
    coords = metric['coords']
    for mode, idx in mode_indices.items():
        plt.scatter(coords[idx, 0], coords[idx, 1], label=mode, **scatter_fmt[mode])
    plt.gca().set_aspect('equal', 'datalim')
    plt.gca().set_xlim(-0.6, 0.6)
    plt.gca().set_ylim(-0.6, 0.6)
    plt.grid(True)
    for i in range(len(policies)):
        plt.annotate(str(i), coords[i], coords[i]+3, textcoords='offset points')
    plt.title('Projection - ' + metric_name)
plt.legend()
plt.savefig(datadir / 'plots' / 'mds-metrics', pad_inches=0.1, bbox_inches='tight')

In [None]:
plt.figure(figsize=(12, 12))
axes = []
enum_faults = np.arange(len(faults))
for i, (metric_name, metric) in enumerate(metrics.items()):
    ax = plt.subplot(2, 2, i+1)
    axes.append(ax)
    distmat = metric['distmat']
    im = plt.matshow(distmat, 0)
#     plt.colorbar()
    for mode, idx in mode_indices.items():
        fidx = enum_faults[idx]
        start, end = fidx[0], fidx[-1]
        plt.axvline(start - 0.5, c=scatter_fmt[mode]['c'])
        plt.axvline(end + 0.5, c=scatter_fmt[mode]['c'])
        plt.axhline(start - 0.5, c=scatter_fmt[mode]['c'])
        plt.axhline(end + 0.5, c=scatter_fmt[mode]['c'])
    plt.title(metric_name)
# plt.legend()
fig.colorbar(im, ax=axes)
fig.tight_layout()
plt.savefig(datadir / 'plots' / 'distmats-metrics', pad_inches=0, bbox_inches='tight')

### Experiment contd. - Pruning and policy space

* Progressively pruning and which policies are kept

In [None]:
exp_name = now() + 'pruned-space'
exp_data = dict(
    prunes = [len(policies[1:]), len(policies[1:]) // 2 , len(policies[1:]) // 3, len(policies[1:]) // 4],
    distmat = distmat,
    coords = coords,
    domain_idx = domain_idx,
    keeps = [],
    env_params = env_params
)

for prune in tqdm(exp_data['prunes'], leave=False):
    memory = ppo.Memory()
    agent.experience(memory, update_interval, agent.env, agent.policy)
    _, _, keep = meta.prune_library(policies[1:], agent, memory, max_library_size=prune, seed=model_seed)
    exp_data['keeps'].append([k+1 for k in keep])  # convert policies[1:] indices to policies indices

keeps = exp_data['keeps']
prunes = exp_data['prunes']
utils.write_pickle(exp_data, datadir / exp_name)

In [None]:
rows = (len(exp_data['prunes']) + 1) // 2
fig, axes = plt.subplots(rows, 2, figsize=(8, 4 * rows), squeeze=False)
coords = exp_data['coords']
axes = np.ravel(axes)

indices = list(np.arange(len(faults)))
midx = [indices[idx] for idx in mode_indices.values()]
handles = []
for i, (ax, keep) in enumerate(zip(axes, exp_data['keeps'])):
    h = ax.scatter(coords[:1,0], coords[:1, 1], s=60, c='black', label='Nominal')
    if i==0: handles.append(h)
    for mode, idx in zip(mode_indices.keys(), midx):
        if mode=='nominal': continue
        mode_keep = list(set(idx).intersection(set(keep)))
        kept_coords = coords[np.ix_(mode_keep)]
        if len(kept_coords):
            h = ax.scatter(kept_coords[:,0], kept_coords[:, 1], label=mode, **scatter_fmt[mode])
            if i==0: handles.append(h)

    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(axes[0].get_xlim())
    ax.set_ylim(axes[0].get_ylim())
    ax.set(xticks=[], yticks=[])
    ax.set_title('No pruning' if i==0 else ('%d policies' % len(keep)))

plt.legend(handles=handles, fontsize='medium')
plt.tight_layout()
plt.savefig(datadir / 'plots' / 'mds-prunes')

### Experiment contd. - Policy Sampling

* Finding best policy for each fault

In [None]:
exp_name = now() + 'policy-sampling'
exp_data = dict(
    fault_idx = [],
    selected_idx = [],
    env_params = env_params,
    faults = faults,
    prunes = prunes,
    keeps = keeps,
    coords = coords
)

for keep in tqdm(keeps, leave=False):
    domain_faults = [faults[k] for k in keep]    # test for all faults (excluding nominal idx==0)
    domain_policies = [policies[k] for k in keep]
    sample_iterator = model_selection.LeavePOut(1).split(domain_faults)
    fidx, sidx = [], []
    for (train_idx, test_idx) in tqdm(sample_iterator, leave=False, total=len(domain_faults)):
        library = [domain_policies[idx] for idx in train_idx]
        agent_ = make_model(model_seed)
        env_ = get_task(model_seed, params=domain_faults[test_idx[0]])
        _, sample_idx = adapt(agent_, env_, [nominal_params] + library, update_interval)
        sample_idx = keep[train_idx[sample_idx - 1]]  # convert to local library idx, and then to idx in domain_policies
        fidx.append(keep[test_idx[0]])
        sidx.append(sample_idx)
    exp_data['fault_idx'].append(fidx)  # convert from domain_faults to faults indexing
    exp_data['selected_idx'].append(sidx)

In [None]:
rows = (len(exp_data['prunes']) + 1) // 2
fig, axes = plt.subplots(rows, 2, figsize=(8, 4 * rows), squeeze=False)
coords = exp_data['coords']
axes = np.ravel(axes)

indices = list(np.arange(len(faults)))
midx = [indices[idx] for idx in mode_indices.values()]
handles = []
for i, (ax, keep, fidx, sidx) in enumerate(zip(axes, exp_data['keeps'], exp_data['fault_idx'], exp_data['selected_idx'])):
    # Plot projections
    h = ax.scatter(coords[:1,0], coords[:1, 1], s=60, c='black', label='Nominal')
    if i==0: handles.append(h)
    for mode, idx in zip(mode_indices.keys(), midx):
        if mode=='nominal': continue
        mode_keep = list(set(idx).intersection(set(keep)))
        kept_coords = coords[np.ix_(mode_keep)]
        if len(kept_coords):
            h = ax.scatter(kept_coords[:,0], kept_coords[:, 1], label=mode, **scatter_fmt[mode])
            if i==0: handles.append(h)
    # Plot arrows between projections
    for fi, si in zip(fidx, sidx):
        dx, dy = coords[si] - coords[fi]
        ax.arrow(coords[fi, 0], coords[fi, 1],  dx, dy, length_includes_head=True, width=1e-4, head_width=10e-4, ls=':')
    
    for no in keep:
        ax.annotate(str(no), coords[no], coords[no]+3, textcoords='offset points')

    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(axes[0].get_xlim())
    ax.set_ylim(axes[0].get_ylim())
    ax.set(xticks=[], yticks=[])
    ax.set_title('No pruning' if i==0 else ('%d policies' % len(keep)))
plt.tight_layout()
plt.legend(handles=handles, fontsize='medium')
plt.savefig(datadir / 'plots' / 'mds-sampling')