In [None]:
import warnings
import matplotlib
warnings.filterwarnings("ignore", category=matplotlib.MatplotlibDeprecationWarning)

In [None]:
# %matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
# import pandas as pd
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
from src.microcircuit import *
from src.save_exp import *
import sys
# import pandas as pd
from scipy.stats import sem

plt.style.use('./matplotlib_style.mplstyle')
cm_to_inch = 1/2.54  # centimeters in inches
pagewidth = 17 # width in cm

#load autoreload, which automatically reloads the microcircuit.py upon execution
%reload_ext autoreload
%autoreload 1
%aimport src.microcircuit

In [None]:
def moving_average(x, w, n_repeat=None):
    if n_repeat==None:
        n_repeat = w
    x = np.repeat(x[::n_repeat], n_repeat, axis=0)
    return np.convolve(x, np.ones(w), 'valid') / w

In [None]:
MAIN_PATH = ''

PATHS = ['experiments/Fig2a_cartpole_ideal_emc/', 'experiments/Fig2a_cartpole_tanh_emc/']
PATH_APPENDIX = ['']
NAME = ''
OUTPUT_PATH = 'experiments/cartpole_plots/'

### Plot

In [None]:
network_MC_list = []
for PATH in PATHS:
    for PATH_APP in PATH_APPENDIX:
        model_file = MAIN_PATH + PATH + '/' + PATH_APP + '/model.pkl'
        MC_list = src.save_exp.load(model_file)
        network_MC_list.append(MC_list)

In [None]:
# define the number of recorded time steps which belong to the pre-training
# and therefore should be skipped in plotting
TPRE = int(MC_list[0].settling_time / MC_list[0].dt / MC_list[0].rec_per_steps)

linestyles = ['dotted', 'dashdot', 'dashed','solid']

LABEL_MULTIPLIER = int(1/ MC_list[0].rec_per_steps * MC_list[0].Tpres * MC_list[0].dataset_size / MC_list[0].dt)

In [None]:
T = 200  # msecs
n_subsims = 10

In [None]:
# for mc in network_MC_list[0][7:9]:
#     for i in range(2):
#         plt.figure()
#         data = [arr[i] for arr in mc.WPP_time_series]
#         data = np.array(data)
#         if i == 0:
#             plt.plot(data.reshape(-1,16))
#         elif i == 1:
#             plt.plot(data.reshape(-1,4))
#         plt.ylim(-1.5,1.5)
#         plt.show()

In [None]:
x_lim = 2.4 # limits for reset
theta_lim = 12 * 2 * np.pi / 360

## Penalty log

In [None]:
def calc_val_penalty(MC_list):
    """
        Calculates penalty per epoch from the the penalty log
        Penalty log: 1 every time the pendulum falls, else 0
    """
    # format: seeds x time x [x,theta]
    data = [mc.penalty_log for mc in MC_list]
    data = np.array(data)
    mc = MC_list[0]
    
    N_dt_per_epoch = int(T / mc.dt)
    N_dt_per_sample = int(T / mc.dt / n_subsims)
    N_seeds = len(data)
    
    # calculate mse over epoch
    epoch_val_data = data.reshape(N_seeds, -1, N_dt_per_epoch-1)
    # skip the couple of entries after  entry per epoch, that is 1 due to reset
#     print(epoch_val_data)
    epoch_val_data[:,:,0] = 0
    # sum over penalties within epoch
    epoch_val_data = np.sum(epoch_val_data, axis=2)
   
    return epoch_val_data

In [None]:
val_penalty_tanh = calc_val_penalty(MC_list=network_MC_list[1][1:])
val_penalty_ideal_emc = calc_val_penalty(MC_list=network_MC_list[0][1:])
val_penalty_teacher = calc_val_penalty(MC_list=network_MC_list[0][:1])

In [None]:
fig, ax = plt.subplots(figsize=(pagewidth*cm_to_inch/2.2, pagewidth*cm_to_inch/4))

plt.ylabel("penalty \n per epoch")
plt.xlabel("trained epochs")
colors = ['C0', 'C1']
model_labels = ['error MC \n linear, $\sigma_{L} = 0.0$', 'error MC \n tanh, $\sigma_{L} = 0.3$']

xoffset = 0.17
width = 0.25

data = val_penalty_ideal_emc
xpos = np.arange(1, len(data[0])+1, 1)
# median = np.median(data, axis=0)
# percentile_25 = np.percentile(data, 25, axis=0)
# percentile_75 = np.percentile(data, 75, axis=0)
# percentiles = [percentile_25, percentile_75]
# plt.errorbar(np.arange(0, len(median)), y=median, yerr=percentiles, fmt='o', capsize=3, label='error MC \n (linear, $\sigma_\\textrm{lat} = 0.0$)', markersize=4)

