In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pickle

In [45]:
# import pickle file from ./test_policy/test_trajectory.pkl
with open('./test_policy/test_fully-connected_10/test_trajectory.pkl', 'rb') as f:
    trajectories = pickle.load(f)

print(len(trajectories))

# trajectories is a list of trajectories. Where each trajectory is a list of:
# [timestep, agent1_x, agent1_y, agent2_x, agent2_y, agent3_x, agent3_y, landmark1_x, landmark1_y, landmark2_x, landmark2_y, landmark3_x, landmark3_y, reward_1, reward_2, reward_3]
# convert trajectories to pd dataframe with columns: episode, timestep, agent1_x, agent1_y, agent2_x, agent2_y, agent3_x, agent3_y, reward_1, reward_2, reward_3
# Flatten the nested list
flattened_data = [tup for trajectory in trajectories for tup in trajectory]

# Convert to DataFrame
df = pd.DataFrame(flattened_data, columns=[ 'episode', 'timestep',
                                            'agent1_x', 'agent1_y',
                                            'agent2_x', 'agent2_y', 
                                            'agent3_x', 'agent3_y', 
                                            'land1_x', 'land1_y',
                                            'land2_x', 'land2_y', 
                                            'land3_x', 'land3_y', 
                                            'reward_1', 'reward_2', 'reward_3'])

# print(df.head())

BLUE = [0, 0.4470, 0.7410]
RED = [0.8500, 0.3250, 0.0980]
YELLOW = [0.929, 0.6940, 0.1250]
# plot the x,y trajectory of agents in episode 1 sorted by timestep using seaborn
episode = 20
sns.lineplot(x="agent1_x", y="agent1_y", data=df[df['episode']==episode], sort=False, color=BLUE)
sns.lineplot(x="agent2_x", y="agent2_y", data=df[df['episode']==episode], sort=False, color=RED)
sns.lineplot(x="agent3_x", y="agent3_y", data=df[df['episode']==episode], sort=False, color=YELLOW)

# mark the location of landmarks
sns.scatterplot(x="land1_x", y="land1_y", data=df[df['episode']==episode], color='gray', s=300)
sns.scatterplot(x="land2_x", y="land2_y", data=df[df['episode']==episode], color='gray', s=300)
sns.scatterplot(x="land3_x", y="land3_y", data=df[df['episode']==episode], color='gray', s=300)


# make axis equal and limit the axis to -1.0 to 1.0
plt.axis('square')
plt.xlim(-1,1)
plt.ylim(-1,1)
plt.xlabel('x')
plt.ylabel('y')
plt.show()



1000


In [46]:
# plot the reward of agents in episode 1 sorted by timestep using seaborn
sns.lineplot(x="timestep", y="reward_1", data=df[df['episode']==episode], sort=False, color=BLUE)
sns.lineplot(x="timestep", y="reward_2", data=df[df['episode']==episode], sort=False, color=RED)
sns.lineplot(x="timestep", y="reward_3", data=df[df['episode']==episode], sort=False, color=YELLOW)

<Axes: xlabel='timestep', ylabel='reward_1'>

In [44]:

%matplotlib qt
import pygame
import time

pygame.init()

# Display settings
WIDTH, HEIGHT = 640, 480
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Multi-Agent and Landmarks Animation")

# Colors
BLUE = [0, 0.4470*255, 0.7410*255]
RED = [0.8500*255, 0.3250*255, 0.0980*255]
YELLOW = [0.929*255, 0.6940*255, 0.1250*255]

WHITE = (255, 255, 255)
AGENTS_COLORS = [BLUE, RED, YELLOW]
LANDMARK_COLORS = [(50, 50, 50), (50, 50, 50), (50, 50, 50)]
BLACK = (0, 0, 0)
# Sample trajectory data
trajectory = trajectories[4]

# drop the first element of the list for each list item in trajectory
trajectory = [t[1:] for t in trajectory]

# shift and scale the data (-1,1) to fit the screen size (0,500) and (0,500)

SCALE_FACTOR = 240  # given our dimensions and trajectory range
SCREEN_CENTER = (WIDTH // 2, HEIGHT // 2)

# Choose a font (using a default system font here)
font = pygame.font.SysFont("arial", 16)

def map_to_screen(pos):
    """Map a trajectory position to a screen position."""
    return int(pos[0] * SCALE_FACTOR + SCREEN_CENTER[0]), int(pos[1] * SCALE_FACTOR + SCREEN_CENTER[1])


def draw_entity(screen, x, y, color, size=.2*240):
    pygame.draw.circle(screen, color, (int(x), int(y)), size)



def display_text(text, x, y, color=BLACK):
    """Render and display text on the screen at specified coordinates."""
    text_surface = font.render(text, True, color)
    screen.blit(text_surface, (x, y))


def main():
    clock = pygame.time.Clock()
    running = True
    current_time = 0
    last_time = time.time()

    while running:
        screen.fill(WHITE)

        # Calculate the current positions of the agents and landmarks
        for i in range(len(trajectory) - 1):
            t0, *data0 = trajectory[i]
            t1, *data1 = trajectory[i + 1]

            if t0 <= current_time < t1:
                alpha = (current_time - t0) / (t1 - t0)

                # Drawing agents
                for j in range(3):  
                    x0, y0, x1, y1 = data0[j * 2], data0[j * 2 + 1], data1[j * 2], data1[j * 2 + 1]
                    r0, r1 = data0[12 + j], data1[12 + j]
                    x = x0 * (1 - alpha) + x1 * alpha
                    y = y0 * (1 - alpha) + y1 * alpha
                    r = r0 * (1 - alpha) + r1 * alpha
                    screen_x, screen_y = map_to_screen((x, y))
                    draw_entity(screen, screen_x, screen_y, AGENTS_COLORS[j])
                    display_text(f"{r:.3f}", screen_x-10, screen_y-60)

                # Drawing landmarks
                for j in range(3):
                    x0, y0, x1, y1 = data0[6 + j * 2], data0[7 + j * 2], data1[6 + j * 2], data1[7 + j * 2]
                    x = x0 * (1 - alpha) + x1 * alpha
                    y = y0 * (1 - alpha) + y1 * alpha
                    screen_x, screen_y = map_to_screen((x, y))
                    draw_entity(screen, screen_x, screen_y, LANDMARK_COLORS[j], size=.05*240)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

        pygame.display.flip()
        clock.tick(60)

        # Update current time
        current_time += .01

        # Exit loop when trajectory ends
        if current_time > trajectory[-1][0]:
            running = False

    pygame.quit()

if __name__ == "__main__":
    main()


KeyboardInterrupt: 