In [1]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from matplotlib.lines import Line2D


In [4]:
path = '<path to trends data>'


In [5]:
def get_assistant_name(name: str):

    assistant_name = name.split("_assistant_")

    if len(assistant_name) == 1:
        return 'None'
    else:
        assistant_name = assistant_name[1]

        for typ in ['mlp', 'lstm', 'rnn', 'gru', 'sac', 'ddpg']:
            if typ in assistant_name :
                return typ.upper()

        return 'unknown'


def get_assistant_specialized_name(name: str):

    assistant_name = name.split("_assistant_")

    if len(assistant_name) == 1:
        return 'None'
    else:
        assistant_name = assistant_name[1]
        if 'assistant' not in assistant_name:
            if 'penalizeDeflection' in assistant_name:
                return assistant_name.split('_')[0].upper()
            else:
                assistant_name = assistant_name.split('_')
                return f"{assistant_name[0]}-{assistant_name[1]}".upper()

        else:
            assistant_name= assistant_name.removeprefix('assistant_mars_')
            name = assistant_name.split('_', maxsplit=1)[0]
            if 'all_prof' in assistant_name:
                name += '-GMB'
            elif 'good_med' in assistant_name:
                name += '-GM'
            elif 'good' in assistant_name:
                name += '-G'
            elif 'med' in assistant_name:
                name += '-M'
            elif 'bad' in assistant_name:
                name += '-B'

            if 'small_window' in assistant_name:
                name += '-0.2'
            elif 'small_window_future' in assistant_name:
                name += '-0.3'
            elif 'mlp' in assistant_name and 'window' not in assistant_name:
                name += '-0'
            else:
                name += '-0.5'

            return name.upper()


In [None]:

profs = ['good', 'med', 'bad']
datas = ['vip', 'mars']

# profs = ['good']
# datas=['mars']

columns = {
 '# crashes': 'Crashes',
 '% destabilizing actions': '% destab.',
 'Distance from DOB mean': 'μ|θ|',
 'Angular position SD': 'σ(θ)',
 'Velocity magnitude mean': 'μ|Mag|vel',
 'Angular velocity SD': 'σ(|Mag|vel)',
 'Velocity RMS': 'vel RMS',
 'Deflection magnitude mean': 'μ|d| * 30'
}

assistant_order = ['None', 'MLP', 'RNN', 'LSTM', 'GRU', 'DDPG', 'SAC']

best_assistant_per_pilot = []
worst_assistant_per_pilot = []

