In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib.patches import Rectangle, Circle
from IPython.display import HTML
import matplotlib as mpl
import warnings

warnings.filterwarnings('ignore')

In [2]:
#Configuração Inicial do Ambiente e da Figura
mpl.rcParams['animation.embed_limit'] = 50.0
env = gym.make('CartPole-v1')
observation = env.reset()

fig, ax = plt.subplots(figsize=(10, 8))
ax.set_aspect('equal', adjustable='box')
ax.set_xlim(-3, 3)
ax.set_ylim(-2, 2)

# Cores
cart_color_base = 'skyblue'
cart_color_highlight = 'lightskyblue'
wheel_color = 'dimgray'
wheel_rim_color = 'black'
pendulum_color_base = 'forestgreen'
pendulum_color_highlight = 'lightgreen'
mass_color = 'gold'
joint_color = 'black'
cube_color = 'gray'
ground_color = 'lightgray'

plt.close()

In [3]:
# Elementos do CartPole aprimorados
cart_width = 0.6
cart_height = 0.3
cart_body = plt.Rectangle((-cart_width / 2, -1 + cart_height / 2), cart_width, cart_height, color=cart_color_base, zorder=10)
cart_highlight = plt.Rectangle((-cart_width / 2 + 0.05, -1 + cart_height / 2 + 0.05), cart_width - 0.1, cart_height - 0.1, color=cart_color_highlight, zorder=10)
wheel_radius = 0.08
wheel_left = Circle((-cart_width / 2 + wheel_radius, -1 - wheel_radius/2), wheel_radius, color=wheel_color, zorder=11)
wheel_right = Circle((cart_width / 2 - wheel_radius, -1 - wheel_radius/2), wheel_radius, color=wheel_color, zorder=11)
wheel_rim_left = Circle((-cart_width / 2 + wheel_radius, -1 - wheel_radius/2), wheel_radius * 0.8, color=wheel_rim_color, zorder=12)
wheel_rim_right = Circle((cart_width / 2 - wheel_radius, -1 - wheel_radius/2), wheel_radius * 0.8, color=wheel_rim_color, zorder=12)

pendulum_length = 1.5
mass_radius = 0.1
joint = Circle((0, -1), 0.05, color=joint_color, zorder=13)
pendulum_line, = ax.plot([], [], lw=3, color=pendulum_color_base, zorder=11)
mass = Circle((0, -1 + pendulum_length), mass_radius, color=mass_color, zorder=12)
mass_highlight = Circle((0.05, -1 + pendulum_length + 0.05), mass_radius * 0.6, color='lightyellow', zorder=12) # Simulação de brilho

ax.add_patch(cart_body)
ax.add_patch(cart_highlight)
ax.add_patch(wheel_left)
ax.add_patch(wheel_right)
ax.add_patch(wheel_rim_left)
ax.add_patch(wheel_rim_right)
ax.add_patch(joint)
ax.add_patch(mass)
ax.add_patch(mass_highlight);

In [4]:
# Adicionar o chão
ground_height = 0.1
ground = Rectangle((-3, -1 - wheel_radius - ground_height), 6, ground_height, color=ground_color, zorder=2)
ax.add_patch(ground);

In [5]:
# Parâmetros do cubo simulado
cube_size = 2
rotation_speed = 0.02
num_cube_lines = 12
cube_lines = []
for _ in range(num_cube_lines):
    line, = ax.plot([], [], color=cube_color, linewidth=1, zorder=1)
    cube_lines.append(line)

