In [None]:
#########################################################################################################################
# Inferring Temporal Logic Specifications for Robot-Assisted Feeding in Social Dining Settings
#
# Jan Ondras (janko@cs.cornell.edu, jo951030@gmail.com)
# Project for Program Synthesis (CS 6172)
# Cornell University, Fall 2021
#########################################################################################################################
# Run LoTuS cross-validation (Matlab function lotusCrossVal) for various parameters and save the results as pickle files
#########################################################################################################################

In [1]:
# First, open Matlab with LoTuS project and run: matlab.engine.shareEngine

import matlab.engine
import numpy as np
import time
import pickle
import json
import matplotlib.pyplot as plt

# window_sizes = [30, 60, 90, 120]
# window_sizes = [10, 20, 30, 60, 90, 120, 150, 180, 210]
# window_sizes = [15, 45, 75, 105, 135, 165, 195]

window_sizes = [15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180, 195, 210]
# window_sizes = [75] # for depth 5
# window_sizes = [30, 60] # for depth 3

feature_types = [
    # 'tR2',
    'R2',
    'dR2',
    'R2dR2',
    # 'tR2dR2'
]
# max_tree_depths = [1, 2, 3, 4, 5, 6, 7]
# max_tree_depths = [1, 3, 5, 7, 9, 12, 15]
# max_tree_depths = [3, 5, 7]

max_tree_depths = [5]
max_tree_depths = [3]

primitives_sets = [
    'setPrim1',
    'setPrim2',
    'setPrim12'
]

feature_sources = [
    'solo',
    'duo',
    'trio'
]

eng = matlab.engine.connect_matlab()
 
cv_results = {}
for feature_source in feature_sources:
    for primitives_set in primitives_sets:
        
        for feature_type in feature_types:

            cv_results[feature_type] = {}
            for window_size in window_sizes:

                traces_type = f'{feature_source}_{feature_type}f_{window_size}w'
                cv_results[feature_type][window_size] = {}
                for max_tree_depth in max_tree_depths:
                    start_time = time.time()
                    res = eng.lotusCrossVal(traces_type, float(max_tree_depth), primitives_set)
                    # Convert matlab.double to numpy array
                    for field in ['times', 'train_mcrs', 'test_mcrs', 'train_conf_mats', 'test_conf_mats', 
                                  'train_mcrs_pruned', 'test_mcrs_pruned', 'train_conf_mats_pruned', 'test_conf_mats_pruned',
                                  'train_conf_mats_sum', 'test_conf_mats_sum', 'train_conf_mats_pruned_sum', 'test_conf_mats_pruned_sum']:
                        res[field] = np.asarray(res[field]);
                    cv_results[feature_type][window_size][max_tree_depth] = res
                    print('Time taken: ', time.time() - start_time)

        with open(f'./cv_results/cv_results_{feature_source}_{primitives_set}_3_compareWindows.pkl', 'wb') as f:
            pickle.dump(cv_results, f)


Time taken:  47.946561336517334
Time taken:  52.51865911483765
Time taken:  43.14264488220215
Time taken:  46.26475501060486
Time taken:  89.95598459243774
Time taken:  115.27894592285156
Time taken:  1349.8727099895477
Time taken:  1423.2005743980408
Time taken:  1113.752592086792
Time taken:  1247.0072391033173
Time taken:  2410.2394273281097
Time taken:  2659.863125562668
Time taken:  1289.328797340393
Time taken:  1460.7385969161987
Time taken:  1171.7511081695557
Time taken:  1260.227205991745
Time taken:  2519.2025575637817
Time taken:  2573.546110391617
Time taken:  65.43915557861328
Time taken:  82.99714756011963
Time taken:  56.286391258239746
Time taken:  66.60621285438538
Time taken:  142.8233323097229
Time taken:  234.5059950351715
Time taken:  1657.0232496261597
Time taken:  1998.1550946235657
Time taken:  1236.2620704174042
Time taken:  1495.1522977352142
Time taken:  3096.078025341034
Time taken:  3611.0507073402405
Time taken:  1844.222285747528
Time taken:  2049.946063