In [1]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm

%matplotlib qt
plt.close('all')

In [2]:
# Function to plot the landscape of the cost function

def plotCostSpace(cost_function, param_inteval=[0,1], param_name="π", ax=None, cmap=cm.coolwarm):

    if ax is None:
        _, ax = plt.subplots(subplot_kw=dict(projection='3d'))

    p_range = np.linspace(param_inteval[0], param_inteval[1], 100)
    x, y = np.meshgrid(p_range, p_range)
    p = np.stack([x, y])
    c = cost_function(p)
    ax.plot_surface(x, y, c, cmap=cmap, linewidth=0, antialiased=False, alpha=0.5)

    plt.xlabel(param_name + "_1")
    plt.ylabel(param_name + "_2")
    ax.set_zlabel("Cost (" + param_name + ")")

    plt.show()
    plt.title("Cost space for parametrization w.r.t. " + param_name)
    return ax



In [3]:
# Function to plot the landscape of the cost function

def plotCostSpace(cost_function, param_inteval=[0,1], param_name="π", ax=None, cmap=cm.coolwarm):

    if ax is None:
        _, ax = plt.subplots(subplot_kw=dict(projection='3d'))

    p_range = np.linspace(param_inteval[0], param_inteval[1], 100)
    x, y = np.meshgrid(p_range, p_range)
    p = np.stack([x, y])
    c = cost_function(p)
    ax.plot_surface(x, y, c, cmap=cmap, linewidth=0, antialiased=False, alpha=0.5)

    plt.xlabel(param_name + "_1")
    plt.ylabel(param_name + "_2")
    ax.set_zlabel("Cost (" + param_name + ")")

    plt.show()
    plt.title("Cost space for parametrization w.r.t. " + param_name)
    return ax

In [4]:
# Function to plot the landscape of the cost function

def plot3DCostSpace(cost_function, param_inteval=[0,1], param_name="π", cmap=cm.coolwarm):
    
    p_range = np.linspace(param_inteval[0], param_inteval[1], 100)
    x, y = np.meshgrid(p_range, p_range)
    z = np.ones(x.shape) * np.median(p_range)

    c1 = cost_function(np.stack([x, y, z]))
    c2 = cost_function(np.stack([x, z, y]))
    c3 = cost_function(np.stack([z, x, y]))

    A0_eq = p_range
    A1_eq = np.log(1 - np.exp(A0_eq))
    A2_eq = np.ones(A0_eq.shape) * np.median(p_range)
    C_eq = cost_function(np.stack([A0_eq, A1_eq, A2_eq]))

    i1_min = np.argmin(c1)
    i2_min = np.argmin(c2)
    i3_min = np.argmin(c3)

    ieq_min = np.argmin(C_eq)

    fig = plt.figure(figsize=plt.figaspect(0.5))
   
    ax = fig.add_subplot(1, 3, 1, projection='3d')
    ax.plot_surface(x, y, c1, cmap=cmap, linewidth=0, antialiased=False, alpha=0.5)
    ax.plot(A0_eq, A1_eq, zs=C_eq, zdir='z', color="black", linewidth=3, label='policy submanifold')
    ax.scatter(x.flatten()[i1_min], y.flatten()[i1_min], zs=c1.flatten()[i1_min], zdir='z', c="purple", s=100, marker="*", label='global extremum')
    ax.scatter(A0_eq.flatten()[ieq_min], A1_eq.flatten()[ieq_min], zs=C_eq.flatten()[ieq_min], zdir='z', c="green", s=100, marker="*", label='policy submanifold global extremum')
    plt.xlabel("A_0")
    plt.ylabel("A_1")
    ax.set_zlabel("LogC(A)")
    plt.title("log C(A | A2 = constant) ")
    plt.legend()

    ax = fig.add_subplot(1, 3, 2, projection='3d')
    ax.plot_surface(x, y, c2, cmap=cmap, linewidth=0, antialiased=False, alpha=0.5)
    ax.scatter(x.flatten()[i2_min], y.flatten()[i2_min], zs=c2.flatten()[i2_min], zdir='z', c="purple", s=100, marker="*", label='global extremum')
    plt.xlabel("A_0")
    plt.ylabel("A_2")
    ax.set_zlabel("LogC(A)")
    plt.title("log C(A | A1 = constant) ")

    ax = fig.add_subplot(1, 3, 3, projection='3d')
    ax.plot_surface(x, y, c3, cmap=cmap, linewidth=0, antialiased=False, alpha=0.5)
    ax.scatter(x.flatten()[i3_min], y.flatten()[i3_min], zs=c3.flatten()[i3_min], zdir='z', c="purple", s=100, marker="*", label='global extremum')
    plt.xlabel("A_1")
    plt.ylabel("A_2")
    ax.set_zlabel("LogC(A)")
    plt.title("log C(A | A0 = constant) ")

    plt.show()
    return ax

