In [1]:
import os
import pandas as pd
import numpy as np
import optuna
import sklearn.metrics
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRIALS = 250
TOP_N = 1

In [3]:
def load_performance_metrics_sliding_windows(dataset, gnn, llm, method, display_information = False):
  
  study_name = f'{dataset}-{gnn}-{llm.replace("/", "-")}-{method}-Sliding_Windows'
  storage = f'sqlite:///../../sliding_windows/optuna_studies/{study_name}.db'
  study = optuna.load_study(study_name = study_name, storage = storage)
  
  study_df = study.trials_dataframe()

  study_df = study_df[study_df['value'] >= 0.0].sort_values(by = 'number', ascending = True)
  columns = ['number', 'value'] + [x for x in study_df.columns if x.startswith('params_')] + [x for x in study_df.columns if x.startswith('user_attrs_')]

  top_trials = study_df[columns].head(TRIALS).sort_values(by = ['value', 'user_attrs_validation_loss', 'user_attrs_training_loss'], ascending = [False, True, True]).head(TOP_N)# .head(3) for R8, BART-large due to exceptions
  
  performance = list()
  for _, row in top_trials.iterrows():
    
    if display_information:
      print(row)
    trial = row['number']

    for random_state in [x for x in os.listdir(f'../../sliding_window_outputs/{dataset}-{gnn}-{llm}/{method}/Sliding_Windows/{trial}') if os.path.isdir(f'../../sliding_window_outputs/{dataset}-{gnn}-{llm}/{method}/Sliding_Windows/{trial}/{x}')]:
      df = pd.read_csv(f'../../sliding_window_outputs/{dataset}-{gnn}-{llm}/{method}/Sliding_Windows/{trial}/{random_state}/predictions.csv')
      for split in ['validation', 'test']:
        df_split = df[df['split'] == split]
        accuracy = sklearn.metrics.accuracy_score(df_split['real'], df_split['prediction'])
        f1_score = sklearn.metrics.f1_score(df_split['real'], df_split['prediction'], average = 'macro')
        precision = sklearn.metrics.precision_score(df_split['real'], df_split['prediction'], average = 'macro')
        recall = sklearn.metrics.recall_score(df_split['real'], df_split['prediction'], average = 'macro')
        performance.append((trial, random_state, split, accuracy, f1_score, precision, recall))
  return pd.DataFrame(performance, columns = ['trial', 'random_state', 'split', 'accuracy', 'f1_score', 'precision', 'recall'])

In [4]:
def remove_trials_with_exceptions(df):
  return df.groupby(['trial']).filter(lambda x : len(x) >= 10 * 2).reset_index(drop = True)

In [5]:
def get_best_results_sliding_windows(dataset, gnn, llm, method, target_metric):
  df = load_performance_metrics_sliding_windows(dataset = dataset, gnn = gnn, llm = llm, method = method, display_information = True)
  df = remove_trials_with_exceptions(df)
  df_aggregated = df.groupby(['trial', 'split']).agg({
    'accuracy' : ['mean', 'std', 'max'],
    'f1_score' : ['mean', 'std'],
    'precision' : ['mean', 'std'],
    'recall' : ['mean', 'std'],
  }).reset_index()
  df_aggregated.columns = df_aggregated.columns.map('_'.join).str.strip('_')
  
  best_trial_number = df_aggregated.sort_values(by = ['split', f'{target_metric}_mean', f'{target_metric}_std'], ascending = [True, False, True]).reset_index(drop = True).iloc[0]['trial']

  best_trial = df_aggregated[df_aggregated['trial'] == best_trial_number]

  print(best_trial_number)
  print('-' * 10, 'Validation', '-' * 10)

  print(
    'Accuracy:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['accuracy_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['accuracy_std'].values[0] * 100, decimals = 2))
  )
  print(
    'F1-score:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['f1_score_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['f1_score_std'].values[0] * 100, decimals = 2))
  )
  print(
    'Precision:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['precision_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['precision_std'].values[0] * 100, decimals = 2))
  )
  print(
    'Recall:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['recall_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'validation']['recall_std'].values[0] * 100, decimals = 2))
  )

  print('-' * 10, 'Test', '-' * 10)

  print(
    'Accuracy:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['accuracy_mean'].values[0] * 100, decimals = 2)),
    '±', 
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['accuracy_std'].values[0] * 100, decimals = 2))
  )
  print(
    'F1-score:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['f1_score_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['f1_score_std'].values[0] * 100, decimals = 2))
  )
  print(
    'Precision:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['precision_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['precision_std'].values[0] * 100, decimals = 2))
  )
  print(
    'Recall:',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['recall_mean'].values[0] * 100, decimals = 2)),
    '±',
    '{:.2f}'.format(np.round(best_trial[best_trial['split'] == 'test']['recall_std'].values[0] * 100, decimals = 2))
  )

