In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
from msdm.domains.gridgame.tabulargridgame import TabularGridGame
from msdm.domains.gridworld.mdp import GridWorld
from msdm.algorithms.multiagentqlearning import TabularMultiagentQLearner
from msdm.algorithms.friendfoeq import TabularFriendFoeQLearner
from msdm.algorithms.correlatedq import TabularCorrelatedQLearner
from msdm.algorithms.nashq import TabularNashQLearner
from msdm.core.problemclasses.stochasticgame.policy.tabularpolicy import SingleAgentPolicy
from msdm.core.assignment.assignmentmap import AssignmentMap
import msdm
import numpy as np
import importlib
import itertools
import IPython.display as display

In [None]:
two_player = """
# # # # #
# . G . # 
# . . . #
# A0.~ . A1.~ #
# # # # #
"""
gg = TabularGridGame(two_player,agent_symbols=("A0","A1"),goal_symbols=(("G",("A0","A1")),),step_cost=-1,collision_cost=-5,goal_reward=100)
gg.state_list
print("State list generated")

In [None]:
random_policy = AssignmentMap()
for state in gg.state_list:
    actions = list(gg.joint_actions(state)["A1"])
    random_policy[state] = AssignmentMap()
    for action in actions:
        random_policy[state][action] = 1.0/len(actions)
random_policy = SingleAgentPolicy("A1",gg,random_policy)

In [None]:
all_agents = ["A0","A1"]
learning_agents = ["A0","A1"]
friends = {"A1":[],"A0":[]}
foes = {"A1":["A0"],"A0":["A1"]}
# other_policies = {"A0":random_policy}
other_policies = {}
params = {"num_episodes":2000,"epsilon":.01,"epsilon_decay":1.0,"discount_rate":.99,
          "learning_rate":0.01,"show_progress":True,"default_q_value":1.0}

In [None]:
q_learner = TabularMultiagentQLearner(learning_agents,other_policies,all_actions=True,alg_name="Q-Learning",**params)
ffq_learner = TabularFriendFoeQLearner(learning_agents,friends,foes,other_policies,alg_name="FFQ-Learning",**params)
libertarian_q_learner = TabularCorrelatedQLearner(learning_agents,other_policies,objective_func="Libertarian",alg_name="Libertarian CEQ",**params)
utilitarian_q_learner = TabularCorrelatedQLearner(learning_agents,other_policies,objective_func="Utilitarian",alg_name="Utilitarian CEQ",**params)
republican_q_learner = TabularCorrelatedQLearner(learning_agents,other_policies,objective_func="Republican",alg_name="Republican CEQ",**params)
egalitarian_q_learner = TabularCorrelatedQLearner(learning_agents,other_policies,objective_func="Egalitarian",alg_name="Egalitarian CEQ",**params)
nash_q_learner = TabularNashQLearner(learning_agents,other_policies,alg_name="Nash-Q Learning",**params)
algorithms = [q_learner,ffq_learner,utilitarian_q_learner,republican_q_learner,egalitarian_q_learner,libertarian_q_learner,nash_q_learner]

In [None]:
results = []
example_trajectories = []
for alg in algorithms:
    res = alg.train_on(gg)
    results.append(res)
    trajectory = res.pi.run_on(gg,maxSteps=10)
    example_trajectories.append(trajectory)

In [None]:
from msdm.domains.gridgame.policyviztools import positionMapping, positionActionMapping, weightMapping 
for k,alg in enumerate(algorithms): 
    fig,axes = plt.subplots(1,len(all_agents),figsize=(20,10))
    fig.suptitle(alg.alg_name)
    for i,agent_name in enumerate(all_agents):
        plotter = gg.plot(ax=axes[i])
        plotter.title(agent_name + " Values")
        q_matrix = results[k].pi.single_agent_policies[agent_name].q_matrix
        occupancy_matrix = results[k].pi.occupancy_matrix
        occupancy_matrix.fill(1.0/len(occupancy_matrix[0]))
        initial_state = gg.initial_state_dist().sample()
        initial_index = gg.state_list.index(initial_state)
        plotter.plot_state_map(positionMapping(results[k].pi,agent_name,q_matrix,occupancy_matrix,initial_state))

In [None]:
animations = []
for k,alg in enumerate(algorithms):
    fig, axes = plt.subplots(1,1,figsize=(20,10))
    fig.suptitle(alg.alg_name)
    trajectory = results[k].pi.run_on(gg,maxSteps=10)
    animator = gg.animate(figure=fig,ax=axes)
    animation = animator.animate_trajectory(trajectory)
    animations.append(animation)
display.display(*[display.HTML(animation.to_jshtml()) for animation in animations])
plt.close()