In [None]:
import sys
sys.path.append('..')

from src.environments.grid_env import SmartGridEnv
import matplotlib.pyplot as plt
import numpy as np

# Create environment
env = SmartGridEnv(n_agents=5)

# Test reset
obs, info = env.reset()
print("Observation shape:", obs[0].shape)
print("Number of agents:", len(obs))

# Test random actions
episode_rewards = {i: [] for i in range(env.n_agents)}
battery_levels = {i: [] for i in range(env.n_agents)}

for step in range(24):
    # Random actions
    actions = {i: env.action_space.sample() for i in range(env.n_agents)}
    
    obs, rewards, dones, truncated, info = env.step(actions)
    
    for i in range(env.n_agents):
        episode_rewards[i].append(rewards[i])
        battery_levels[i].append(info['battery_levels'][i])
    
    if dones['__all__']:
        break

# Plot results
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Plot rewards
for i in range(env.n_agents):
    axes[0].plot(episode_rewards[i], label=f'Agent {i}')
axes[0].set_title('Rewards per Step')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Reward')
axes[0].legend()
axes[0].grid(True)

# Plot battery levels
for i in range(env.n_agents):
    axes[1].plot(battery_levels[i], label=f'Agent {i}')
axes[1].set_title('Battery Levels')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Battery Level (normalized)')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('../results/figures/environment_test.png')
plt.show()

# Calculate statistics
print("\n=== Episode Statistics ===")
for i in range(env.n_agents):
    total_reward = sum(episode_rewards[i])
    print(f"Agent {i} - Total Reward: {total_reward:.2f}")