In [63]:
import os
import tqdm
import wandb
import warnings
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import concurrent.futures

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow logging
warnings.filterwarnings('ignore', category=UserWarning, module='google.protobuf')

from matplotlib.axes import Axes
from wandb.apis.public import Run

from typing import Union, List, Dict
from src.visualization import set_themes

set_themes() # Set custom themes for plots
pl.Config.set_tbl_rows(20) # Set Polars table display rows limit

pd.set_option('future.no_silent_downcasting', True)

In [65]:
cache_file = "wandb/summary.parquet"
config = {
    "model": "matrix_factorization",
    "ensure_available_locally": False
}
sorting_criterion = {
    "epoch/test_hitrate@50": 0.5,
    "epoch/test_ndcg@50": 0.25,
}

if cache_file is not None and os.path.exists(cache_file):
    print(f"Loading cached experiment runs from {cache_file}...")
    experiment_runs = pl.read_parquet(cache_file)
    print(f"Loaded {len(experiment_runs)} runs from cache.")
else:
    print("No cache file found. Fetching experiment runs from Weights & Biases...")
    api = wandb.Api() # Initialize Weights & Biases API, used for fetching run data

    def fetch_run_metadata(run: Run, considered_metrics: Union[str, Dict[str, float]] = "epoch/epoch") -> Dict:
        run_config = {}
        for key, value in run.config.items():
            # Convert lists and dicts to strings
            if isinstance(value, (list, dict)):
                run_config[key] = str(value)
            else:
                run_config[key] = value

        run_history = run.history()
        run_history = run_history.replace({"Infinity": np.inf, "NaN": np.nan})

        if isinstance(considered_metrics, str):
            run_history["score"] = run_history[considered_metrics]
        elif isinstance(considered_metrics, dict):
            run_history["score"] = sum(
                run_history[metric] * weight for metric, weight in considered_metrics.items()
            )
        else:
            raise ValueError("considered_metrics must be either a string or a dictionary")
        
        best_summary = run_history.iloc[run_history["score"].argmax()]
        best_summary = {f"best:{key}": val for key, val in best_summary.items()}
        
        return {
            "run_id": run.id,
            "run_name": run.name,
            "sweep_id": run.sweep.id if run.sweep else None,
            "model": run.config.get("model"),
            **run_config,
            **{metric: run_history[metric].to_list() for metric in run_history},
            **best_summary,
            "gpu_type": run.metadata.get("gpu"),
            "cpu_count": run.metadata.get("cpu_count"),
        }

    batch_size = 16
    records = []
    futures = {}
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=batch_size)
    runs:List[Run] = api.runs("feedr/peppermint-matrix", per_page=2*batch_size-1, filters={"config.model": config["model"]})
    run_iterator = iter(runs)
    with tqdm.tqdm(total=len(runs), ncols=128) as pbar:
        while len(records) < len(runs):
            # submit new tasks if we empty slots in the batch
            while len(futures) < batch_size and len(records) + len(futures) < len(runs):
                current_runs = next(run_iterator)
                current_future = executor.submit(fetch_run_metadata, current_runs, sorting_criterion)
                futures[current_future] = current_runs

            # check for completed tasks
            finished_futures, _ = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=0.1)
            for finished_future in finished_futures:
                finished_run = futures.pop(finished_future)
                records.append(finished_future.result())
                pbar.update(1)

    # Create a Polars DataFrame from the records
    experiment_runs = pl.DataFrame(records, infer_schema_length=None)
    
# Tag run as available locally if the model files exist
local_run_ids = []
local_sweep_ids = os.listdir(f"./models/{config['model']}/")
for sweep_id in local_sweep_ids:
    local_run_ids.extend([run_id for run_id in os.listdir(f"./models/{config['model']}/{sweep_id}/")])
    
experiment_runs = experiment_runs.with_columns(
    available_locally=pl.col("run_id").is_in(local_run_ids)
)

if config["ensure_available_locally"]:
    experiment_runs = experiment_runs.filter(pl.col("available_locally") == True)

experiment_runs = experiment_runs.sort("_timestamp", descending=False)
experiment_runs = experiment_runs.with_columns(
    run_duration_second=pl.col("_runtime").list.max(),
    run_duration_minute=(pl.col("_runtime").list.max() / 60)
)
experiment_runs.select(
    pl.col("run_id"),
    pl.col("run_name"),
    pl.col("sweep_id"),
    pl.col("model"),
    pl.col("embedding_dimension"),
    pl.col("shuffle"),
    pl.col("best:epoch/epoch"),
    pl.col("best:epoch/train_loss"),
    pl.col("best:epoch/test_loss"),
    pl.col("best:epoch/test_recall@10"),
    pl.col("best:epoch/test_ndcg@10"),
)

