In [1]:
import pandas as pd

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

# Read the data
graph = pd.read_csv('best_results_graph.csv')
hypergraph = pd.read_csv('best_results_hypergraph.csv')
cell = pd.read_csv('best_results_cell.csv')
simplicial = pd.read_csv('best_results_simplicial.csv')

In [2]:
col_order = ['Model', 
'Cora',  'citeseer',  'PubMed', 
'MUTAG', 'PROTEINS', 'NCI1', 'NCI109','IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY',
'US-county-demos-Election', 'US-county-demos-BachelorRate', 'US-county-demos-BirthRate', 'US-county-demos-DeathRate', 'US-county-demos-MedianIncome', 'US-county-demos-MigraRate', 'US-county-demos-UnemploymentRate',
'amazon_ratings', 'minesweeper', 'questions', 'roman_empire', 'tolokers',
'ZINC',]

rename_dict = {
    'Model': 'Model',
    'Cora': 'Cora',
    'citeseer': 'Citeseer',
    'PubMed': 'Pubmed',
    'MUTAG': 'MUTAG',
    'PROTEINS': 'PROTEINS',
    'NCI1': 'NCI1',
    'NCI109': 'NCI109',
    'IMDB-BINARY': 'IMDB-BINARY',
    'IMDB-MULTI': 'IMDB-MULTI',
    'REDDIT-BINARY': 'REDDIT-BINARY',
    
    'ZINC': 'ZINC', 
    'US-county-demos-Election': 'Election',
    'US-county-demos-BachelorRate': 'Bachelor Rate',
    'US-county-demos-BirthRate': 'Birth Rate',
    'US-county-demos-DeathRate': 'Death Rate',
    'US-county-demos-MedianIncome': 'Median Income',
    'US-county-demos-MigraRate': 'Migra Rate',
    'US-county-demos-UnemploymentRate': 'Unemployment Rate',
    'roman_empire': 'RomanEmpire',
    'amazon_ratings': 'Amazon Ratings',
    'minesweeper': 'Minesweeper',
    'questions': 'Questions',
    'tolokers': 'Tolokers'
}

rename_models = { 
    'gcn': 'GCN',
    'gat': 'GAT',
    'gin': 'GIN',

    'allsettransformer': 'AllSetTransformer',
    'edgnn': 'EDGNN',
    'unignn2': 'UniGNN2',

    'ccxn': 'CCXN',
    'cwn_dcm': 'CCCN',
    'cwn': 'CWN',

    'scn': 'SCN',
    'sccn': 'SCCN',
    'sccnn_custom': 'SCCNN',
}

df = pd.concat([graph, hypergraph, cell, simplicial], axis=0)
df = df[col_order]
df = df.rename(columns=rename_dict)
df['Model'] = df['Model'].map(rename_models)
df['Model'].reset_index(drop=True, inplace=True)

# Give correct order of models
df.set_index('Model', inplace=True)
model_order = ['GCN', 'GAT', 'GIN', 'AllSetTransformer', 'EDGNN', 'UniGNN2', 'CWN', 'CCCN','CCXN', 'SCN', 'SCCN', 'SCCNN']
df = df.T[model_order].T

# substitute off nan with NA
df.fillna('NA', inplace=True)

# fill nan ± nan with NA
for col in df.columns:
    df[col] = df[col].apply(lambda x: 'NA' if x == 'nan ± nan' else x)

In [3]:

df