flierprops = dict(marker='x', markerfacecolor='None', markersize=3,  markeredgecolor='black')
bplot1 = plt.boxplot(data, patch_artist=True, flierprops=flierprops, positions=xpos-xoffset, widths=width)

# fill with colors
for patch in bplot1['boxes']:
    patch.set_facecolor(colors[0])
for median in bplot1['medians']:
    median.set_color('black')


data = val_penalty_tanh
# median = np.median(data, axis=0)
# percentile_25 = np.percentile(data, 25, axis=0)
# percentile_75 = np.percentile(data, 75, axis=0)
# percentiles = [percentile_25, percentile_75]
# plt.errorbar(np.arange(0, len(median)), y=median, yerr=percentiles, fmt='o', capsize=3, label='error MC \n (tanh, $\sigma_\\textrm{lat} = 0.3$)', markersize=4)
bplot2 = plt.boxplot(data, patch_artist=True, flierprops=flierprops, positions=xpos+xoffset, widths=width)

# fill with colors
for patch in bplot2['boxes']:
    patch.set_facecolor(colors[1])
for median in bplot2['medians']:
    median.set_color('black')

ax.set_xticks(xpos)
ax.set_xticklabels(xpos)

# plot labels once
for lab, c in zip(model_labels, colors):
    # matplotlib bar doesn't like , so we set width = 0
    # https://github.com/matplotlib/matplotlib/issues/21506
    ax.bar([1],[1], label=lab, color=c, width=0)

data = val_penalty_teacher
median = np.median(data)
plt.axhline(median, c='black', label='LQR')

# ax[0].legend(prop={'size': 10}, loc='upper center', bbox_to_anchor=(1.35, .5))#ncol=1, labelspacing=-1.1, bbox_to_anchor=(0.5, 1.5))
l = ax.legend(prop={'size': 8}, loc='upper center', ncol=4, handlelength=1, bbox_to_anchor=(0.4, 1.55), columnspacing=1, handletextpad=0.6)
# for t in l.get_texts():
#     t.set_ha('center') 
# l.get_texts()[-1].set_va('baseline')
# plt.setp(l.get_texts(), multialignment='center')

plt.yscale('log')
plt.minorticks_on()
ax.tick_params(axis='x', which='minor', bottom=False)
# plt.yticks()
plt.yticks([1, 10, 100, 1000])
plt.ylim([0.5, 2000])
plt.tight_layout()
plt.subplots_adjust(left = 0.21, top = 0.76, right = 0.92, bottom = 0.28, hspace = 0.2, wspace = 0.0)
plt.savefig(OUTPUT_PATH + "val_penalty.pdf")
plt.show()

## Example trajectories

we pick a net and show trajectories before and after training

In [None]:
state_names = [r"$x$", r"$\dot{x}$", r"$\theta$ [rad]", r"$\dot{\theta}$"]

In [None]:
# for mc in network_MC_list[0][:1]:
for mc in network_MC_list[0][1:]:
# for mc in network_MC_list[0][7:9]:
    print(mc.state_log.shape)
    
    plt.figure()
      
    for i in [0, 2]:
        plt.plot(mc.state_log[:, i], label=state_names[i])
    plt.plot(mc.action_log, label="reset", alpha=0.5, ls="--")
    plt.ylim(-1.5, 1.5)
    plt.title("Cartpole dynamics with error-mc controller")
    file_name = 'cartpole_dynamics.png'
    
#     plt.legend()
#     plt.savefig(PATH + file_name, dpi=200)

#     # only first epoch
#     plt.xlim(0, T/mc.dt / 10)
    # only last epoch
    plt.xlim(len(mc.state_log)-T/mc.dt, len(mc.state_log))
    plt.show()

In [None]:
# # fig, ax = plt.subplots(nrows=1, figsize=(pagewidth*cm_to_inch/2.2, pagewidth*cm_to_inch/3.5))
# fig, ax = plt.subplots(nrows=2, figsize=(pagewidth*cm_to_inch/2.2, pagewidth*cm_to_inch/3.5))

# mc = network_MC_list[0][-4]

# X_TIME_RANGE = 10_000

# for i in [0, 2]:
#     # first epoch
#     data = mc.state_log[:X_TIME_RANGE, i]
#     x = np.arange(len(data)) * mc.dt
#     ax[0].plot(x, data, label=state_names[i])    
    
# data = mc.action_log[:X_TIME_RANGE]
# ax[0].plot(x, data, ls = ':', c='black')

# for i in [0, 2]:
#     # first epoch
#     data = mc.state_log[int(len(mc.state_log)-X_TIME_RANGE):int(len(mc.state_log)), i]
#     x = np.arange(len(data)) * mc.dt
#     ax[1].plot(x, data, label=state_names[i])
    