Loading cached experiment runs from wandb/summary.parquet...
Loaded 695 runs from cache.


run_id,run_name,sweep_id,model,embedding_dimension,shuffle,best:epoch/epoch,best:epoch/train_loss,best:epoch/test_loss,best:epoch/test_recall@10,best:epoch/test_ndcg@10
str,str,str,str,i64,bool,f64,f64,f64,f64,f64
"""o94q0juk""","""logical-sweep-1""","""nbysw136""","""matrix_factorization""",256,false,52.0,0.330564,0.372714,0.026077,0.100539
"""4ftaae0p""","""stilted-sweep-3""","""nbysw136""","""matrix_factorization""",4,false,59.0,0.693148,0.693148,0.006594,0.02556
"""fway5u2z""","""breezy-sweep-4""","""nbysw136""","""matrix_factorization""",512,false,4.0,0.693147,0.693147,0.0026,0.012059
"""bphcl2xf""","""clean-sweep-2""","""nbysw136""","""matrix_factorization""",1024,false,1.0,0.237272,0.239171,0.024744,0.095121
"""fftz1dek""","""trim-sweep-5""","""nbysw136""","""matrix_factorization""",256,true,57.0,0.319744,0.366027,0.025562,0.099834
"""otb8suw9""","""scarlet-sweep-6""","""nbysw136""","""matrix_factorization""",4,true,63.0,0.133806,0.217083,0.021467,0.082659
"""lvre7srl""","""solar-sweep-7""","""nbysw136""","""matrix_factorization""",256,true,8.0,0.036125,0.194509,0.021456,0.084617
"""dcbj92eg""","""ruby-sweep-8""","""nbysw136""","""matrix_factorization""",256,false,3.0,0.078173,0.179263,0.024295,0.094158
"""x17mnyw8""","""breezy-sweep-9""","""nbysw136""","""matrix_factorization""",8,false,61.0,0.111218,0.207902,0.023699,0.089898
"""7jrm756b""","""super-sweep-10""","""nbysw136""","""matrix_factorization""",4,false,10.0,0.693147,0.693147,0.002425,0.010531


# Parameter Comparison

## Embedding Dimension vs Regularization | Shuffle = False

In [71]:
experiment_summary = experiment_runs.filter(pl.col("shuffle") == False).group_by("embedding_dimension", "l2_regularization").agg(
    pl.col("run_id").count().alias("num_runs"),
    pl.col("best:epoch/epoch").mean(),
    pl.col("best:epoch/test_recall@10").mean(),
    pl.col("best:epoch/test_ndcg@10").mean(),
    pl.col("best:epoch/test_recall@20").mean(),
    pl.col("best:epoch/test_ndcg@20").mean(),
).sort("embedding_dimension", "l2_regularization")
experiment_summary

embedding_dimension,l2_regularization,num_runs,best:epoch/epoch,best:epoch/test_recall@10,best:epoch/test_ndcg@10,best:epoch/test_recall@20,best:epoch/test_ndcg@20
i64,f64,u32,f64,f64,f64,f64,f64
2,0.0,4,51.25,0.015956,0.061995,0.028299,0.08201
2,1.0000e-9,4,49.25,0.014855,0.057413,0.026413,0.075915
2,1.0000e-8,5,55.6,0.015249,0.060109,0.027161,0.079271
2,0.0000001,5,55.6,0.014978,0.05851,0.026745,0.077531
2,0.000001,2,34.0,0.014193,0.05583,0.025517,0.073636
2,0.00001,6,56.5,0.005557,0.023788,0.009413,0.030217
2,0.0001,4,9.5,0.002482,0.010624,0.004251,0.014389
2,0.001,3,13.333333,0.000361,0.001702,0.00068,0.00253
2,0.01,5,9.8,0.000536,0.002152,0.001209,0.003682
4,0.0,2,26.0,0.020357,0.079491,0.035603,0.102377