Unnamed: 0_level_0,Cora,Citeseer,Pubmed,MUTAG,PROTEINS,NCI1,NCI109,IMDB-BINARY,IMDB-MULTI,REDDIT-BINARY,Election,Bachelor Rate,Birth Rate,Death Rate,Median Income,Migra Rate,Unemployment Rate,Amazon Ratings,Minesweeper,Questions,RomanEmpire,Tolokers,ZINC
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
GCN,87.09 ± 0.2,75.53 ± 1.27,89.4 ± 0.3,69.79 ± 6.8,75.7 ± 2.14,72.86 ± 0.69,72.2 ± 1.22,72.0 ± 2.48,49.97 ± 2.16,76.24 ± 0.54,0.3062 ± 0.0215,0.2928 ± 0.022,0.7171 ± 0.0903,0.5081 ± 0.0389,0.2218 ± 0.0256,0.7951 ± 0.1184,0.2531 ± 0.0312,49.56 ± 0.55,87.52 ± 0.42,76.11 ± 0.82,78.16 ± 0.32,83.02 ± 0.71,0.6217 ± 0.0088
GAT,86.71 ± 0.95,74.41 ± 1.75,89.44 ± 0.24,72.77 ± 2.77,76.34 ± 1.66,75.0 ± 0.99,73.8 ± 0.73,69.76 ± 2.65,50.13 ± 3.87,75.68 ± 1.0,0.2908 ± 0.0227,0.2842 ± 0.0221,0.7148 ± 0.0939,0.5108 ± 0.0369,0.2038 ± 0.0219,0.7722 ± 0.1279,0.2288 ± 0.0278,50.17 ± 0.59,89.64 ± 0.43,77.89 ± 0.72,84.02 ± 0.51,84.43 ± 1.0,0.6059 ± 0.0058
GIN,87.21 ± 1.89,73.73 ± 1.23,89.29 ± 0.41,79.57 ± 6.13,75.2 ± 3.3,74.26 ± 0.96,74.42 ± 0.7,70.96 ± 1.93,47.68 ± 4.21,81.96 ± 1.36,0.2841 ± 0.0151,0.3073 ± 0.0287,0.718 ± 0.0876,0.5231 ± 0.0444,0.2137 ± 0.0238,0.797 ± 0.1036,0.2212 ± 0.0239,49.16 ± 1.02,87.82 ± 0.34,76.38 ± 0.88,79.56 ± 0.2,80.72 ± 1.19,0.5713 ± 0.0371
AllSetTransformer,88.92 ± 0.44,73.85 ± 2.21,89.62 ± 0.25,71.06 ± 6.49,76.63 ± 1.74,75.18 ± 1.24,73.75 ± 1.09,70.32 ± 3.27,50.51 ± 2.92,74.84 ± 2.68,0.2897 ± 0.0136,0.2996 ± 0.0276,0.7111 ± 0.0823,0.4938 ± 0.0456,0.2066 ± 0.0216,0.7775 ± 0.123,0.2199 ± 0.022,50.5 ± 0.27,81.14 ± 0.05,,79.5 ± 0.13,83.26 ± 0.1,0.5959 ± 0.0234
EDGNN,87.06 ± 1.09,74.93 ± 1.39,89.04 ± 0.51,80.0 ± 4.9,73.91 ± 4.39,73.97 ± 0.82,74.93 ± 2.5,69.12 ± 2.92,49.17 ± 4.35,83.24 ± 1.45,0.3418 ± 0.0246,0.2934 ± 0.0242,0.7036 ± 0.0743,0.5192 ± 0.0466,0.225 ± 0.0236,0.7988 ± 0.1225,0.2576 ± 0.0281,48.18 ± 0.09,84.52 ± 0.05,,81.01 ± 0.24,77.53 ± 0.01,0.5218 ± 0.0082
UniGNN2,86.97 ± 0.88,74.72 ± 1.08,89.34 ± 0.45,80.43 ± 4.09,75.2 ± 2.96,73.02 ± 0.92,70.76 ± 1.11,71.04 ± 1.31,49.76 ± 3.55,75.56 ± 3.19,0.367 ± 0.0204,0.311 ± 0.0229,0.7257 ± 0.0952,0.511 ± 0.0454,0.2342 ± 0.0218,0.7923 ± 0.1162,0.2833 ± 0.0207,49.06 ± 0.08,78.02 ± 0.0,,77.06 ± 0.2,77.35 ± 0.03,0.5991 ± 0.0056
CWN,86.32 ± 1.38,75.2 ± 1.82,88.64 ± 0.36,80.43 ± 1.78,76.13 ± 2.7,73.93 ± 1.87,73.8 ± 2.06,70.4 ± 2.02,49.71 ± 2.83,,0.3437 ± 0.0216,0.3306 ± 0.0279,0.7181 ± 0.086,0.5399 ± 0.0553,0.2468 ± 0.03,0.838 ± 0.1286,0.2535 ± 0.031,,88.62 ± 0.04,,81.81 ± 0.62,,0.3497 ± 0.012
CCCN,87.44 ± 1.28,75.63 ± 1.58,88.52 ± 0.44,77.02 ± 9.32,73.33 ± 2.3,76.67 ± 1.48,75.35 ± 1.5,69.12 ± 2.82,47.79 ± 3.45,,0.307 ± 0.0155,0.313 ± 0.0248,0.7134 ± 0.0915,0.5443 ± 0.0568,0.2263 ± 0.0187,0.8373 ± 0.1206,0.2446 ± 0.0317,,89.42 ± 0.0,,82.14 ± 0.0,,0.336 ± 0.0125
CCXN,86.79 ± 1.81,74.67 ± 2.24,88.91 ± 0.47,74.89 ± 5.51,75.63 ± 2.57,74.86 ± 0.82,75.66 ± 1.3,70.08 ± 1.21,47.63 ± 3.45,,0.3471 ± 0.0154,0.3166 ± 0.0288,0.745 ± 0.1123,0.5426 ± 0.055,0.2489 ± 0.0296,0.8482 ± 0.185,0.2707 ± 0.0258,,88.88 ± 0.36,,81.44 ± 0.31,,0.4194 ± 0.0161
SCN,82.27 ± 1.34,71.24 ± 1.68,88.72 ± 0.5,73.62 ± 6.13,75.27 ± 2.14,74.46 ± 1.11,75.7 ± 1.04,,,,0.4648 ± 0.043,0.3186 ± 0.0241,0.7122 ± 0.0836,0.5208 ± 0.0525,0.2526 ± 0.0247,0.9209 ± 0.1993,0.3753 ± 0.0432,,90.32 ± 0.11,,88.79 ± 0.46,,


