In [3]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt

In [40]:
def get_sweep_df(sweep_id, partition='test',metric="accuracy",best='max'):
    entity = "r252_bel"
    project_name = "setup_tests"

    # Fetch the sweep
    api = wandb.Api()
    sweep = api.sweep(f"{entity}/{project_name}/{sweep_id}")
    # Initialize a list to store data
    data = []

    # Iterate over each run and fetch the required data
    for run in sweep.runs:
        config = run.config
        summary_metrics = run.summary
        # You might need to adjust the key names based on your actual data
        try:
            dct = {
                # 'dataset': config['dataset'],  
                'with_sam': config['with_sam'], 
                'base_optimizer': config['base_optimizer'], 
                # Add other hyperparameters here as needed
                'batch_size': config['batch_size'],  
                'num_hops': config['num_hops'],
                'val/'+metric: summary_metrics['val/'+metric],  # Adjust based on your summary metrics
                'test/'+metric: summary_metrics['test/'+metric],  # Adjust based on your summary metrics
                # Include other config elements as necessary
            }
            if 'train/'+metric in summary_metrics.keys():
                dct['train/'+metric]= summary_metrics['train/'+metric]  # Adjust based on your summary metrics
            data.append(dct)
        except:
            print(f"Run {run.id} failed")

    # Convert the list to a DataFrame
    df = pd.DataFrame(data)
    
    df['train/'+metric] = None

    # df = df[df['dataset'] != 'poly_binarised_decimalised_mod10_synth'] # excliude this dataset

    # Group by dataset and other relevant hyperparameters (excluding seed), and calculate the mean of balanced_accuracy across seeds
    grouped_df = df.groupby(['with_sam', 
                            'base_optimizer',  # Add other hyperparameters here, except 'seed'
                            'batch_size',
                            'num_hops'
                            ]).agg({
                                'train/'+metric: ['mean', 'std'],
                                'val/'+metric: ['mean', 'std'],
                                'test/'+metric: ['mean', 'std']
                                }).reset_index()

    # Now, find the entry with the maximum 'balanced_accuracy' for each 'dataset'
    # result_df = grouped_df.loc[grouped_df.groupby(['base_optimizer','with_sam'])['test/accuracy'].idxmax()]

    # First, flatten the column hierarchy to simplify accessing the mean and std columns
    grouped_df.columns = ['_'.join(col).strip() for col in grouped_df.columns.values]


    # Assuming you want to find the configuration with the highest mean test accuracy
    # Now, find the entry with the maximum mean 'test/accuracy' for each combination of 'base_optimizer' and 'with_sam'
    if best == 'max':
        indices = grouped_df.groupby(['base_optimizer_', 'with_sam_'])['val/'+metric+'_mean'].idxmax()

    else:
        indices = grouped_df.groupby(['base_optimizer_', 'with_sam_'])['val/'+metric+'_mean'].idxmin()

    result_df = grouped_df.loc[indices]

    # Display the resulting DataFrame
    print(result_df)
    return result_df


# Amazon-GCN-new-nl

In [41]:
sweep_id = "anmjm6yi"
result_df = get_sweep_df(sweep_id)

Run 647tptcy failed
Run 5l40nnr9 failed
    with_sam_ base_optimizer_  batch_size_  num_hops_ train/accuracy_mean  \
2       False            adam          512          6                 NaN   
26       True            adam          512          6                 NaN   
12      False             sgd          512          2                 NaN   
36       True             sgd          512          2                 NaN   

    train/accuracy_std  val/accuracy_mean  val/accuracy_std  \
2                  NaN           0.465387          0.009193   
26                 NaN           0.468699          0.001174   
12                 NaN           0.379719          0.000000   
36                 NaN           0.379719          0.000000   

    test/accuracy_mean  test/accuracy_std  
2             0.470946           0.001574  
26            0.475552           0.001790  
12            0.377238           0.000000  
36            0.377238           0.000000  


In [30]:
result_df = get_sweep_df("anmjm6yi", metric='train/loss')

    with_sam_ base_optimizer_  batch_size_  num_hops_  train/loss_mean  \
9       False            adam         4096          2         1.201114   
33       True            adam         4096          2         1.245885   
12      False             sgd          512          2         1.491604   
36       True             sgd          512          2         1.493490   

    train/loss_std  
9         0.006753  
33        0.002858  
12        0.000846  
36        0.000781  


In [15]:
print(result_df.to_latex(index=False))

\begin{tabular}{rlrrrr}
\toprule
with_sam_ & base_optimizer_ & batch_size_ & num_hops_ & test/accuracy_mean & test/accuracy_std \\
\midrule
False & adam & 512 & 6 & 0.470946 & 0.001574 \\
True & adam & 512 & 6 & 0.475552 & 0.001790 \\
False & sgd & 512 & 2 & 0.377238 & 0.000000 \\
True & sgd & 512 & 2 & 0.377238 & 0.000000 \\
\bottomrule
\end{tabular}



# neighbour_loader_node_class_Cora-GCN

In [18]:
result_df = get_sweep_df('vkcjqut3')

    with_sam_ base_optimizer_  batch_size_  num_hops_  test/accuracy_mean  \
8       False            adam          128          6            0.810743   
26       True            adam          128          6            0.806611   
14      False             sgd           32          6            0.825367   
32       True             sgd           32          6            0.822028   

    test/accuracy_std  
8            0.002520  
26           0.004110  
14           0.004731  
32           0.005260  


In [31]:
result_df = get_sweep_df("vkcjqut3", metric='train/loss')

    with_sam_ base_optimizer_  batch_size_  num_hops_  train/loss_mean  \
0       False            adam            8          2         0.004434   
18       True            adam            8          2         0.013569   
17      False             sgd          128          6         1.699061   
35       True             sgd          128          6         1.717816   

    train/loss_std  
0         0.003891  
18        0.003682  
17        0.034533  
35        0.029100  


In [32]:
result_df = get_sweep_df("vkcjqut3", metric='test/loss')


    with_sam_ base_optimizer_  batch_size_  num_hops_  test/loss_mean  \
2       False            adam            8          6       14.781947   
19       True            adam            8          4       22.042636   
17      False             sgd          128          6        1.690725   
35       True             sgd          128          6        1.708623   

    test/loss_std  
2        1.072301  
19       2.333827  
17       0.035280  
35       0.029744  


In [19]:
print(result_df.to_latex(index=False))

\begin{tabular}{rlrrrr}
\toprule
with_sam_ & base_optimizer_ & batch_size_ & num_hops_ & test/accuracy_mean & test/accuracy_std \\
\midrule
False & adam & 128 & 6 & 0.810743 & 0.002520 \\
True & adam & 128 & 6 & 0.806611 & 0.004110 \\
False & sgd & 32 & 6 & 0.825367 & 0.004731 \\
True & sgd & 32 & 6 & 0.822028 & 0.005260 \\
\bottomrule
\end{tabular}



# neighbour_loader_node_class_CiteSeer-GCN

In [20]:
result_df = get_sweep_df('48oqoszw')

    with_sam_ base_optimizer_  batch_size_  num_hops_  test/accuracy_mean  \
8       False            adam          128          6            0.703363   
26       True            adam          128          6            0.702261   
14      False             sgd           32          6            0.784583   
32       True             sgd           32          6            0.783879   

    test/accuracy_std  
8            0.004817  
26           0.011175  
14           0.002564  
32           0.003285  


# Roman-GCN-NeighbourLoader

result_df = get_sweep_df('rq1n4pdr')