In [74]:
experiment_summary[["embedding_dimension", "l2_regularization", "num_runs"]].pivot(
    values=["num_runs"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "num_runs"]].pivot(


embedding_dimension,0.0,1e-9,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-10
i64,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
2,4,4,5,5.0,2,6,4,3,5,
4,2,3,2,2.0,7,4,6,1,3,2.0
8,3,3,2,5.0,1,2,3,4,2,8.0
16,5,1,6,2.0,4,3,8,4,4,1.0
32,2,1,4,4.0,5,4,3,9,6,1.0
64,3,1,4,7.0,5,3,4,3,1,2.0
128,6,3,4,5.0,3,3,3,7,1,1.0
256,4,3,4,4.0,4,4,2,5,5,2.0
512,2,2,5,,3,3,7,4,2,1.0
1024,5,1,2,3.0,3,4,6,1,3,2.0


In [75]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/epoch"]].pivot(
    values=["best:epoch/epoch"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/epoch"]].pivot(


embedding_dimension,0.0,1e-9,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-10
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,51.25,49.25,55.6,55.6,34.0,56.5,9.5,13.333333,9.8,
4,26.0,36.666667,50.5,60.0,57.857143,59.5,7.166667,49.0,9.333333,54.0
8,51.333333,47.666667,40.0,60.6,53.0,62.0,6.666667,34.25,9.5,46.875
16,33.0,33.0,38.0,58.0,59.0,61.666667,5.875,46.0,9.75,20.0
32,13.5,21.0,22.75,59.0,60.0,60.5,5.333333,57.666667,8.333333,15.0
64,8.333333,8.0,16.0,59.142857,53.2,62.0,5.75,61.333333,8.0,7.5
128,7.666667,3.333333,5.0,54.2,57.333333,62.333333,5.333333,57.714286,8.0,3.0
256,5.25,3.0,3.75,53.75,49.25,61.5,5.5,57.6,8.4,4.5
512,1.0,1.0,2.0,,53.0,62.666667,5.571429,62.25,9.0,1.0
1024,1.0,1.0,1.0,4.333333,39.666667,15.75,5.5,49.0,8.0,1.0


In [76]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(
    values=["best:epoch/test_recall@20"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(


embedding_dimension,0.0,1e-9,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-10
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,0.028299,0.026413,0.027161,0.026745,0.025517,0.009413,0.004251,0.00068,0.001209,
4,0.035603,0.034158,0.035595,0.036566,0.036019,0.010423,0.004288,0.000676,0.001358,0.036021
8,0.0392,0.039011,0.040219,0.042264,0.041996,0.011371,0.004176,0.000602,0.001154,0.039337
16,0.04007,0.040536,0.04125,0.045753,0.043829,0.011708,0.004205,0.000635,0.001355,0.040328
32,0.040401,0.04091,0.041409,0.048015,0.044864,0.011938,0.0043,0.000637,0.001517,0.041414
64,0.039906,0.041384,0.041036,0.048532,0.045193,0.012688,0.00418,0.00066,0.001715,0.041631
128,0.03902,0.041844,0.042047,0.048786,0.045222,0.012968,0.004322,0.000708,0.00228,0.042033
256,0.038145,0.041942,0.04228,0.047684,0.045084,0.013638,0.004299,0.000725,0.001791,0.041675
512,0.040658,0.041295,0.041465,,0.04552,0.013945,0.004337,0.000747,0.001614,0.041881
1024,0.041928,0.043326,0.043283,0.043072,0.044551,0.014939,0.004374,0.000796,0.001707,0.042935


In [77]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(
    values=["best:epoch/test_recall@20"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(


embedding_dimension,0.0,1e-9,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-10
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,0.028299,0.026413,0.027161,0.026745,0.025517,0.009413,0.004251,0.00068,0.001209,
4,0.035603,0.034158,0.035595,0.036566,0.036019,0.010423,0.004288,0.000676,0.001358,0.036021
8,0.0392,0.039011,0.040219,0.042264,0.041996,0.011371,0.004176,0.000602,0.001154,0.039337
16,0.04007,0.040536,0.04125,0.045753,0.043829,0.011708,0.004205,0.000635,0.001355,0.040328
32,0.040401,0.04091,0.041409,0.048015,0.044864,0.011938,0.0043,0.000637,0.001517,0.041414
64,0.039906,0.041384,0.041036,0.048532,0.045193,0.012688,0.00418,0.00066,0.001715,0.041631
128,0.03902,0.041844,0.042047,0.048786,0.045222,0.012968,0.004322,0.000708,0.00228,0.042033
256,0.038145,0.041942,0.04228,0.047684,0.045084,0.013638,0.004299,0.000725,0.001791,0.041675
512,0.040658,0.041295,0.041465,,0.04552,0.013945,0.004337,0.000747,0.001614,0.041881
1024,0.041928,0.043326,0.043283,0.043072,0.044551,0.014939,0.004374,0.000796,0.001707,0.042935


In [78]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_ndcg@20"]].pivot(
    values=["best:epoch/test_ndcg@20"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_ndcg@20"]].pivot(


embedding_dimension,0.0,1e-9,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-10
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,0.08201,0.075915,0.079271,0.077531,0.073636,0.030217,0.014389,0.00253,0.003682,
4,0.102377,0.096656,0.100571,0.102822,0.102865,0.033224,0.015206,0.002372,0.003976,0.1025
8,0.110761,0.110237,0.112467,0.117875,0.117645,0.034171,0.014815,0.002131,0.003596,0.111143
16,0.113374,0.114356,0.115944,0.128185,0.122799,0.036012,0.01491,0.002359,0.003994,0.1153
32,0.113685,0.115539,0.116762,0.132806,0.125672,0.037479,0.015194,0.00237,0.00493,0.116097
64,0.113178,0.114448,0.116176,0.134394,0.12657,0.038883,0.01464,0.002634,0.006554,0.11693
128,0.111663,0.118013,0.118536,0.134409,0.126275,0.040437,0.015177,0.002728,0.00809,0.11919
256,0.108346,0.118845,0.119785,0.131961,0.126971,0.041902,0.015096,0.002757,0.005035,0.117174
512,0.113271,0.114751,0.116177,,0.126861,0.042879,0.015423,0.002837,0.00429,0.118107
1024,0.117476,0.119362,0.12,0.119814,0.125853,0.044805,0.015453,0.002837,0.005742,0.1193


In [79]:
experiment_summary["best:epoch/test_recall@20"].max()

0.04878597855567932

## Embedding Dimension vs Regularization | Shuffle = True

In [57]:
experiment_summary = experiment_runs.filter(pl.col("shuffle") == True).group_by("embedding_dimension", "l2_regularization").agg(
    pl.col("run_id").count().alias("num_runs"),
    pl.col("best:epoch/epoch").mean(),
    pl.col("best:epoch/test_recall@10").mean(),
    pl.col("best:epoch/test_ndcg@10").mean(),
    pl.col("best:epoch/test_recall@20").mean(),
    pl.col("best:epoch/test_ndcg@20").mean(),
).sort("embedding_dimension", "l2_regularization")

In [58]:
experiment_summary[["embedding_dimension", "l2_regularization", "num_runs"]].pivot(
    values=["num_runs"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "num_runs"]].pivot(


embedding_dimension,0.0,1e-10,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-9
i64,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
2,4,3.0,4,2,1,4,2,5,4,
4,5,3.0,5,5,1,5,8,3,2,4.0
8,3,3.0,5,1,4,3,5,3,6,1.0
16,2,2.0,9,2,1,9,3,4,4,
32,5,1.0,3,5,8,1,3,3,4,1.0
64,3,2.0,3,5,5,3,3,2,5,2.0
128,4,,4,4,4,5,4,4,4,2.0
256,3,2.0,4,5,4,3,3,3,4,2.0
512,3,1.0,5,2,6,1,2,3,5,3.0
1024,3,3.0,4,3,5,4,4,3,3,3.0


In [59]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/epoch"]].pivot(
    values=["best:epoch/epoch"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/epoch"]].pivot(


embedding_dimension,0.0,1e-10,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-9
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,44.5,48.666667,53.75,53.0,61.0,44.0,11.5,59.0,9.0,
4,58.2,60.666667,54.6,61.8,56.0,43.2,8.75,61.0,8.5,45.0
8,50.333333,50.0,55.2,60.0,58.75,33.333333,10.2,59.0,8.833333,63.0
16,48.5,40.5,52.444444,62.5,63.0,45.444444,11.333333,63.0,8.25,
32,18.0,28.0,41.0,59.6,53.625,57.0,11.666667,61.666667,9.0,21.0
64,16.0,15.0,22.333333,46.8,52.2,53.0,12.0,62.5,8.0,19.0
128,15.5,,14.0,55.5,53.25,57.6,10.25,61.75,8.75,16.5
256,9.666667,11.0,11.75,53.8,56.75,0.0,10.0,61.666667,7.75,10.0
512,6.666667,7.0,10.0,37.0,48.5,0.0,10.5,58.666667,8.0,8.333333
1024,0.0,0.0,0.0,12.333333,51.2,0.0,12.0,58.666667,8.0,0.0


In [60]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(
    values=["best:epoch/test_recall@20"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_recall@20"]].pivot(


embedding_dimension,0.0,1e-10,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-9
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,0.026371,0.027117,0.027395,0.027732,0.026154,0.022498,0.004749,0.001249,0.001291,
4,0.03502,0.035082,0.035336,0.036723,0.034739,0.023605,0.004646,0.001449,0.000948,0.035559
8,0.039253,0.039335,0.04002,0.042699,0.041714,0.023033,0.004607,0.001213,0.001234,0.039325
16,0.039995,0.039968,0.041175,0.046312,0.044298,0.023699,0.004775,0.00145,0.001523,
32,0.039954,0.039321,0.040843,0.047524,0.044643,0.024045,0.004856,0.001429,0.001495,0.040157
64,0.039525,0.039815,0.040729,0.046027,0.045053,0.024785,0.004896,0.00159,0.001663,0.039497
128,0.038355,,0.039403,0.048712,0.044757,0.02121,0.00482,0.001645,0.001317,0.038563
256,0.037374,0.03723,0.038363,0.047734,0.045211,0.017743,0.004804,0.001716,0.001603,0.038289
512,0.035717,0.036024,0.036662,0.043993,0.044945,0.017482,0.004849,0.001912,0.001574,0.035875
1024,0.037187,0.037858,0.037974,0.040105,0.044832,0.016709,0.004946,0.002026,0.00116,0.037917


In [61]:
experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_ndcg@20"]].pivot(
    values=["best:epoch/test_ndcg@20"],
    index="embedding_dimension",
    columns="l2_regularization"
)

  experiment_summary[["embedding_dimension", "l2_regularization", "best:epoch/test_ndcg@20"]].pivot(


embedding_dimension,0.0,1e-10,1e-8,1e-7,1e-6,0.00001,0.0001,0.001,0.01,1e-9
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2,0.076778,0.080488,0.080453,0.080954,0.076683,0.067139,0.016989,0.004961,0.004804,
4,0.099924,0.099181,0.100808,0.103918,0.101346,0.070803,0.016738,0.005766,0.003144,0.101952
8,0.110814,0.111207,0.113028,0.120665,0.117086,0.070661,0.016199,0.004873,0.004324,0.10962
16,0.113371,0.113472,0.115648,0.12745,0.123636,0.071606,0.017064,0.005524,0.004844,
32,0.113714,0.113383,0.114622,0.131433,0.12553,0.071951,0.016999,0.005687,0.004593,0.113442
64,0.112327,0.112793,0.114465,0.127775,0.126345,0.074572,0.016865,0.005924,0.005059,0.112669
128,0.109684,,0.112423,0.134807,0.125869,0.065726,0.017174,0.006286,0.004135,0.109167
256,0.107132,0.107047,0.110123,0.131401,0.125926,0.05364,0.017114,0.006449,0.005054,0.108992
512,0.103055,0.103696,0.105173,0.12447,0.12614,0.052,0.017076,0.007372,0.005125,0.103988
1024,0.107641,0.109731,0.108747,0.113485,0.126274,0.049593,0.017329,0.007753,0.003697,0.108707


In [62]:
experiment_summary["best:epoch/test_recall@20"].max()

0.04871189594268799

# Cross-GPU Training

In [15]:
experiment_runs.group_by("embedding_dimension").agg(
    pl.col("run_duration_minute").mean()
).sort("embedding_dimension")

embedding_dimension,run_duration_minute
i64,f64
2,22.742412
4,22.665075
8,23.885482
16,22.985601
32,23.285895
64,21.641968
128,23.73886
256,24.837704
512,25.84378
1024,25.352799


In [16]:
experiment_runs.group_by("gpu_type").agg(
    pl.col("run_duration_minute").mean()
).sort("gpu_type")

gpu_type,run_duration_minute
str,f64
"""NVIDIA A100-SXM4-40GB""",47.189427
"""NVIDIA A10G""",23.009375
"""NVIDIA L4""",28.125109


failed to send, dropping 2 traces to intake at http://localhost:8126/v0.5/traces after 3 retries, 6 additional messages skipped


In [17]:
experiment_summary = experiment_runs.filter(pl.col("l2_regularization") == 0.).group_by("embedding_dimension", "gpu_type").agg(
    pl.col("run_id").count().alias("num_runs"),
    pl.col("best:epoch/epoch").mean(),
    pl.col("best:epoch/test_recall@10").mean(),
    pl.col("best:epoch/test_ndcg@10").mean(),
    pl.col("best:epoch/test_recall@50").mean(),
    pl.col("best:epoch/test_ndcg@50").mean(),
).sort("embedding_dimension", "gpu_type")
experiment_summary

embedding_dimension,gpu_type,num_runs,best:epoch/epoch,best:epoch/test_recall@10,best:epoch/test_ndcg@10,best:epoch/test_recall@50,best:epoch/test_ndcg@50
i64,str,u32,f64,f64,f64,f64,f64
2,"""NVIDIA A100-SXM4-40GB""",1,63.0,0.014696,0.059235,0.05268,0.105463
2,"""NVIDIA A10G""",19,63.0,0.014549,0.056882,0.053937,0.103935
2,"""NVIDIA L4""",1,63.0,0.014849,0.058681,0.056454,0.10793
4,"""NVIDIA A100-SXM4-40GB""",1,63.0,0.019013,0.074248,0.070179,0.129829
4,"""NVIDIA A10G""",14,63.0,0.019431,0.075468,0.069966,0.130528
4,"""NVIDIA L4""",4,63.0,0.019367,0.075652,0.069782,0.130409
8,"""NVIDIA A100-SXM4-40GB""",2,63.0,0.02136,0.082498,0.078069,0.141937
8,"""NVIDIA A10G""",19,63.0,0.021914,0.084904,0.078703,0.143965
8,"""NVIDIA L4""",3,63.0,0.021715,0.083619,0.078856,0.143085
16,"""NVIDIA A100-SXM4-40GB""",1,63.0,0.022388,0.087016,0.079871,0.146282


In [18]:
experiment_summary[["embedding_dimension", "gpu_type", "num_runs"]].pivot(
    values=["num_runs"],
    index="embedding_dimension",
    columns="gpu_type"
)

  experiment_summary[["embedding_dimension", "gpu_type", "num_runs"]].pivot(


embedding_dimension,NVIDIA A100-SXM4-40GB,NVIDIA A10G,NVIDIA L4
i64,u32,u32,u32
2,1.0,19,1.0
4,1.0,14,4.0
8,2.0,19,3.0
16,1.0,27,1.0
32,,27,3.0
64,1.0,21,2.0
128,,26,
256,3.0,23,4.0
512,2.0,20,3.0
1024,1.0,23,


In [19]:
experiment_summary[["embedding_dimension", "gpu_type", "best:epoch/test_recall@50"]].pivot(
    values=["best:epoch/test_recall@50"],
    index="embedding_dimension",
    columns="gpu_type"
)

  experiment_summary[["embedding_dimension", "gpu_type", "best:epoch/test_recall@50"]].pivot(


embedding_dimension,NVIDIA A100-SXM4-40GB,NVIDIA A10G,NVIDIA L4
i64,f64,f64,f64
2,0.05268,0.053937,0.056454
4,0.070179,0.069966,0.069782
8,0.078069,0.078703,0.078856
16,0.079871,0.079915,0.081216
32,,0.077398,0.077563
64,0.07268,0.073611,0.072878
128,,0.069959,
256,0.066112,0.066402,0.066835
512,0.062456,0.063286,0.063337
1024,0.061825,0.049841,


In [20]:
experiment_summary[["embedding_dimension", "gpu_type", "best:epoch/test_ndcg@50"]].pivot(
    values=["best:epoch/test_ndcg@50"],
    index="embedding_dimension",
    columns="gpu_type"
)

  experiment_summary[["embedding_dimension", "gpu_type", "best:epoch/test_ndcg@50"]].pivot(


embedding_dimension,NVIDIA A100-SXM4-40GB,NVIDIA A10G,NVIDIA L4
i64,f64,f64,f64
2,0.105463,0.103935,0.10793
4,0.129829,0.130528,0.130409
8,0.141937,0.143965,0.143085
16,0.146282,0.145748,0.147393
32,,0.142054,0.142608
64,0.136999,0.136925,0.135871
128,,0.132737,
256,0.127316,0.128326,0.128494
512,0.123614,0.12421,0.124616
1024,0.123555,0.098477,