In [4]:
import re
import numpy as np

def get_metric(s, metric=0):
    '''
    0 - mean
    1 - std
    '''
    if type(s) == str:
        vals = [v.strip() for v in s.split('±')]
        if len(vals) == 1:
            return -1
        else:
            return float(vals[metric])
    else:
        return float('nan')

def highlight_best(row):
    if row.name == 'avg. ranking':
        row = [float(v) for v in row.values]
        return np.where(row == np.nanmin(row), f'background: grey; font-weight: bold;', None)
    elif get_metric(row[1],0 ) == -1:
        # First columns
        return [None] * len(row)
    elif row.name in ['Election','Bachelor Rate', 'Birth Rate','Death Rate', 'Median Income', 'Migra Rate','Unemployment Rate','ZINC']:
        x_mean = np.array([get_metric(val, 0) for val in row])
        x_std = np.array([get_metric(val, 1) for val in row])
        best_mean = np.nanmin(x_mean)
        best_idx = np.nanargmin(x_mean)
        best_std = x_std[np.nanargmin(x_mean)]
        blue_results = x_mean < (best_mean + best_std)
        blue_results[best_idx] = False
        grey_results = x_mean == best_mean

        styles = np.array([None] * len(blue_results))
        styles[blue_results] = 'background: blue;'
        styles[grey_results] = 'background: grey; font-weight: bold;'

        return styles
    else:
        x_mean = np.array([get_metric(val, 0) for val in row])
        x_std = np.array([get_metric(val, 1) for val in row])
        best_mean = np.nanmax(x_mean)
        best_idx = np.nanargmax(x_mean)
        best_std = x_std[np.nanargmax(x_mean)]
        blue_results = x_mean > (best_mean - best_std)
        blue_results[best_idx] = False
        grey_results = x_mean == best_mean

        styles = np.array([None] * len(blue_results))
        styles[blue_results] = 'background: blue;'
        styles[grey_results] = 'background: grey; font-weight: bold;'

        return styles
        #return np.where(x_mean == np.nanmax(x_mean), f'background: grey; font-weight: bold;', None)

