In [None]:
import runpy
import sys
from io import StringIO
import time
from IPython.display import IFrame

import importlib_resources
import pkg_resources
from importlib_resources import path
from pstats import Stats
import pandas as pd
import pstats
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from benchmarks.benchmark_utils import swap_attr

In [None]:
%load_ext snakeviz

## Configuration

In [None]:
NUM_RUNS = 25

In [None]:
loop = [
    {
        "sha": "f37c7f7947d823317651521994aaaf464e6e8dfa",
        "date": "Sat Nov 19 17:55:28 2022 +0000",
        "message": "introduce find_swaps2 - faster version of find_swaps",
        "name": "before lru"
    },
    {
        "sha": "45768358",
        "date": "Fri Oct 27 15:19:24 2023 +0200",
        "name": "v4.0.0",
        "message": "Release version 4.0.0"
    },
    {
        "sha": "5a97ccb6aec2e7c6227aba8a3b33de54f567ee3a",
        "date": "Tue Apr 23 15:17:36 2024 +0200",
        "name": "v4.0.2"
    },
    {
        "sha": "9115580bf7c602ca3c524ad392489bd712f355da",
        "date": "Tue Feb 18 17:03:18 2025 +0100",
        "name": "v4.0.4"
    },
    {
        "sha": "01d4c7ae8179c7a716059552eb31865772e5a549",
        "date": "Tue Feb 18 17:11:28 2025 +0100",
        "name": "118-fix-lru-cache-in-env-loading"
    }
]

## Run Benchmarks with Profiler

In [None]:
# use separate location where we checkout the code version to profile
# hence we can control the configuration of the env in the cli in the current code relative to this notebook
# CAVEAT: we run the checked out code in the current env (as we run it from the notebook) - this could lead to inconsistencies in the future (discarded or updated requirements or backwards-incompatibilities of the benchmarking cli)
!git clone https://github.com/flatland-association/flatland-rl.git /tmp/flatland-rl
!cd /tmp/flatland-rl && git clean -f && git reset --hard

In [None]:
for i in range(NUM_RUNS):
    for l in loop:
        print("===================================================================")
        print(f'{l["name"]} - {l["sha"]} - {i}')
        print("===================================================================")
        !cd /tmp/flatland-rl && git checkout {l["sha"]} && git log -1
        !export PYTHONPATH=/tmp/flatland-rl && python ../examples/flatland_performance_profiling.py -o flatland_performance_profiling.py_{l["sha"]}_{i}.prof 
        #> /tmp/out.txt && head -n 10 /tmp/out.txt 
        time.sleep(2)

## Analyse Profiling

In [None]:
# https://stackoverflow.com/questions/44302726/pandas-how-to-store-cprofile-output-in-a-pandas-dataframe
def prof_to_df(st):
    keys_from_k = ['file', 'line', 'fn']
    keys_from_v = ['cc', 'ncalls', 'tottime', 'cumtime', 'callers']
    data = {k: [] for k in keys_from_k + keys_from_v}
    
    s = st.stats
    
    for k in s.keys():
        for i, kk in enumerate(keys_from_k):
            data[kk].append(k[i])
    
        for i, kk in enumerate(keys_from_v):
            data[kk].append(s[k][i])
    return pd.DataFrame(data)

In [None]:
agg = {"fn": ["first"], "sha": ["first"], "cumtime": ['mean', 'median', 'min', 'max', 'std'], "tottime": ['mean', 'median', 'min', 'max', 'std']}

In [None]:
def aggregate(example):
    dfs = []
    for l in loop:
        for i in range(NUM_RUNS):
            fn = f'{example}_{l["sha"]}_{i}.prof'
            ps = pstats.Stats(fn)
            # print(fn)
            ps = pstats.Stats(fn)
            df = prof_to_df(ps)
            df["sha"]=l["sha"]
            df["name"]=l["name"]
            dfs.append(df)
    df = pd.concat(dfs)
    return df

In [None]:
# tottime is the total time spent in the function alone. 
# cumtime is the total time spent in the function plus all functions that this function called.

In [None]:
def filter_df(df, conditions):
    cond = False
    for fn, file in conditions:
        cond = cond | (df["fn"]==fn) & (df["file"].str.contains(file))
    return df[cond]

In [None]:
def analyse_df(df,fn,file, sort_by="cumtime"):
    df_ = df[(df["fn"]==fn) & (df["file"].str.contains(file))].groupby("name").agg(agg).sort_values((sort_by, "median"), ascending=True)
    df_["diff_median"] = df_[(sort_by, "median")].diff().cumsum()
    df_["diff%_median"] = df_["diff_median"]/(df_[("cumtime", "median")]+df_["diff_median"])*100
    df_["diff_mean"] = df_[(sort_by, "mean")].diff().cumsum()
    df_["diff%_mean"] = df_["diff_mean"]/(df_[("cumtime", "mean")]+df_["diff_mean"])*100
    return df_

In [None]:
df_flatland_performance_profiling = aggregate("flatland_performance_profiling.py")
df_flatland_performance_profiling

### Look into overall performance

In [None]:
plt.figure(figsize=(15,8))
ax = sns.barplot(filter_df(df_flatland_performance_profiling,[
    ("step", "rail_env.py"), 
    ("reset", "rail_env.py"), 
    ("run_simulation", "flatland_performance_profiling.py")
]), x="name", y="cumtime", hue="fn", legend=True, estimator="median")
ax.bar_label(ax.containers[2], fontsize=10);
plt.savefig("performance_overall.png")

The same data in tabular form:

In [None]:
analyse_df(df_flatland_performance_profiling,"run_simulation", "flatland_performance_profiling.py")

### Look into `a_star()` and `find_conflicts()` in relation to `step()`
- improvement of a star: https://github.com/flatland-association/flatland-rl/pull/68 (come in with v4.0.2)
- improvement of motion check: https://github.com/flatland-association/flatland-rl/issues/6 (forthcoming)

In [None]:
plt.figure(figsize=(15,8))
ax = sns.barplot(filter_df(df_flatland_performance_profiling,[
    ("step", "rail_env.py"), 
    ("a_star", "star"), 
    ("find_conflicts", "agent_chains.py"),
]), x="name", y="cumtime", hue="fn", legend=True, estimator="mean")
#ax.bar_label(ax.containers[1], fontsize=10);
plt.savefig("performance_a_star_motion_check.png")

### Look into LRU caching speed-up
- improvement of lru caching (came in with v4.0.0)

In [None]:
plt.figure(figsize=(15,8))
ax = sns.barplot(filter_df(df_flatland_performance_profiling,[
    ("is_dead_end", "map"),
    ("get_transition", "map"),
]), x="name", y="cumtime", hue="fn", legend=True, estimator="median")
ax.bar_label(ax.containers[0], fontsize=10);
plt.savefig("performance_lru.png")

### Snakeviz of individual profiles
Use the following line to start a snakeviz server and open a new browser window:

In [None]:
# !snakeviz "flatland_performance_profiling.py_01d4c7ae8179c7a716059552eb31865772e5a549_3.prof"