In [1]:
import mlflow
import pandas as pd

import mlflow
import pandas as pd

def generate_recommendations_table(experiment_ids, prefix_note="sizes_acts", dataset="MovieLens"):
    all_rows = []

    for exp_id in experiment_ids:
        runs = mlflow.search_runs(
            experiment_ids=[exp_id],
            output_format="list"
        )
        for run in runs:
            if prefix_note not in run.data.params.get("note") or run.data.params.get("dataset") != dataset:
                continue

            dim = int(run.data.params.get("embedding_dim", 0))
            dataset = run.data.params.get("dataset", f"Exp-{exp_id}")
            group_type = run.data.params.get("group_type", "none")
            topk = int(run.data.params.get("top_k", 0))
            activation = run.data.params.get("topk_inference", "False")
            aggregation_function = run.data.params.get("SAE_fusion_strategy", "none")
            

            row_key = (dim, topk, aggregation_function, activation)
            metrics = {
                (group_type, "G/mean"): run.data.metrics.get("CommonItemsNDCG20/mean"),
                (group_type, "U/mean"): run.data.metrics.get("NDCG20/mean"),
                (group_type, "U/min"): run.data.metrics.get("NDCG20/min"),
                (group_type, "Pop"): run.data.metrics.get("Popularity/mean"),
            }

            all_rows.append((row_key, metrics))

    # Build DataFrame from records
    records = {}
    for key, metrics in all_rows:
        if key not in records:
            records[key] = {}
        records[key].update(metrics)

    df = pd.DataFrame.from_dict(records, orient="index")
    df.index.names = ["Dimensions", "TopK", "Aggregation", "Activation"]

    # Sort and reindex columns by dataset then metric
    df = df.sort_index(axis=1, level=[0, 1]).sort_values(
        by=["Dimensions", "TopK", "Aggregation", "Activation"]
    )

    return df.reset_index()

def highlight_top3_dark_to_light(s):
    # Colors from dark to light
    colors = ['mediumseagreen', 'lightgreen']
    
    # Get sorted unique values in descending order
    top_values = s.nlargest(2).unique()
    
    # Assign background color depending on rank
    styles = ['' for _ in s]
    for rank, value in enumerate(top_values):
        styles = [
            f'background-color: {colors[rank]}' if v == value and styles[i] == '' else styles[i]
            for i, v in enumerate(s)
        ]
    return styles

def highlight_bottom3_dark_to_light(s):
    # Colors from dark to light
    colors = ['mediumblue', 'lightblue', 'paleturquoise']
    
    # Get sorted unique values in ascending order
    bottom_values = s.nsmallest(3).unique()
    
    # Assign background color depending on rank
    styles = ['' for _ in s]
    for rank, value in enumerate(bottom_values):
        styles = [
            f'background-color: {colors[rank]}' if v == value and styles[i] == '' else styles[i]
            for i, v in enumerate(s)
        ]
    return styles

# Aggregation functions

We have already selected for each aggregation function if it is better with topk activation function or not.

## SAE group recommendations table aggregated across all sizes

**Group type: Similar**

Each value is a mean accros all 9 sizes variant


In [3]:
experiments = ['523100174176986081', '333391697323445885']

# Select only the desired columns for aggregation

def format_latex(df, highlight_max_cols=None, highlight_min_cols=None, round_digits=3):
    formatted_df = df.copy()
    highlight_max_cols = highlight_max_cols or []
    highlight_min_cols = highlight_min_cols or []

    for col in df.columns:
        col_values = df[col]

        if col in highlight_max_cols:
            top_two = col_values.nlargest(2).values

            def format_cell(val):
                if val == top_two[0]:
                    return f"\\textbf{{{val:.{round_digits}f}}}"
                elif val == top_two[1]:
                    return f"\\underline{{{val:.{round_digits}f}}}"
                else:
                    return f"{val:.{round_digits}f}"

        elif col in highlight_min_cols:
            bottom_two = col_values.nsmallest(2).values

            def format_cell(val):
                if val == bottom_two[0]:
                    return f"\\textbf{{{val:.{round_digits}f}}}"
                elif val == bottom_two[1]:
                    return f"\\underline{{{val:.{round_digits}f}}}"
                else:
                    return f"{val:.{round_digits}f}"
        else:
            def format_cell(val):
                return f"{val}"

        formatted_df[col] = col_values.apply(format_cell)

    return formatted_df