In [5]:
# Function to test convexity on two points

def testConvexity(pa, pb, cost_func, param_name="π", ax = None, color="black"):

    if ax is None:
        ax = plt.gca()

    pm = (pa + pb) / 2

    Ca  = cost_func(pa)
    Cb  = cost_func(pb)
    Cm  = cost_func(pm)

    ax.plot([pa[0], pb[0]], [pa[1], pb[1]],zs=[Ca ,Cb], color="black")
    ax.scatter([pa[0]], [pa[1]], s=100, zs=[Ca], color="black")
    ax.scatter([pb[0]], [pb[1]], s=100, zs=[Cb], color="black")

    ax.text(pa[0], pa[1], s=param_name+"a", z=Ca+0.1, color=color, fontsize="x-large")
    ax.text(pb[0], pb[1], s=param_name+"b", z=Cb+0.1, color=color, fontsize="x-large")

    ax.plot([pm[0], pm[0]], [pm[1], pm[1]],zs=[(Ca+Cb)/2 ,Cm], color=color, linestyle="--")
    ax.scatter([pm[0]], [pm[1]], s=100, zs=[Cm], color=color)
    ax.text(pm[0], pm[1], s=param_name+"m", z=Cm+0.1, color=color, fontsize="x-large")

    convexity =  Cm <= (Ca+Cb)/2
    if convexity:
        txt = "the cost space for parametrization π is not concave (it could be convex)"
    else:
        txt = "the cost space for parametrization π is not convex (it could be concave)"
    ax.text(.0, .0, 0.5, txt, ha='center')

    return ax




Let's consider the simple MDP shown in Agarwal et al.(2022), page 120.  
For direct parametrization, the cost function is: Cπ = -Vπ = - π1 * π2 * r

In [6]:
# Assuming r=1:
Vπ_func = lambda π : np.prod(π, axis=0)
Lπ_func = lambda π : np.log(Vπ_func(π))
Cπ_func = lambda π : 1-Vπ_func(π)
LCπ_func = lambda π : np.log(Cπ_func(π))

Consider the two policies πa = <0, 0> and πb = <1, 1>, and the middle-point policy πm=<.5, .5>.
If the cost space is convex, then C_πa + C_πb >= C_πm. 
In this case we have C_πa = 0, C_πb = 1, C_πm = -0.25  (no convexity)

In [7]:
πa = np.array([0.1, 0.1])
πb = np.array([0.9, 0.9])

fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1, projection='3d')

plotCostSpace(Vπ_func, param_inteval=[0, 1], param_name="π", ax=ax)
testConvexity(πa, πb, Vπ_func)
ax.set_zlabel("V(π)")
plt.title("V(π)")

ax = fig.add_subplot(1, 2, 2, projection='3d')
plotCostSpace(Lπ_func, param_inteval=[0, 1], param_name="π", ax=ax)
testConvexity(πa, πb, Lπ_func)
ax.set_zlabel("Log V(π)")
plt.title("Log V(π)")

  Lπ_func = lambda π : np.log(Vπ_func(π))


Text(0.5, 0.92, 'Log V(π)')

In [8]:
πa = np.array([0.1, 0.1])
πb = np.array([0.9, 0.9])

fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1, projection='3d')

plotCostSpace(Cπ_func, param_inteval=[0, 1], param_name="π", ax=ax)
testConvexity(πa, πb, Cπ_func)
ax.set_zlabel("C(π)")
plt.title("C(π)")

ax = fig.add_subplot(1, 2, 2, projection='3d')
plotCostSpace(LCπ_func, param_inteval=[0, 1], param_name="π", ax=ax)
testConvexity(πa, πb, LCπ_func)
ax.set_zlabel("Log C(π)")
plt.title("Log C(π)")

  txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w
  LCπ_func = lambda π : np.log(Cπ_func(π))


Text(0.5, 0.92, 'Log C(π)')

Now, let's consider the softmax parametrization π_ij = e^A_ij 

In [9]:
CA_func = lambda A : -np.prod(np.exp(A), axis=0)  # for r=1

Aa = np.array([-3, -2])
Ab = np.array([-2, 0])

plotCostSpace(CA_func, param_inteval=[-5, 0], param_name="A")
testConvexity(Aa, Ab, CA_func,param_name="A")
plt.title("C(A) with π = exp(A)")

Text(0.5, 0.92, 'C(A) with π = exp(A)')