def init():
    cart_body.set_xy((-cart_width / 2, -1 + cart_height / 2))
    cart_highlight.set_xy((-cart_width / 2 + 0.05, -1 + cart_height / 2 + 0.05))
    wheel_left.center = (-cart_width / 2 + wheel_radius, -1 - wheel_radius/2)
    wheel_right.center = ((cart_width / 2 - wheel_radius), -1 - wheel_radius/2)
    wheel_rim_left.center = (-cart_width / 2 + wheel_radius, -1 - wheel_radius/2)
    wheel_rim_right.center = ((cart_width / 2 - wheel_radius), -1 - wheel_radius/2)
    joint.center = (0, -1)
    pendulum_line.set_data([], [])
    mass.center = (0, -1 + pendulum_length)
    mass_highlight.center = (0.05, -1 + pendulum_length + 0.05)
    ground.set_xy((-3, -1 - wheel_radius - ground_height))
    for line in cube_lines:
        line.set_data([], [])
    return [cart_body, cart_highlight, wheel_left, wheel_right, wheel_rim_left, wheel_rim_right, joint, pendulum_line, mass, mass_highlight, ground] + cube_lines

def update(frame):
    global observation
    action = env.action_space.sample()
    observation, _, done, _, _ = env.step(action)
    x, _, theta, _ = observation

    # Atualizar carrinho
    cart_body.set_xy((x - cart_width / 2, -1 + cart_height / 2))
    cart_highlight.set_xy((x - cart_width / 2 + 0.05, -1 + cart_height / 2 + 0.05))
    wheel_left.center = (x - cart_width / 2 + wheel_radius, -1 - wheel_radius/2)
    wheel_right.center = (x + cart_width / 2 - wheel_radius, -1 - wheel_radius/2)
    wheel_rim_left.center = (x - cart_width / 2 + wheel_radius, -1 - wheel_radius/2)
    wheel_rim_right.center = (x + cart_width / 2 - wheel_radius, -1 - wheel_radius/2)
    joint.center = (x, -1)

    # Atualizar pêndulo
    pendulum_x = x + np.sin(theta) * pendulum_length
    pendulum_y = -1 + np.cos(theta) * pendulum_length
    pendulum_line.set_data([x, pendulum_x], [-1, pendulum_y])
    mass.center = (pendulum_x, pendulum_y)
    mass_highlight.center = (pendulum_x + 0.05, pendulum_y + 0.05) # Acompanhar o movimento

    # Atualizar a posição do chão (manter fixo por enquanto)
    ground.set_xy((-3, -1 - wheel_radius - ground_height))

    # Rotacionar o cubo
    angle = frame * rotation_speed
    vertices = [
        [-cube_size / 2, -cube_size / 2, -cube_size / 2],
        [cube_size / 2, -cube_size / 2, -cube_size / 2],
        [cube_size / 2, cube_size / 2, -cube_size / 2],
        [-cube_size / 2, cube_size / 2, -cube_size / 2],
        [-cube_size / 2, -cube_size / 2, cube_size / 2],
        [cube_size / 2, -cube_size / 2, cube_size / 2],
        [cube_size / 2, cube_size / 2, cube_size / 2],
        [-cube_size / 2, cube_size / 2, cube_size / 2],
    ]
    rotated_vertices_x = [v[0] * np.cos(angle) - v[2] * np.sin(angle) for v in vertices]
    rotated_vertices_z = [v[0] * np.sin(angle) + v[2] * np.cos(angle) for v in vertices]
    vertices_2d = [[rotated_vertices_x[i], vertices[i][1]] for i in range(8)]
    edges = [(0, 1), (1, 2), (2, 3), (3, 0), (4, 5), (5, 6), (6, 7), (7, 4), (0, 4), (1, 5), (2, 6), (3, 7)]
    for i, (start, end) in enumerate(edges):
        x_data = [vertices_2d[start][0], vertices_2d[end][0]]
        y_data = [vertices_2d[start][1] - 1, vertices_2d[end][1] - 1]
        cube_lines[i].set_data(x_data, y_data)

    if done:
        observation = env.reset()

    return [cart_body, cart_highlight, wheel_left, wheel_right, wheel_rim_left, wheel_rim_right, joint, pendulum_line, mass, mass_highlight, ground] + cube_lines

In [6]:
num_frames = 300
ani = animation.FuncAnimation(fig, update, frames=num_frames, init_func=init, blit=True, interval=33)

HTML(ani.to_jshtml())
plt.close()

In [7]:
ani.save('cartpole_with_ground_rotating_cube.gif', writer='pillow', fps=30)
plt.close()