highlight_max_cols = [('sim', 'G/mean', 'mean'), ('sim', 'U/min', 'mean'), ('sim', 'U/mean', 'mean'), ('random', 'G/mean', 'mean'), ('random', 'U/min', 'mean'), ('random', 'U/mean', 'mean'), ('outlier', 'G/mean', 'mean'), ('outlier', 'U/min', 'mean'), ('outlier', 'U/mean', 'mean')]
highlight_min_cols = [('sim', 'Pop', 'mean'), ('random', 'Pop', 'mean'), ('outlier', 'Pop', 'mean')]


table = generate_recommendations_table(experiments, prefix_note="aggregations", dataset="MovieLens")

row_indexes_selected = [
    ('average', 'True'),
    ('common_features', 'False'),
    ('max', 'True'),
    ('square_average', 'False'),
    ('topk', 'True'),
    ('wcom', 'True'),
]

table = table[
    table.set_index(['Aggregation', 'Activation']).index.isin(row_indexes_selected)
].reset_index()

# table = table[
#     table.set_index(['Aggregation', 'Activation'])
# ].reset_index()


selected_columns = []
group_types = ["sim", "random", "outlier"]
for grouptype in group_types:
    for metric in ["G/mean", "U/mean", "U/min", "Pop"]:
        selected_columns.append((grouptype, metric))
# Group by Aggregation and Activation and aggregate

agg_table = (
    table
    .groupby(["Aggregation", "Activation"])[selected_columns]
    .agg(['mean'])
).round(3)

selected_columns = []
for group_type in group_types:
    for metric in ["G/mean", "U/mean", "U/min"]:
        selected_columns.append((group_type, metric, 'mean'))
    
std_selected_columns = []
for group_type in group_types:
    for metric in ["G/mean", "U/mean", "U/min"]:
        std_selected_columns.append((group_type, metric, 'std'))
    


agg_table

Unnamed: 0_level_0,Unnamed: 1_level_0,sim,sim,sim,sim,random,random,random,random,outlier,outlier,outlier,outlier
Unnamed: 0_level_1,Unnamed: 1_level_1,G/mean,U/mean,U/min,Pop,G/mean,U/mean,U/min,Pop,G/mean,U/mean,U/min,Pop
Unnamed: 0_level_2,Unnamed: 1_level_2,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
Aggregation,Activation,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3
average,True,0.646,0.703,0.557,0.501,0.631,0.688,0.539,0.545,0.546,0.666,0.487,0.477
common_features,False,0.585,0.663,0.513,0.459,0.532,0.619,0.46,0.478,0.351,0.52,0.338,0.357
max,True,0.629,0.692,0.549,0.495,0.615,0.678,0.533,0.537,0.528,0.651,0.487,0.463
topk,True,0.644,0.702,0.554,0.514,0.632,0.688,0.537,0.56,0.557,0.67,0.494,0.505
wcom,True,0.634,0.694,0.545,0.498,0.618,0.68,0.524,0.542,0.533,0.659,0.468,0.48


In [6]:
format_latex(
    agg_table.reset_index(),
    highlight_max_cols=highlight_max_cols,
    highlight_min_cols=highlight_min_cols,
    round_digits=3
).to_latex(
    "sae_table.tex",
    index=False,
    float_format="%.3f",
    bold_rows=False,
    column_format="ll|rrrr|rrrr|rrrr",
    escape=False,
    caption = (
        "Table summarizing the performance of different SAE aggregation strategies on MovieLens dataset. "
        "'G/mean' shows the percentage change in mean NDCG@20 using ground-truth recommendations seen by all group members. "
        "'U/min' shows the change in the mean of the minimum NDCG@20 across group members. "
        "'U/mean' shows the change in the mean of the average NDCG@20 across group members. "
        "'Pop' shows the change in the mean popularity of recommended items."
    ),
    label="tab:aggregations:movielens"
)

## SAE group recommendations table aggregated across all sizes