for prof in profs:
    for data in datas:

        # reading diff file
        fname = f"{path}/{prof}_{data}_analysis_output_diff.csv"
        diff_data = pd.read_csv(fname)

        # multiply mean deflection magnitude
        diff_data['Deflection magnitude mean'] = diff_data['Deflection magnitude mean'] * 30

        # getting assistant type and name
        diff_data['assistant_type'] = diff_data['Name'].apply(get_assistant_name)
        diff_data['specialized_type'] = diff_data['Name'].apply(get_assistant_specialized_name)


        # formatting data to plot metrics averaged over arch type
        averaged_data_by_type = diff_data.groupby('assistant_type').mean(numeric_only=True)
        averaged_data_by_type['assistant_type'] = averaged_data_by_type.index
        averaged_data_by_type['order'] = averaged_data_by_type['assistant_type'].apply(lambda x: assistant_order.index(x))
        averaged_data_by_type = averaged_data_by_type.sort_values(['order'])

        # plotting the above formatted data
        plt.rcParams.update({'font.size': 32})
        plt.figure(figsize=(20, 10), dpi=300)

        figures = []
        for col in columns:
            markerInd = list(columns.keys()).index(col)
            fig = plt.plot(averaged_data_by_type['assistant_type'], averaged_data_by_type[col],
                        label=columns[col], markersize=20, linestyle='-.', marker=Line2D.filled_markers[markerInd])
            figures.append(fig)

        plt.xlabel('Assistant type')
        plt.ylabel('Difference from unaided pilot')
        plt.tight_layout()
        plt.legend(loc='best', bbox_to_anchor=(1, 0.5, 0, 0.5))

        plt.savefig(f"{path}/{prof}_{data}_analysis_diff_average_type_plot.png", dpi=300, bbox_inches='tight')


        # getting the best model per arch type
        sorted_diff_data = diff_data.sort_values(['# crashes', '% destabilizing actions', 'Distance from DOB mean'])
        best_assistant_per_pilot.append(sorted_diff_data.values[0, 0])
        best_assistants_filtered = sorted_diff_data.drop_duplicates('assistant_type')
        best_assistants_filtered['order'] = best_assistants_filtered['assistant_type'].apply(lambda x: assistant_order.index(x))
        best_assistants_filtered = best_assistants_filtered.sort_values(['order'])

        #plotting metrics for the best model

        plt.rcParams.update({'font.size': 32})
        plt.figure(figsize=(20, 10), dpi=300)

        figures = []
        for col in columns:
            markerInd = list(columns.keys()).index(col)
            fig = plt.plot(best_assistants_filtered['specialized_type'], best_assistants_filtered[col],
                        label=columns[col], markersize=20, linestyle='-.', marker=Line2D.filled_markers[markerInd])
            figures.append(fig)

        plt.xlabel('Assistant')
        plt.ylabel('Difference from unaided pilot')
        plt.xticks(rotation=20)

        plt.tight_layout()
        plt.legend(loc='best', bbox_to_anchor=(1, 0.5, 0, 0.5))

        plt.savefig(f"{path}/{prof}_{data}_analysis_diff_best_type_plot.png", dpi=300, bbox_inches='tight')

        # getting the worst model per arch type
        sorted_diff_data = diff_data.sort_values(['# crashes', '% destabilizing actions', 'Distance from DOB mean'], ascending=False)
        print(sorted_diff_data.Name.head())
        worst_assistant_per_pilot.append((sorted_diff_data.values[0, 0], sorted_diff_data.values[1, 0]))
        worst_assistants_filtered = sorted_diff_data.drop_duplicates('assistant_type')
        worst_assistants_filtered['order'] = worst_assistants_filtered['assistant_type'].apply(lambda x: assistant_order.index(x))
        worst_assistants_filtered = worst_assistants_filtered.sort_values(['order'])

        #plotting metrics for the best model

        plt.rcParams.update({'font.size': 32})
        plt.figure(figsize=(20, 10), dpi=300)

        figures = []
        for col in columns:
            markerInd = list(columns.keys()).index(col)
            fig = plt.plot(worst_assistants_filtered['specialized_type'], worst_assistants_filtered[col],
                        label=columns[col], markersize=20, linestyle='-.', marker=Line2D.filled_markers[markerInd])
            figures.append(fig)

        plt.xlabel('Assistant')
        plt.ylabel('Difference from unaided pilot')
        plt.xticks(rotation=20)

        plt.tight_layout()
        plt.legend(loc='best', bbox_to_anchor=(1, 0.5, 0, 0.5))

        plt.savefig(f"{path}/{prof}_{data}_analysis_diff_worst_type_plot.png", dpi=300, bbox_inches='tight')

In [26]:
best_assistant_per_pilot

['pilot_vip_lstm_good_small_window_assistant_sac_airl_300k',
 'pilot_mars_lstm_good_small_window_assistant_assistant_mars_mlp_all_prof',
 'pilot_vip_gru_med_small_window_future_assistant_sac_airl_300k',
 'pilot_mars_gru_med_small_window_future_assistant_sac_airl_300k',
 'pilot_vip_mlp_bad_window_assistant_sac_airl_300k',
 'pilot_mars_mlp_bad_window_assistant_sac_penalizeDeflection']

In [27]:
worst_assistant_per_pilot

[('pilot_vip_lstm_good_small_window_assistant_assistant_mars_mlp_good_window',
  'pilot_vip_lstm_good_small_window_assistant_assistant_mars_lstm_all_prof'),
 ('pilot_mars_lstm_good_small_window_assistant_assistant_mars_rnn_good_med',
  'pilot_mars_lstm_good_small_window_assistant_assistant_mars_lstm_all_prof'),
 ('pilot_vip_gru_med_small_window_future',
  'pilot_vip_gru_med_small_window_future_assistant_assistant_mars_rnn_all_prof'),
 ('pilot_mars_gru_med_small_window_future_assistant_assistant_mars_mlp_good_small_window_future',
  'pilot_mars_gru_med_small_window_future_assistant_assistant_mars_lstm_good_small_window'),
 ('pilot_vip_mlp_bad_window_assistant_assistant_mars_mlp_good_window',
  'pilot_vip_mlp_bad_window_assistant_assistant_mars_gru_good_small_window'),
 ('pilot_mars_mlp_bad_window',
  'pilot_mars_mlp_bad_window_assistant_assistant_mars_rnn_good')]