In [85]:
import gymnasium as gym
import torch
import time
import numpy as np
import torch.nn.functional as F
from modules import GenerativeModel, ReplayBuffer, DiscreteSACAgent, EpisodicRewardWrapper
import copy

In [86]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR_POLICY = 3e-4          # Learning rate for actor and critic
LR_GENERATIVE = 3e-4      # Learning rate for generative model
GAMMA = 0.99              # Discount factor
REPLAY_BUFFER_SIZE = 50000 # Size of the replay buffer
BATCH_SIZE = 256          # Batch size for training
TAU = 0.005               # Soft update coefficient for target networks
ALPHA = 0.2               # SAC temperature parameter (entropy regularization)
HIDDEN_DIM = 256          # Hidden dimension for neural networks
MAX_EPISODES = 500      # Total number of episodes to run
MAX_STEPS_PER_EPISODE = 500 # Max steps per episode for CartPole-v1
START_TRAINING_EPISODES = 10 # Number of episodes to collect data before training starts

# Hyperparameters for the GRD generative model loss (L_reg)
# These control the sparsity of the learned causal graph. Increased to encourage sparsity.
LAMBDA_S_R = 5e-4  # state -> reward
LAMBDA_A_R = 1e-5  # action -> reward
LAMBDA_S_S = 5e-5  # state -> state
LAMBDA_A_S = 1e-8  # action -> state

In [87]:
from modules import EpisodicRewardWrapper
env = gym.make("CartPole-v1", render_mode='human')
# env = EpisodicRewardWrapper(env)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

compact_state_dim = state_dim 

replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE)
generative_model = GenerativeModel(state_dim, action_dim, HIDDEN_DIM, DEVICE, LR_GENERATIVE, GAMMA, LAMBDA_S_S, LAMBDA_S_R, LAMBDA_A_S, LAMBDA_A_R)
sac_agent = DiscreteSACAgent(state_dim, action_dim, compact_state_dim, HIDDEN_DIM, DEVICE, LR_POLICY, TAU, GAMMA, ALPHA)

In [88]:
def models_equal(model1, model2, rtol=1e-05, atol=1e-08):
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        if not torch.allclose(p1, p2, rtol=rtol, atol=atol):
            return False
    return True

In [89]:
generative_model.load('weights', DEVICE)
sac_agent.load('weights', DEVICE)

In [90]:
with torch.no_grad():
    C_s_s, _, C_s_r, C_a_r = generative_model.causal_module.get_causal_masks(training=False)
    compact_mask = generative_model.causal_module.get_compact_representation_mask(C_s_s, C_s_r)
    s_r_probs = F.softmax(generative_model.causal_module.s_to_r_logits, dim=-1)[:, 1].cpu().numpy()
    a_r_probs = F.softmax(generative_model.causal_module.a_to_r_logits, dim=-1)[:, 1].cpu().numpy()
    print(f"  Causal Probs (S->R): {[f'{p:.2f}' for p in s_r_probs]}")
    print(f"  Compact Mask: {compact_mask.cpu().numpy()}")

  Causal Probs (S->R): ['1.00', '1.00', '1.00', '1.00']
  Compact Mask: [1. 1. 1. 1.]


In [None]:

state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)

episode_reward = 0

for t in range(MAX_STEPS_PER_EPISODE):
    with torch.no_grad():
        C_s_s, _, C_s_r, _ = generative_model.causal_module.get_causal_masks(training=False)
        compact_mask = generative_model.causal_module.get_compact_representation_mask(C_s_s, C_s_r)
        compact_state = state * compact_mask
        
    action = sac_agent.select_action(compact_state)
    next_state_np, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    

    episode_reward += reward 
    state = torch.tensor(next_state_np, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    
    if done:
        print("done")
        break
    
    

print(f"Reward: {episode_reward}")

    

1
1
0
1
0
0
1
0
1
0
0
1
0
1
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
1
0
0
1
1
0
0
1
0
1
0
1
1
0
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0
0
1
0
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
1
0
0
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0
0
1
0
1
1
0
1
0
0
1
0
1
1
0
0
1
1
0
1
0
0
1
0
1
1
0
1
0
0
1
0
1
1
0
1
0
0
1
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
0


KeyboardInterrupt: 