def replace_style_incl(text, name, new_name):
    return re.sub(r'[\\]' + name + '\s([^&]+?)\s([\\\\&])',
                  '\\\\' + new_name + '{\g<1>} \g<2>', text)

def replace_style_normal(text, name, new_name):
    return text.replace(name, new_name)

def beautify_table(df):
    # Highlight the cell with a blue background
    highlighted_df = df.style.apply(highlight_best, axis=0)
    highlighted_df = highlighted_df.hide(axis="index")
    return highlighted_df


def convert_to_latex(pivot_df):
    # Highlight cells
    highlighted_df = beautify_table(pivot_df)

    # Convert the table to LaTeX format
    latex_table = highlighted_df.to_latex(column_format='', hrules=True)
    # Remove \begin and \end and other
    latex_table = latex_table.replace('\\begin{tabular}{}', '')
    latex_table = latex_table.replace('\\end{tabular}', '')
    latex_table = latex_table.replace('\\toprule', '')
    latex_table = latex_table.replace('\\bottomrule', '')
    latex_table = latex_table.replace('\\midrule', '\\hline')
    #latex_table = latex_table.replace('\\\\', '\\\\ [5pt]')

    latex_table = replace_style_incl(latex_table, 'font-weightbold', 'textbf')
    latex_table = replace_style_normal(latex_table, '\\backgroundblue', '\\cellcolor{blue!25}')
    latex_table = replace_style_normal(latex_table, '\\backgroundgrey', '\\cellcolor{grey!25}')

    latex_table = f'''
        \\hline
        {latex_table}
        '''

    return latex_table + '\n'


In [5]:
df.reset_index(inplace=True)

In [6]:
df

