In [65]:
import os
import json
import pandas as pd

def collect_data(main_folder):
    """ Collects the minimum test loss and corresponding parameters across all subfolders in the main folder. """
    results = []
    context_size = 0
    # Traverse through each subfolder in the main folder
    for subfolder in os.listdir(main_folder):
        subfolder_path = os.path.join(main_folder, subfolder)
        progress_file = os.path.join(subfolder_path, 'progress.csv')
        params_file = os.path.join(subfolder_path, 'params.json')
        
        # Check if both necessary files exist
        if os.path.exists(progress_file) and os.path.exists(params_file):
            try:
                # Read progress.csv and find the minimum test loss
                data = pd.read_csv(progress_file)
                if 'test_loss' in data.columns:
                    min_loss = data['test_loss'].min()
                    # Read params.json
                    with open(params_file, 'r') as file:
                        params = json.load(file)
                        context_size = params.get('context_size')
                        # Collect required params and the corresponding test loss
                        result = {
                            'learning_rate': params.get('learning_rate'),
                            'best_test_loss': min_loss
                        }
                        results.append(result)
            except Exception as e:
                print(f"Error processing files in {subfolder_path}: {e}")

    return results, context_size

def create_results_table(main_folder):
    """ Creates a table of the minimum test losses for each combination of learning_rate, context_size, and samples. """
    data, context_size = collect_data(main_folder)
    if data:
        # Create DataFrame from collected data
        df = pd.DataFrame(data)
        # Group by the parameters and find the row with the minimum test_loss
        result_df = df.groupby(['learning_rate']).agg({'best_test_loss': 'min'}).reset_index()
        sorted_result_df = df.sort_values(by='best_test_loss', ascending=True).reset_index()
        print(f'context_size : {context_size}')
        print(sorted_result_df)
    else:
        print("No data collected. Check the contents of your directories.")


In [66]:
create_results_table(os.getcwd() + '/ray_results/run_2024-07-11_05-11-10')

context_size : 8
   index  learning_rate  best_test_loss
0      5         0.0010        5.604943
1      6         0.0010        7.172198
2      8         0.0005        7.583963
3      2         0.0100        7.768374
4      4         0.0005        8.799016
5      7         0.0005        9.354083
6      0         0.0100       11.646391
7      3         0.0010       12.060120
8      1         0.0100       36.306425


In [67]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_11-39-40')

context_size : 4
    index  learning_rate  best_test_loss
0       4         0.0005        5.605410
1       1         0.0010        6.441049
2       7         0.0010        7.344255
3       2         0.0005        7.762958
4       9         0.0010        8.088158
5       5         0.0100       11.489364
6       8         0.0100       11.512672
7      10         0.0010       13.766644
8       0         0.0100       14.642271
9       3         0.0100       16.623400
10      6         0.0005       30.744198


In [68]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_17-01-38')

context_size : 2
    index  learning_rate  best_test_loss
0       4         0.0005        6.900608
1       1         0.0010        7.325198
2      10         0.0005        7.735133
3      14         0.0010        7.857705
4       8         0.0010        8.699912
5       2         0.0100       10.317407
6       0         0.0100       11.460215
7       3         0.0005       12.818644
8       9         0.0100       13.487335
9       6         0.0005       16.386970
10      5         0.0100       16.590300
11      7         0.0010       16.957243
12     13         0.0010       18.019910
13     11         0.0005       18.183999
14     12         0.0100       25.067751


# Below is for n_stores = 50

In [69]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_05-27-18')

context_size : 128
   index  learning_rate  best_test_loss
0      1         0.0005        5.412806
1      4         0.0010        5.415018
2      0         0.0005        5.506853
3      5         0.0010        5.781743
4      3         0.0100        7.152602
5      2         0.0100       16.189754


In [70]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_07-52-50')

context_size : 64
   index  learning_rate  best_test_loss
0      0         0.0010        5.409906
1      5         0.0005        7.076955
2      4         0.0010        7.972043
3      3         0.0005        8.897128
4      1         0.0100       11.613703
5      2         0.0100       17.898005


In [71]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_09-41-43')

context_size : 32
   index  learning_rate  best_test_loss
0      3         0.0005        5.411628
1      2         0.0010        7.393025
2      1         0.0100       16.262605
3      0         0.0100       16.342239


In [72]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_11-52-10')

context_size : 16
    index  learning_rate  best_test_loss
0      13         0.0010        5.412956
1       7         0.0005        6.980137
2       3         0.0005        7.714144
3       4         0.0005        8.068696
4       6         0.0010        8.379202
5       5         0.0010        8.604063
6      12         0.0005       11.359692
7      14         0.0010       13.784768
8       9         0.0100       14.719614
9      11         0.0100       15.860225
10      0         0.0100       16.307592
11      8         0.0005       16.750998
12     10         0.0100       19.036855
13      2         0.0100       21.494105
14      1         0.0010       22.327871


In [73]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-11_19-32-22')

context_size : 8
    index  learning_rate  best_test_loss
0       9         0.0010        5.413410
1       2         0.0010        7.537731
2      10         0.0005        7.735478
3       1         0.0005        8.919952
4       7         0.0005       10.710806
5       6         0.0100       11.243040
6       8         0.0010       13.570966
7       4         0.0100       13.993180
8       3         0.0100       16.343363
9       5         0.0100       22.620466
10      0         0.0010       36.146848


In [74]:
create_results_table('/user/ml4723/Prj/NIC/ray_results/run_2024-07-12_01-44-35')

context_size : 4
   index  learning_rate  best_test_loss
0      2         0.0010        6.828720
1      4         0.0100        8.461158
2      7         0.0005        8.570793
3      3         0.0005       10.900044
4      1         0.0010       11.167397
5      6         0.0010       15.930729
6      0         0.0100       20.107402
7      9         0.0005       23.392702
8      8         0.0100       27.114183
9      5         0.0100       31.480385
