In [6]:
import tensorflow as tf
import numpy as np
import pandas as pd
import copy

import altair as alt

In [8]:
data = pd.read_parquet("../adaptation_rate.parquet", engine="pyarrow")
data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,10.0,0.0,2.417426,1.074746,2.997491,0.878639,-0.226966,2.026018,3.952866,-0.035098,1.134337,0.892359
1,10.0,1.0,-3.311921,-2.513282,3.297085,3.943522,-1.148066,-1.081385,6.188364,-0.340237,0.970285,0.125787
2,10.0,2.0,8.297198,-9.012988,0.369586,1.369888,-0.427774,-1.517351,2.071255,-0.796607,0.848126,-0.383256
3,10.0,3.0,-2.468949,3.3247,-10.99008,-4.776526,-0.869051,-2.122751,5.169607,-0.3771,1.104195,-0.471872
4,10.0,4.0,8.226031,6.355278,-0.30907,3.013748,1.198458,5.652486,0.576945,-0.153741,0.173159,-0.327417


In [9]:
# rename columns
headers = {"0": "adaptation_rate", "1": "iteration"}
for gradient in range(10):
    headers[str(gradient + 2)] = f"g_{gradient}"


data = data.rename(columns=headers)
data.head()

Unnamed: 0,adaptation_rate,iteration,g_0,g_1,g_2,g_3,g_4,g_5,g_6,g_7,g_8,g_9
0,10.0,0.0,2.417426,1.074746,2.997491,0.878639,-0.226966,2.026018,3.952866,-0.035098,1.134337,0.892359
1,10.0,1.0,-3.311921,-2.513282,3.297085,3.943522,-1.148066,-1.081385,6.188364,-0.340237,0.970285,0.125787
2,10.0,2.0,8.297198,-9.012988,0.369586,1.369888,-0.427774,-1.517351,2.071255,-0.796607,0.848126,-0.383256
3,10.0,3.0,-2.468949,3.3247,-10.99008,-4.776526,-0.869051,-2.122751,5.169607,-0.3771,1.104195,-0.471872
4,10.0,4.0,8.226031,6.355278,-0.30907,3.013748,1.198458,5.652486,0.576945,-0.153741,0.173159,-0.327417


In [10]:
#color=alt.condition(
#        gradient_domain[0] <= alt.datum.a_0 <= gradient_domain[1],
#        alt.value("steelblue"),
#        alt.value("orange")
#)

In [11]:
gradient_domain = (-1, 1)
action_names = list(data.columns)[2:]

# add line chart for baseline
base = alt.Chart().mark_line(clip=True).encode(
    x="iteration:O",
    #tooltip=["iteration", "adaptation_rate", "a_0"]
).properties(
    height=200,
    width=600
)

# add points to for inspection convenience
base += base.mark_point(clip=True).encode(
    x="iteration:O"
).properties(
    height=200,
    width=600
)

# manually build the faceted chart. build the rows and combine
# them to the finished polt
chart = alt.vconcat(data=data)
for adaptation_rate in data["adaptation_rate"].unique():
    row = alt.hconcat()
    for action in action_names:
        row |= base.encode(
            y=alt.Y(field=action,
                    type="quantitative",
                    scale=alt.Scale(domain=gradient_domain,
                                    zero=True
                                   ),
                    axis=alt.Axis(title=f"{action}, {adaptation_rate}")
               ),
            tooltip=["iteration", "adaptation_rate", action]
        ).transform_filter(
            f"datum.adaptation_rate == {adaptation_rate}"
        )
    chart &= row

# save the plot
chart.save("adaptation_rate.html")