# data = mc.action_log[int(len(mc.action_log)-X_TIME_RANGE):int(len(mc.action_log))]
# ax[1].plot(x, data, label="action (force)", ls = ':', c='black')
    
    
    
# ax[0].tick_params(labelbottom=False)
# # plt.plot(x, data, label="action (force)", alpha=0.7)
# for ax in ax:
#     ax.set_ylim(-0.75, 0.75)

# # plt.savefig(OUTPUT_PATH + file_name, dpi=400)
# plt.xlabel('time [ms]')
# plt.subplots_adjust(left = 0.1, top = 0.835, right = 0.93, bottom = 0.25, hspace = 0.2, wspace = 0.0)
# plt.show()

In [None]:
# fig, ax = plt.subplots(nrows=1, figsize=(pagewidth*cm_to_inch/2.2, pagewidth*cm_to_inch/3.5))
fig, ax = plt.subplots(nrows=2, figsize=(pagewidth*cm_to_inch/2.1, pagewidth*cm_to_inch/3.2))

mc = network_MC_list[1][0]

X_TIME_RANGE = 6_000
X_TIME_OFFSET = 18_000

# multiply x by this value
TIME_MULTIPLICATION_FACTOR = 1/10

INTERVAL = [int(len(mc.state_log))-X_TIME_OFFSET,int(len(mc.state_log)+X_TIME_RANGE)-X_TIME_OFFSET-1]

ax[0].set_yscale('symlog')
ax[1].set_yscale('symlog')

# plot x
data = mc.state_log[INTERVAL[0]:INTERVAL[1], 0]
x = np.arange(len(data)) * mc.dt * TIME_MULTIPLICATION_FACTOR
ax[0].plot(x, data, label=state_names[0], c="C0", zorder=10)
ax[0].set_ylabel(state_names[0], color="C0")
ax[0].set_ylim(-3.4, 3.4)
plt.xlabel('time [s]')
ax[0].set_yticks([-2,0,2])
ax[0].set_yticklabels(['$-2$','$0$','$2$'])

ax[0].spines['left'].set_color('C0')
ax[0].xaxis.label.set_color('C0')
ax[0].tick_params(axis='y', colors='C0')

ax2 = ax[0].twinx()  # instantiate a second Axes that shares the same x-axis

# plot theta
data = mc.state_log[INTERVAL[0]:INTERVAL[1], 2]
# convert to angle
data = data * 360/2/np.pi
x = np.arange(len(data)) * mc.dt * TIME_MULTIPLICATION_FACTOR
ax2.plot(x, data, label=state_names[2], c="C1", zorder=-10)
ax2.set_ylim(-60, 60)
ax2.set_ylabel("$\\vartheta$ [deg]", color="C1")

ax2.spines['left'].set_color('C0')
ax2.spines['right'].set_color('C1')
ax2.xaxis.label.set_color('C1')
ax2.tick_params(axis='y', colors='C1')

# for tick in ax2.yaxis.get_majorticklabels():
#     tick.set_horizontalalignment("right")

ax[0].set_zorder(ax2.get_zorder()-1)
ax[0].set_frame_on(False)



ax[0].tick_params(labelbottom=False)
    
data = mc.action_log[INTERVAL[0]:INTERVAL[1]]
ax[1].plot(x, data, label="force $F$",  c='black')
ax[1].set_ylabel("force $F$")
ax[1].set_ylim(-60, 60)

ax[1].set_yticks([-10,-1,1,10])
ax[1].set_yticklabels(['$-100$','$-10$','$10$','$100$'])

fig.align_ylabels(ax)

for ax1 in ax:
    for val in [0,20 * TIME_MULTIPLICATION_FACTOR,40 * TIME_MULTIPLICATION_FACTOR]:
        ax1.axvline(val, c='black', alpha=0.1)

# lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
# lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# fig.legend(lines, labels, prop={'size': 8}, loc='upper center', ncol=4, handlelength=0.75, bbox_to_anchor=(0.5, 1.0), columnspacing=1, handletextpad=0.6)
    


plt.subplots_adjust(left = 0.1, top = 0.835, right = 0.93, bottom = 0.25, hspace = 0.2, wspace = 0.0)
plt.tight_layout()
plt.savefig(OUTPUT_PATH + '/example_cartpole.pdf')
plt.show()

In [None]:
network_MC_list[0][1].layers

In [None]:
x_range = [0,100]

mc = network_MC_list[0][1]
plt.plot(mc.penalty_log[x_range[0]:x_range[1]], ls=':')
plt.plot(mc.state_log[x_range[0]:x_range[1],[2]])
plt.axhline(theta_lim,c='C0')
plt.axhline(-theta_lim,c='C0')
# plt.axhline(x_lim,c='C0')
# plt.axhline(-x_lim,c='C0')
plt.show()