**Group type: random**

Each value is a mean accros all 9 sizes variant

In [7]:
experiments = ['523100174176986081', '333391697323445885']

# Select only the desired columns for aggregation

def format_latex(df, highlight_max_cols=None, highlight_min_cols=None, round_digits=3):
    formatted_df = df.copy()
    highlight_max_cols = highlight_max_cols or []
    highlight_min_cols = highlight_min_cols or []

    for col in df.columns:
        col_values = df[col]

        if col in highlight_max_cols:
            top_two = col_values.nlargest(2).values

            def format_cell(val):
                if val == top_two[0]:
                    return f"\\textbf{{{val:.{round_digits}f}}}"
                elif val == top_two[1]:
                    return f"\\underline{{{val:.{round_digits}f}}}"
                else:
                    return f"{val:.{round_digits}f}"

        elif col in highlight_min_cols:
            bottom_two = col_values.nsmallest(2).values

            def format_cell(val):
                if val == bottom_two[0]:
                    return f"\\textbf{{{val:.{round_digits}f}}}"
                elif val == bottom_two[1]:
                    return f"\\underline{{{val:.{round_digits}f}}}"
                else:
                    return f"{val:.{round_digits}f}"
        else:
            def format_cell(val):
                return f"{val}"

        formatted_df[col] = col_values.apply(format_cell)

    return formatted_df

highlight_max_cols = [('sim', 'G/mean', 'mean'), ('sim', 'U/min', 'mean'), ('sim', 'U/mean', 'mean'), ('random', 'G/mean', 'mean'), ('random', 'U/min', 'mean'), ('random', 'U/mean', 'mean'), ('outlier', 'G/mean', 'mean'), ('outlier', 'U/min', 'mean'), ('outlier', 'U/mean', 'mean')]
highlight_min_cols = [('sim', 'Pop', 'mean'), ('random', 'Pop', 'mean'), ('outlier', 'Pop', 'mean')]


table = generate_recommendations_table(experiments, prefix_note="aggregations", dataset="LastFM1k")

row_indexes_selected = [
    ('average', 'True'),
    ('common_features', 'False'),
    ('max', 'True'),
    ('square_average', 'False'),
    ('topk', 'True'),
    ('wcom', 'True'),
]

table = table[
    table.set_index(['Aggregation', 'Activation']).index.isin(row_indexes_selected)
].reset_index()

# table = table[
#     table.set_index(['Aggregation', 'Activation'])
# ].reset_index()


selected_columns = []
group_types = ["sim", "random", "outlier"]
for grouptype in group_types:
    for metric in ["G/mean", "U/mean", "U/min", "Pop"]:
        selected_columns.append((grouptype, metric))
# Group by Aggregation and Activation and aggregate

agg_table = (
    table
    .groupby(["Aggregation", "Activation"])[selected_columns]
    .agg(['mean'])
).round(3)

selected_columns = []
for group_type in group_types:
    for metric in ["G/mean", "U/mean", "U/min"]:
        selected_columns.append((group_type, metric, 'mean'))
    
std_selected_columns = []
for group_type in group_types:
    for metric in ["G/mean", "U/mean", "U/min"]:
        std_selected_columns.append((group_type, metric, 'std'))
    


agg_table

Unnamed: 0_level_0,Unnamed: 1_level_0,sim,sim,sim,sim,random,random,random,random,outlier,outlier,outlier,outlier
Unnamed: 0_level_1,Unnamed: 1_level_1,G/mean,U/mean,U/min,Pop,G/mean,U/mean,U/min,Pop,G/mean,U/mean,U/min,Pop
Unnamed: 0_level_2,Unnamed: 1_level_2,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
Aggregation,Activation,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3
average,True,0.59,0.805,0.635,0.57,0.5,0.752,0.544,0.637,0.413,0.715,0.448,0.583
common_features,False,0.57,0.795,0.619,0.586,0.481,0.739,0.526,0.647,0.4,0.694,0.447,0.605
max,True,0.59,0.804,0.64,0.559,0.499,0.75,0.549,0.625,0.417,0.71,0.469,0.574
topk,True,0.584,0.802,0.632,0.564,0.487,0.744,0.532,0.626,0.392,0.704,0.428,0.563
wcom,True,0.576,0.798,0.621,0.578,0.49,0.747,0.53,0.646,0.394,0.71,0.419,0.586


