### - import updated dataset
### - import policy from pickle
### - create new algorithm with as policy the imported policy
### - algorithm.step
### - save pickle for new algorithm, new policy and new action dispatcher

In [12]:
%load_ext autoreload
%aimport os, pandas, numpy, pickle
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import pandas as pd
ds = pd.read_csv('./trajectory/dataset.csv')
rt = pd.read_csv('./trajectory/ref_traj.csv')

In [28]:
import pandas as pd
import sys
import os
import pickle
import argparse
from trlib.policies.valuebased import EpsilonGreedy, Softmax
from trlib.policies.qfunction import ZeroQ
from sklearn.ensemble.forest import ExtraTreesRegressor
from trlib.algorithms.reinforcement.fqi_driver import FQIDriver, DoubleFQIDriver
from trlib.environments.trackEnv import TrackEnv
from trlib.utilities.ActionDispatcher import *
from fqi.dataset_preprocessing import *
from fqi.fqi_evaluate import run_evaluation
"""from fqi.et_tuning import run_tuning"""
from fqi.utils import *
from fqi.reward_function import *
from fqi.sars_creator import *
sys.setrecursionlimit(3000)

ref_df = pd.read_csv('./trajectory/ref_traj.csv')
data_df = pd.read_csv('./trajectory/dataset.csv')

print("There are", data_df.tail(1)['NLap'].values, "laps")

There are [169.] laps


In [29]:
def run_experiment(track_file_name, rt_file_name, data_path, max_iterations, output_path, n_throttle,
               n_brake, n_steer, n_jobs, output_name, reward_function, delta_t,
               filter_actions, ad_type, tuning, kdt_norm, kdt_param, filt_a_outliers, double_fqi, evaluation):


    # instantiate Reward Function and build SARS
    if reward_function == 'speed':
        print('SPEED REWARD FUNCTION')
        reward_function = Speed_projection(ref_df)
    elif reward_function == 'spatial':
        print('SPATIAL REWARD FUNCTION')
        reward_function = Spatial_projection(ref_df)
    elif reward_function == 'temporal':
        print('TEMPORAL REWARD FUNCTION')
        reward_function = Temporal_projection(ref_df)
    sars_data = to_SARS(data_df, reward_function)
    
    print('SARS ready')
    nmin = 5

    # Create environment
    state_dim = len(state_cols)
    action_dim = len(action_cols)
    mdp = TrackEnv(state_dim, action_dim, 0.99999, 'continuous')

    # Parameters of ET regressor
    regressor_params = {'n_estimators': 100,
                        'criterion': 'mse',
                        'min_samples_split': 2,
                        'min_samples_leaf': nmin,
                        'n_jobs': n_jobs,
                        'random_state': 42}
    regressor = ExtraTreesRegressor

    if ad_type == 'fkdt':
        action_dispatcher = FixedKDTActionDispatcher
        alg_actions = sars_data[action_cols].values

    elif ad_type == 'rkdt':
        action_dispatcher = RadialKDTActionDispatcher
        alg_actions = sars_data[action_cols].values

    elif ad_type == 'discrete':
        action_dispatcher = ConstantActionDispatcher
        actions, sub_actions = create_action_combinations(sars_data, n_throttle, n_brake, n_steer, filter_actions)
        alg_actions = sub_actions
    else:
        action_dispatcher = None
        alg_actions = None

    # import policy
    algorithm_name = output_name + '.pkl'
    policy_name = 'policy_' + algorithm_name
    with open(output_path + '/' + policy_name, 'rb') as pol:
        print('loading policy')
        pi = pickle.load(pol)
        print('pi:', pi)

    # Define the order of the columns to pass to the algorithm
    # state_prime_cols: colonne dello stato successivo
    cols = ['t'] + state_cols + action_cols + ['r'] + state_prime_cols + ['absorbing']
    # Define the masks used by the action dispatcher
    state_mask = [i for i, s in enumerate(state_cols) if s in knn_state_cols]
    data_mask = [i for i, c in enumerate(cols) if c in knn_state_cols]

    if double_fqi:
        fqi = DoubleFQIDriver
    else:
        fqi = FQIDriver

    algorithm = fqi(mdp=mdp,
                    policy=pi,
                    actions=alg_actions,
                    max_iterations=max_iterations,
                    regressor_type=regressor,
                    data=sars_data[cols].values,
                    action_dispatcher=action_dispatcher,
                    state_mask=state_mask,
                    data_mask=data_mask,
                    s_norm=kdt_norm,
                    filter_a_outliers=filt_a_outliers,
                    ad_n_jobs=n_jobs,
                    ad_param=kdt_param,
                    verbose=True,
                    **regressor_params)

    result = algorithm.step()
    

    # save algorithm object
    algorithm_name = output_name + '.pkl'
    with open(output_path + '/' + algorithm_name, 'wb') as output:
        pickle.dump(algorithm, output, pickle.HIGHEST_PROTOCOL)
    print('Saved algorithm object')
    
    # save policy object
    policy_name = 'policy_' + algorithm_name
    with open(output_path + '/' + policy_name, 'wb') as output:
        pickle.dump(algorithm._policy, output, pickle.HIGHEST_PROTOCOL)
    print('Saved policy object')

    # save action dispatcher object
    AD_name = 'AD_' + algorithm_name
    with open(output_path + '/' + AD_name, 'wb') as output:
        pickle.dump(algorithm._action_dispatcher, output, pickle.HIGHEST_PROTOCOL)
    print('Saved Action Dispatcher')
    
    

    if evaluation:

        print('*** Evaluation ***')
        run_evaluation(output_path+'/'+algorithm_name, track_file_name, data_path, n_jobs, output_path,
                       'eval_'+output_name, filter_actions,
                       output_path + '/' + AD_name)