Let's try again, but first let's modify the MDP so that all costs greater or equal to zero (i.e. no positive rewards).
So the cost function now is CcA = (1 - π1 * π2) * c

In [10]:
gpp_cost_func = lambda A : 1 - np.exp(A[1]) + np.exp(A[0]) * np.exp(A[1])
gpp_log_cost_func = lambda A : np.log(gpp_cost_func(A))

In [11]:
GPPa = np.array([-1, -0.])
GPPb = np.array([-1, -2])

plotCostSpace(gpp_cost_func, param_inteval=[-2, 0], param_name="A")
testConvexity(GPPa, GPPb, gpp_cost_func,param_name="A")
plt.title("C_gpp(A) with π = exp(A) and reparametrized action")

Text(0.5, 0.92, 'C_gpp(A) with π = exp(A) and reparametrized action')

Another attempt with the log-cost function:

In [12]:
plotCostSpace(gpp_log_cost_func, param_inteval=[-2, 0], param_name="A")
testConvexity(GPPa, GPPb, gpp_log_cost_func, param_name="A")
plt.title("log C_gpp(A) with π = exp(A) and reparametrized action")

Text(0.5, 0.92, 'log C_gpp(A) with π = exp(A) and reparametrized action')

Now let's consider an unconstrained parametrization of π = exp(A) with no dependent actions.
Then we have C = c * ( p_1u + p_1r * p_2u )  

In [13]:
unconst_loss = lambda A : np.log(np.exp(A[0]) + np.exp(A[1])*np.exp(A[2]))

In [14]:
plot3DCostSpace(unconst_loss, param_inteval=[-5, 0], param_name="A", cmap=cm.coolwarm)
plt.suptitle("LC(A) = log(π_1u + π_1r * π_2u) = log(exp(A0) + exp(A1) * exp(A2))")


  A1_eq = np.log(1 - np.exp(A0_eq))


Text(0.5, 0.98, 'LC(A) = log(π_1u + π_1r * π_2u) = log(exp(A0) + exp(A1) * exp(A2))')

In [15]:
import numpy as np
from matplotlib import pyplot as plt

from agents.pgp.pgp_softmax import SoftMaxPGP
from environments.gridworlds.gridworlds_classic import Agarwal
from plots.gridworlds.gridworld_visualizer import GridWorldVisualizer

maze = Agarwal()
agent = SoftMaxPGP(maze)
viz = GridWorldVisualizer(maze, agent)

for i in range(100):
    p1_up = np.random.random()
    p1_right = 1 -p1_up

    p2_up = np.random.random()
    p2_right = 1 -p2_up
    p2_left = 0

    agent.reset()
    agent.theta[1, 0] = np.log(p1_up)
    agent.theta[1, 1] = np.log(p1_right)
    agent.theta[3, 0] = np.log(p2_up)
    agent.theta[3, 1] = np.log(p2_right)
    agent.theta[3, 3] = np.log(p2_left)

    plt.subplot(1,2,1)
    viz.plot_policy(plot_axis=False)
    plt.title("Initial Policy")

    plt.subplot(1,2,2)
    agent.learn(n_steps=100, alpha=0.1)
    viz.plot_policy(plot_axis=False)
    plt.title("Final Policy")
    plt.show()



  return np.log(agent.A / np.sum(agent.A, axis=1, keepdims=True))
  agent.theta[3, 3] = np.log(p2_left)
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
100%|██████████| 100/100 [00:00<00:00, 1120.11it/s]
100%|██████████| 100/100 [00:00<00:00, 1104.57it/s]
100%|██████████| 100/100 [00:00<00:00, 1108.28it/s]
100%|██████████| 100/100 [00:00<00:00, 813.09it/s]
100%|██████████| 100/100 [00:00<00:00, 1093.13it/s]
100%|██████████| 100/100 [00:00<00:00, 965.63it/s]
100%|██████████| 100/100 [00:00<00:00, 1087.71it/s]
100%|██████████| 100/100 [00:00<00:00, 990.28it/s]
100%|██████████| 100/100 [00:00<00:00, 1106.96it/s]
100%|██████████| 100/100 [00:00<00:00, 831.87it/s]
100%|██████████| 100/100 [00:00<00:00, 1091.69it/s]
100%|██████████| 100/100 [00:00<00:00, 1083.21it/s]
100%|██████████| 100/100 [00:00<00:00, 1093.39it/s]
100%|██████████| 100/100 [00:00<00:00, 1116.74it/s]
100%|██████████| 100/100 [00:00<00:00, 960.84it/s]
100%|██████████| 100/100 [00:00<00:00, 814.07it/s]
100%|██████████| 100/100 [00