## SST-2

In [6]:
get_best_results_sliding_windows(
  dataset = 'SST-2',
  gnn = 'GATv2',
  llm = 'facebook-bart-large',
  method = 'Grouped',
  target_metric = 'accuracy'
)

number                                         121
value                                     0.954128
params_attention_heads                           9
params_balanced_loss                          True
params_batch_size                              178
params_beta_0                             0.836063
params_beta_1                             0.990599
params_co_occurrence_pooling_operation         sum
params_dropout_rate                        0.56877
params_early_stopping_patience                  21
params_embedding_pooling_operation             min
params_epochs                                  167
params_epsilon                            0.000001
params_global_pooling                          max
params_hidden_dimension                        102
params_learning_rate                      0.001089
params_number_of_hidden_layers                   4
params_plateau_divider                           4
params_plateau_patience                         14
params_weight_decay            

## Ohsumed

In [7]:
get_best_results_sliding_windows(
  dataset = 'Ohsumed',
  gnn = 'GATv2',
  llm = 'facebook-bart-large',
  method = 'Grouped',
  target_metric = 'f1_score'
)

number                                          97
value                                     0.675289
params_attention_heads                           5
params_balanced_loss                          True
params_batch_size                              113
params_beta_0                              0.85674
params_beta_1                             0.998013
params_co_occurrence_pooling_operation         sum
params_dropout_rate                       0.558362
params_early_stopping_patience                  15
params_embedding_pooling_operation            mean
params_epochs                                  193
params_epsilon                                 0.0
params_global_pooling                         mean
params_hidden_dimension                        205
params_learning_rate                       0.00185
params_left_stride                              64
params_number_of_hidden_layers                   1
params_plateau_divider                           3
params_plateau_patience        

## R8

In [8]:
get_best_results_sliding_windows(
  dataset = 'R8',
  gnn = 'GATv2',
  llm = 'facebook-bart-large',
  method = 'Grouped',
  target_metric = 'f1_score'
)

number                                         256
value                                     0.975418
params_attention_heads                          15
params_balanced_loss                         False
params_batch_size                              148
params_beta_0                             0.804286
params_beta_1                              0.99726
params_co_occurrence_pooling_operation         sum
params_dropout_rate                       0.574962
params_early_stopping_patience                  18
params_embedding_pooling_operation             max
params_epochs                                  161
params_epsilon                                 0.0
params_global_pooling                         mean
params_hidden_dimension                        218
params_learning_rate                      0.000132
params_left_stride                              64
params_number_of_hidden_layers                   3
params_plateau_divider                           3
params_plateau_patience        

## IMDb-1k

In [9]:
get_best_results_sliding_windows(
  dataset = 'IMDb-top_1000',
  gnn = 'GATv2',
  llm = 'facebook-bart-large',
  method = 'Grouped',
  target_metric = 'accuracy'
)

number                                         223
value                                     0.957576
params_attention_heads                           4
params_balanced_loss                          True
params_batch_size                               20
params_beta_0                             0.824085
params_beta_1                             0.989291
params_co_occurrence_pooling_operation         sum
params_dropout_rate                       0.525161
params_early_stopping_patience                  19
params_embedding_pooling_operation            mean
params_epochs                                  134
params_epsilon                            0.000095
params_global_pooling                         mean
params_hidden_dimension                        167
params_learning_rate                      0.000025
params_left_stride                              32
params_number_of_hidden_layers                   4
params_plateau_divider                           4
params_plateau_patience        