In [None]:
track_file_name = 'dataset'
rt_file_name = 'ref_traj'
data_path = './trajectory/'
max_iterations = 100
output_path = './model_file/'
n_jobs = 10

filter_actions = False
filt_a_outliers = False
evaluation = True

for reward_function in ['speed', 'spatial', 'temporal']:
    output_name = reward_function + '_reward_model'#'first_model'
    run_experiment(track_file_name, rt_file_name, data_path, max_iterations, output_path, 3,3,3, n_jobs, output_name, 
               reward_function, 2, filter_actions, 'rkdt', False, False, 10, filt_a_outliers, True, evaluation)

SPEED REWARD FUNCTION
SARS ready
loading policy
pi: <trlib.policies.valuebased.Softmax object at 0x7ff70027edd8>
Step 1
Finding nearest actions for each state prime
Time for action list 5.6775596141815186
Time for action set 0.48354363441467285
Time for sprime a mat 12.410649538040161
Iteration 0
fitQ 5.359355211257935
Elapsed time 5.359355211257935
Iteration 1
maxQ 12.045907258987427
fitQ 5.408914566040039
Elapsed time 17.454821825027466
Iteration 2
maxQ 13.091437339782715
fitQ 5.338975429534912
Elapsed time 18.430412769317627
Iteration 3
maxQ 12.531635522842407
fitQ 5.673698902130127
Elapsed time 18.205334424972534
Iteration 4
maxQ 12.362218618392944
fitQ 5.552037954330444
Elapsed time 17.91425657272339
Iteration 5
maxQ 12.956385612487793
fitQ 5.483731031417847
Elapsed time 18.44011664390564
Iteration 6
maxQ 12.015736103057861
fitQ 5.173650026321411
Elapsed time 17.189386129379272
Iteration 7
maxQ 10.881363868713379
fitQ 5.193310260772705
Elapsed time 16.074674129486084
Iteration 8
m

maxQ 10.032391548156738
fitQ 4.7550201416015625
Elapsed time 14.7874116897583
Iteration 88
maxQ 10.172150135040283
fitQ 4.869352340698242
Elapsed time 15.041502475738525
Iteration 89
maxQ 10.192524433135986
fitQ 4.986457586288452
Elapsed time 15.178982019424438
Iteration 90
maxQ 10.111929655075073
fitQ 4.771919012069702
Elapsed time 14.883848667144775
Iteration 91
maxQ 10.394980669021606
fitQ 4.752052545547485
Elapsed time 15.147033214569092
Iteration 92
maxQ 10.55277967453003
fitQ 5.161503076553345
Elapsed time 15.714282751083374
Iteration 93
maxQ 10.248588562011719
fitQ 4.689722299575806
Elapsed time 14.938310861587524
Iteration 94
maxQ 10.680197954177856
fitQ 5.08965277671814
Elapsed time 15.769850730895996
Iteration 95
maxQ 10.814886093139648
fitQ 4.9075846672058105
Elapsed time 15.722470760345459
Iteration 96
maxQ 10.423030376434326
fitQ 4.826903820037842
Elapsed time 15.249934196472168
Iteration 97
maxQ 10.079595804214478
fitQ 4.794272184371948
Elapsed time 14.873867988586426
Ite

Processing 95.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 96.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 97.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 98.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 99.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 100.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 101.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 102.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 103.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 104.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 105.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 106.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 107.0 of 169
Computed pilot Q values
Computing policy Q values
Processing 108.0 of 169
Computed pilot Q va

maxQ 9.884007692337036
fitQ 4.925749778747559
Elapsed time 14.809757471084595
Iteration 27
maxQ 10.100929260253906
fitQ 4.793421745300293
Elapsed time 14.8943510055542
Iteration 28
maxQ 10.745954036712646
fitQ 5.148773670196533
Elapsed time 15.89472770690918
Iteration 29
maxQ 10.635509729385376
fitQ 5.093069553375244
Elapsed time 15.72857928276062
Iteration 30
maxQ 10.416970014572144
fitQ 4.814580917358398
Elapsed time 15.231550931930542
Iteration 31
maxQ 9.918670892715454
fitQ 4.725611209869385
Elapsed time 14.644282102584839
Iteration 32
maxQ 10.60941481590271
fitQ 4.78996467590332
Elapsed time 15.39937949180603
Iteration 33
maxQ 10.335052013397217
fitQ 4.992138385772705
Elapsed time 15.327190399169922
Iteration 34
maxQ 10.481157302856445
fitQ 5.216170787811279
Elapsed time 15.697328090667725
Iteration 35
maxQ 10.603423833847046
fitQ 4.90121054649353
Elapsed time 15.504634380340576
Iteration 36
maxQ 10.469187498092651
fitQ 5.1689229011535645
Elapsed time 15.638110399246216
Iteration 