In [1]:
from pathlib import Path
import pandas as pd
from sklearn.model_selection import GridSearchCV
import HQC

def read_file_output_scores(file_path=None, 
                            file_type='xlsx', 
                            target_var=None,
                            model=HQC.HQC(), 
                            parameters={'rescale':[1], 'n_copies':[1]}, 
                            cv_folds=5):
    """
    Read a single file of the type .xlsx, .csv, .tsv or .sql with any name in a user specified 
    folder, and returns a dataframe of the cross-validated Balanced Accuracy and F1 scores performed 
    on a grid search of user specifed model hyperparameters. An excel file of this dataframe is also 
    returned in the user specified folder.
    
    Parameters
    ----------
    file_path  : Folder location in the format r''. See example below. Default set to None.
    file_type  : File type of the file. If excel file use 'xlsx', if csv file use 'csv', 
                 if tsv file use 'tsv', if sql file use 'sql'. Default set to excel.
    target_var : Name of target variable in the file. Default set to None.              
    model      : Scikit-learn compatible estimator. Default set to Helstrom Quantum Centroid (HQC)
                 classifier.
    parameters : Dictionary containing hyperparameters of the specified estimator. Default set to 
                 HQC hyperparameters rescale=1 and n_copies=1.
    cv_folds   : Number of cross-validation folds. Default set to 5.
    
    Returns
    -------  
    results_table : A dataframe of the cross-validated Balanced Accuracy and F1 scores.
                    
    """
    for fle in Path(file_path).iterdir():
        file_name = Path(file_path) / fle
    
    if file_type == 'xlsx':
        df = pd.read_excel(file_name)
    elif file_type == 'csv':
        df = pd.read_csv(file_name)
    elif file_type == 'tsv':
        df = pd.read_csv(file_name, sep='\t')
    elif file_type == 'sql':
        df = pd.read_sql(file_name)
    else:
        raise ValueError('File type not supported')
            
    X = df.drop(target_var, axis=1).values
    Y = df[target_var].values

    scoring_list = ['balanced_accuracy', 'f1_weighted']
    models = GridSearchCV(estimator=model, param_grid=parameters, scoring=scoring_list, n_jobs=-1, 
                          refit=False, cv=cv_folds, verbose=0, pre_dispatch='2*n_jobs', 
                          error_score='raise', return_train_score=False).fit(X, Y)
    
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', 500)
    results_table = pd.DataFrame(models.cv_results_)
    results_table.to_excel(file_path + r'\results_table.xlsx')
    
    return results_table

In [2]:
# Example
read_file_output_scores(file_path=r'C:\Users\HP\Desktop\IBM\New folder', 
                        file_type='csv',
                        target_var='target',
                        model=HQC.HQC(),
                        parameters={'rescale':[0.5, 1, 1.5], 'n_copies':[1, 2]}, 
                        cv_folds=5)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_n_copies,param_rescale,params,split0_test_balanced_accuracy,split1_test_balanced_accuracy,split2_test_balanced_accuracy,split3_test_balanced_accuracy,split4_test_balanced_accuracy,mean_test_balanced_accuracy,std_test_balanced_accuracy,rank_test_balanced_accuracy,split0_test_f1_weighted,split1_test_f1_weighted,split2_test_f1_weighted,split3_test_f1_weighted,split4_test_f1_weighted,mean_test_f1_weighted,std_test_f1_weighted,rank_test_f1_weighted
0,0.898051,0.10254,0.225613,0.037958,1,0.5,"{'n_copies': 1, 'rescale': 0.5}",0.663239,0.672182,0.637841,0.651858,0.664211,0.657866,0.011926,5,0.661041,0.670118,0.627732,0.647389,0.659581,0.653172,0.014634,5
1,0.929453,0.142841,0.243614,0.04297,1,1.0,"{'n_copies': 1, 'rescale': 1}",0.709132,0.700054,0.73839,0.698084,0.709753,0.711082,0.014435,4,0.701999,0.693644,0.734594,0.690566,0.700452,0.704251,0.015748,4
2,0.948454,0.162174,0.235813,0.036148,1,1.5,"{'n_copies': 1, 'rescale': 1.5}",0.58117,0.596104,0.63372,0.608934,0.602099,0.604405,0.017286,6,0.577387,0.595294,0.633608,0.610284,0.602015,0.603718,0.018459,6
3,1.709698,0.222597,0.442825,0.075853,2,0.5,"{'n_copies': 2, 'rescale': 0.5}",0.717112,0.733756,0.719937,0.732506,0.737899,0.728242,0.008181,3,0.721709,0.740322,0.724089,0.738846,0.744234,0.73384,0.009136,3
4,1.515087,0.21896,0.395423,0.054829,2,1.0,"{'n_copies': 2, 'rescale': 1}",0.755178,0.722879,0.765497,0.731561,0.765814,0.748186,0.017755,2,0.762527,0.730554,0.773656,0.739291,0.773039,0.755813,0.017727,2
5,1.490285,0.270316,0.376021,0.111908,2,1.5,"{'n_copies': 2, 'rescale': 1.5}",0.760702,0.726235,0.773459,0.741691,0.777642,0.755946,0.019422,1,0.768152,0.733824,0.781214,0.74934,0.784749,0.763456,0.019329,1