Unnamed: 0,Model,Cora,Citeseer,Pubmed,MUTAG,PROTEINS,NCI1,NCI109,IMDB-BINARY,IMDB-MULTI,REDDIT-BINARY,Election,Bachelor Rate,Birth Rate,Death Rate,Median Income,Migra Rate,Unemployment Rate,Amazon Ratings,Minesweeper,Questions,RomanEmpire,Tolokers,ZINC
0,GCN,87.09 ± 0.2,75.53 ± 1.27,89.4 ± 0.3,69.79 ± 6.8,75.7 ± 2.14,72.86 ± 0.69,72.2 ± 1.22,72.0 ± 2.48,49.97 ± 2.16,76.24 ± 0.54,0.3062 ± 0.0215,0.2928 ± 0.022,0.7171 ± 0.0903,0.5081 ± 0.0389,0.2218 ± 0.0256,0.7951 ± 0.1184,0.2531 ± 0.0312,49.56 ± 0.55,87.52 ± 0.42,76.11 ± 0.82,78.16 ± 0.32,83.02 ± 0.71,0.6217 ± 0.0088
1,GAT,86.71 ± 0.95,74.41 ± 1.75,89.44 ± 0.24,72.77 ± 2.77,76.34 ± 1.66,75.0 ± 0.99,73.8 ± 0.73,69.76 ± 2.65,50.13 ± 3.87,75.68 ± 1.0,0.2908 ± 0.0227,0.2842 ± 0.0221,0.7148 ± 0.0939,0.5108 ± 0.0369,0.2038 ± 0.0219,0.7722 ± 0.1279,0.2288 ± 0.0278,50.17 ± 0.59,89.64 ± 0.43,77.89 ± 0.72,84.02 ± 0.51,84.43 ± 1.0,0.6059 ± 0.0058
2,GIN,87.21 ± 1.89,73.73 ± 1.23,89.29 ± 0.41,79.57 ± 6.13,75.2 ± 3.3,74.26 ± 0.96,74.42 ± 0.7,70.96 ± 1.93,47.68 ± 4.21,81.96 ± 1.36,0.2841 ± 0.0151,0.3073 ± 0.0287,0.718 ± 0.0876,0.5231 ± 0.0444,0.2137 ± 0.0238,0.797 ± 0.1036,0.2212 ± 0.0239,49.16 ± 1.02,87.82 ± 0.34,76.38 ± 0.88,79.56 ± 0.2,80.72 ± 1.19,0.5713 ± 0.0371
3,AllSetTransformer,88.92 ± 0.44,73.85 ± 2.21,89.62 ± 0.25,71.06 ± 6.49,76.63 ± 1.74,75.18 ± 1.24,73.75 ± 1.09,70.32 ± 3.27,50.51 ± 2.92,74.84 ± 2.68,0.2897 ± 0.0136,0.2996 ± 0.0276,0.7111 ± 0.0823,0.4938 ± 0.0456,0.2066 ± 0.0216,0.7775 ± 0.123,0.2199 ± 0.022,50.5 ± 0.27,81.14 ± 0.05,,79.5 ± 0.13,83.26 ± 0.1,0.5959 ± 0.0234
4,EDGNN,87.06 ± 1.09,74.93 ± 1.39,89.04 ± 0.51,80.0 ± 4.9,73.91 ± 4.39,73.97 ± 0.82,74.93 ± 2.5,69.12 ± 2.92,49.17 ± 4.35,83.24 ± 1.45,0.3418 ± 0.0246,0.2934 ± 0.0242,0.7036 ± 0.0743,0.5192 ± 0.0466,0.225 ± 0.0236,0.7988 ± 0.1225,0.2576 ± 0.0281,48.18 ± 0.09,84.52 ± 0.05,,81.01 ± 0.24,77.53 ± 0.01,0.5218 ± 0.0082
5,UniGNN2,86.97 ± 0.88,74.72 ± 1.08,89.34 ± 0.45,80.43 ± 4.09,75.2 ± 2.96,73.02 ± 0.92,70.76 ± 1.11,71.04 ± 1.31,49.76 ± 3.55,75.56 ± 3.19,0.367 ± 0.0204,0.311 ± 0.0229,0.7257 ± 0.0952,0.511 ± 0.0454,0.2342 ± 0.0218,0.7923 ± 0.1162,0.2833 ± 0.0207,49.06 ± 0.08,78.02 ± 0.0,,77.06 ± 0.2,77.35 ± 0.03,0.5991 ± 0.0056
6,CWN,86.32 ± 1.38,75.2 ± 1.82,88.64 ± 0.36,80.43 ± 1.78,76.13 ± 2.7,73.93 ± 1.87,73.8 ± 2.06,70.4 ± 2.02,49.71 ± 2.83,,0.3437 ± 0.0216,0.3306 ± 0.0279,0.7181 ± 0.086,0.5399 ± 0.0553,0.2468 ± 0.03,0.838 ± 0.1286,0.2535 ± 0.031,,88.62 ± 0.04,,81.81 ± 0.62,,0.3497 ± 0.012
7,CCCN,87.44 ± 1.28,75.63 ± 1.58,88.52 ± 0.44,77.02 ± 9.32,73.33 ± 2.3,76.67 ± 1.48,75.35 ± 1.5,69.12 ± 2.82,47.79 ± 3.45,,0.307 ± 0.0155,0.313 ± 0.0248,0.7134 ± 0.0915,0.5443 ± 0.0568,0.2263 ± 0.0187,0.8373 ± 0.1206,0.2446 ± 0.0317,,89.42 ± 0.0,,82.14 ± 0.0,,0.336 ± 0.0125
8,CCXN,86.79 ± 1.81,74.67 ± 2.24,88.91 ± 0.47,74.89 ± 5.51,75.63 ± 2.57,74.86 ± 0.82,75.66 ± 1.3,70.08 ± 1.21,47.63 ± 3.45,,0.3471 ± 0.0154,0.3166 ± 0.0288,0.745 ± 0.1123,0.5426 ± 0.055,0.2489 ± 0.0296,0.8482 ± 0.185,0.2707 ± 0.0258,,88.88 ± 0.36,,81.44 ± 0.31,,0.4194 ± 0.0161
9,SCN,82.27 ± 1.34,71.24 ± 1.68,88.72 ± 0.5,73.62 ± 6.13,75.27 ± 2.14,74.46 ± 1.11,75.7 ± 1.04,,,,0.4648 ± 0.043,0.3186 ± 0.0241,0.7122 ± 0.0836,0.5208 ± 0.0525,0.2526 ± 0.0247,0.9209 ± 0.1993,0.3753 ± 0.0432,,90.32 ± 0.11,,88.79 ± 0.46,,


