In [None]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
from torch.distributions import Categorical
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def measure_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
        return result
    return wrapper

class SharedMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim).to(device)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim).to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(device)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def get_parameters(self):
        return [p.data.clone() for p in self.parameters()]
    
    def set_parameters(self, new_params):
        for p, new_p in zip(self.parameters(), new_params):
            p.data.copy_(new_p)

class LocalNetwork(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64):
        super().__init__()
        self.mlp = SharedMLP(state_dim, hidden_dim).to(device)
        self.actor = nn.Linear(hidden_dim, action_dim).to(device)
        self.critic = nn.Linear(hidden_dim, 1).to(device)
    
    def pi(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(x.to(device))
        x = self.actor(x)
        prob = F.softmax(x, dim=-1)
        return prob
    
    def v(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(x.to(device))
        v = self.critic(x)
        return v
    
    def sample_action(self, x: torch.Tensor):
        prob = self.pi(x)
        m = Categorical(prob)
        action = m.sample()
        return action.item(), m.log_prob(action)

def average_shared_mlp(agent_networks):
    with torch.no_grad():
        num_agents = len(agent_networks)
        averaged_params = [torch.zeros_like(p, device=device) for p in agent_networks[0].mlp.get_parameters()]
        
        for agent in agent_networks:
            params = agent.mlp.get_parameters()
            for i in range(len(averaged_params)):
                averaged_params[i] += params[i] / num_agents
        
        for agent in agent_networks:
            agent.mlp.set_parameters(averaged_params)

@measure_time
def train_single_agent(env, agent, optimizer, episodes=500):
    gamma = 0.99
    losses = []
    
    for episode in range(episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        log_probs, rewards = [], []
        done = False
        
        while not done:
            action, log_prob = agent.sample_action(state)
            next_state, reward, done, _, _ = env.step(action)
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)
            
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        R = 0
        policy_loss = []
        for r in reversed(rewards):
            R = r + gamma * R
            policy_loss.append(-log_probs.pop() * R)
        
        optimizer.zero_grad()
        loss = torch.stack(policy_loss).sum()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        print(f"Episode {episode + 1}: Loss = {loss.item():.4f}")
    
    return losses

@measure_time
def train_federated_agents(env, agents, optimizers, episodes=500, sync_interval=10):
    gamma = 0.99
    losses = []
    
    for episode in range(episodes):
        episode_losses = []
        
        for agent, optimizer in zip(agents, optimizers):
            torch.cuda.empty_cache()
            state, _ = env.reset()
            state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            log_probs, rewards = [], []
            done = False
            
            while not done:
                action, log_prob = agent.sample_action(state)
                next_state, reward, done, _, _ = env.step(action)
                next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)
                
                log_probs.append(log_prob)
                rewards.append(reward)
                state = next_state
            
            R = 0
            policy_loss = []
            for r in reversed(rewards):
                R = r + gamma * R
                policy_loss.append(-log_probs.pop() * R)
            
            optimizer.zero_grad()
            loss = torch.stack(policy_loss).sum()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            episode_losses.append(loss.item())
        
        if episode % sync_interval == 0:
            average_shared_mlp(agents)
        
        avg_loss = sum(episode_losses) / len(episode_losses)
        losses.append(loss.item())
        print(f"Episode {episode + 1}: Federated Sync {episode % sync_interval == 0}, Loss = {avg_loss:.4f}")
        
    return losses



def plot_losses(single_agent_losses, federated_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(single_agent_losses, label='Single Agent', alpha=0.8)
    plt.plot(federated_losses, label='Federated Agents', alpha=0.8)
    
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Optional: add vertical lines for federation sync points
    sync_interval = 10
    for i in range(0, len(federated_losses), sync_interval):
        plt.axvline(x=i, color='gray', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    plt.show()


env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
hidden_dim = 64
episodes = 100
learning_rate = 0.01

# Single Agent Training
single_agent = LocalNetwork(state_dim, action_dim, hidden_dim).to(device)
single_optimizer = torch.optim.Adam(single_agent.parameters(), lr=learning_rate)
single_agent_losses = train_single_agent(env, single_agent, single_optimizer, episodes=episodes)

# Federated Agents Training
agents = [LocalNetwork(state_dim, action_dim, hidden_dim).to(device) for _ in range(3)]
optimizers = [torch.optim.Adam(agent.parameters(), lr=learning_rate) for agent in agents]
federated_losses = train_federated_agents(env, agents, optimizers, episodes=episodes, sync_interval=10)

# Plot the results
plot_losses(single_agent_losses, federated_losses)

print("Training Complete")