In [8]:
format_latex(
    agg_table.reset_index(),
    highlight_max_cols=highlight_max_cols,
    highlight_min_cols=highlight_min_cols,
    round_digits=3
).to_latex(
    "sae_table.tex",
    index=False,
    float_format="%.3f",
    bold_rows=False,
    column_format="ll|rrrr|rrrr|rrrr",
    escape=False,
    caption = (
        "Table summarizing the performance of different SAE aggregation strategies on LastFM1k dataset. "
        "'G/mean' shows the percentage change in mean NDCG@20 using ground-truth recommendations seen by all group members. "
        "'U/min' shows the change in the mean of the minimum NDCG@20 across group members. "
        "'U/mean' shows the change in the mean of the average NDCG@20 across group members. "
        "'Pop' shows the change in the mean popularity of recommended items."
    ),
    label="tab:aggregations:lastfm1k"
)

## SAE group recommendations table aggregated across all sizes

**Group type: divergent**

Each value is a mean accros all 9 sizes variant

In [None]:
experiments = ['523100174176986081', '333391697323445885']

# Select only the desired columns for aggregation


table = generate_recommendations_table(experiments, prefix_note="aggregations", group_type="outlier")

row_indexes_selected = [
    ('average', 'True'),
    ('common_features', 'False'),
    ('max', 'True'),
    ('square_average', 'False'),
    ('topk', 'False'),
    ('wcom', 'True'),
]

# table = table[
#     table.set_index(['Aggregation', 'Activation']).index.isin(row_indexes_selected)
# ].reset_index()


selected_columns = []
for grouptype in ["MovieLens", "LastFM1k"]:
    for metric in ["G/mean", "U/mean", "U/min", "Pop"]:
        selected_columns.append((grouptype, metric))
# Group by Aggregation and Activation and aggregate

agg_table = (
    table
    .groupby(["Aggregation", "Activation"])[selected_columns]
    .agg(['mean'])
).round(3)

selected_columns = []
for grouptype in ["MovieLens", "LastFM1k"]:
    for metric in ["G/mean", "U/mean", "U/min"]:
        selected_columns.append((grouptype, metric, 'mean'))
    
std_selected_columns = []
for grouptype in ["MovieLens", "LastFM1k"]:
    for metric in ["G/mean", "U/mean", "U/min"]:
        std_selected_columns.append((grouptype, metric, 'std'))
    


agg_table.style.apply(highlight_top3_dark_to_light, subset=selected_columns)#.highlight_min((std_selected_columns))

Unnamed: 0_level_0,Unnamed: 1_level_0,MovieLens,MovieLens,MovieLens,MovieLens,LastFM1k,LastFM1k,LastFM1k,LastFM1k
Unnamed: 0_level_1,Unnamed: 1_level_1,G/mean,U/mean,U/min,Pop,G/mean,U/mean,U/min,Pop
Unnamed: 0_level_2,Unnamed: 1_level_2,mean,mean,mean,mean,mean,mean,mean,mean
Aggregation,Activation,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3
average,False,0.37,0.55,0.355,0.322,0.426,0.715,0.465,0.613
average,True,0.546,0.666,0.487,0.477,0.413,0.715,0.448,0.583
common_features,False,0.351,0.52,0.338,0.357,0.4,0.694,0.447,0.605
common_features,True,0.181,0.342,0.186,0.285,0.118,0.307,0.162,0.293
max,False,0.305,0.494,0.305,0.271,0.429,0.707,0.487,0.609
max,True,0.528,0.651,0.487,0.463,0.417,0.71,0.469,0.574
topk,False,0.55,0.668,0.487,0.495,0.382,0.705,0.407,0.567
topk,True,0.557,0.67,0.494,0.505,0.392,0.704,0.428,0.563
wcom,False,0.434,0.596,0.401,0.377,0.417,0.714,0.448,0.611
wcom,True,0.533,0.659,0.468,0.48,0.394,0.71,0.419,0.586