In [7]:
list(df.columns)
column1 = ['Model',
 'Cora',
 'Citeseer',
 'Pubmed',
 'MUTAG',
 'PROTEINS',
 'NCI1',
 'NCI109',
 'IMDB-BINARY',
 'IMDB-MULTI',
 'REDDIT-BINARY',
 'Election',]

column2 = ['Model',
 'Bachelor Rate',
 'Birth Rate',
 'Death Rate',
 'Median Income',
 'Migra Rate',
 'Unemployment Rate',
 'Amazon Ratings',
 'Minesweeper',
 
 'RomanEmpire',
 'Tolokers',
 'ZINC']


latex_table = ''

latex_table += convert_to_latex(df[column1])
latex_table += convert_to_latex(df[column2]) 
latex_table += '\\midrule'



latex_table = latex_table.replace('\hline', '\midrule')
latex_table = latex_table.replace('grey', 'gray')
print(latex_table)


        \midrule
        

Model & Cora & Citeseer & Pubmed & MUTAG & PROTEINS & NCI1 & NCI109 & IMDB-BINARY & IMDB-MULTI & REDDIT-BINARY & Election \\
\midrule
GCN & 87.09 ± 0.2 & \cellcolor{blue!25} 75.53 ± 1.27 & \cellcolor{blue!25} 89.4 ± 0.3 & 69.79 ± 6.8 & \cellcolor{blue!25} 75.7 ± 2.14 & 72.86 ± 0.69 & 72.2 ± 1.22 & \cellcolor{gray!25} \textbf{72.0 ± 2.48} & \cellcolor{blue!25} 49.97 ± 2.16 & 76.24 ± 0.54 & 0.3062 ± 0.0215 \\
GAT & 86.71 ± 0.95 & \cellcolor{blue!25} 74.41 ± 1.75 & \cellcolor{blue!25} 89.44 ± 0.24 & 72.77 ± 2.77 & \cellcolor{blue!25} 76.34 ± 1.66 & 75.0 ± 0.99 & 73.8 ± 0.73 & \cellcolor{blue!25} 69.76 ± 2.65 & \cellcolor{blue!25} 50.13 ± 3.87 & 75.68 ± 1.0 & \cellcolor{blue!25} 0.2908 ± 0.0227 \\
GIN & 87.21 ± 1.89 & 73.73 ± 1.23 & 89.29 ± 0.41 & \cellcolor{blue!25} 79.57 ± 6.13 & \cellcolor{blue!25} 75.2 ± 3.3 & 74.26 ± 0.96 & 74.42 ± 0.7 & \cellcolor{blue!25} 70.96 ± 1.93 & \cellcolor{blue!25} 47.68 ± 4.21 & \cellcolor{blue!25} 81.96 ± 1.36 & \cellcolor{gray!