In [10]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np

# random.seed(42)
# np.random.seed(42)
# torch.manual_seed(42)
# torch.cuda.manual_seed_all(42)

# 1. Load AG News data
dataset = load_dataset('ag_news', split='train[:2000]')  # Keep it small for demo

# 2. Build a simple vocab
tokenizer = lambda s: s.lower().split()  # Replace with better tokenizer if needed
vocab = build_vocab_from_iterator((tokenizer(x['text']) for x in dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 3. Encode samples
def encode(text):
    tokens = tokenizer(text)
    return torch.tensor([vocab[token] for token in tokens][:8], dtype=torch.long)  # seq_len=8
    

X = [encode(sample['text']) for sample in dataset]
X = pad_sequence(X, batch_first=True, padding_value=0)  # Pad to max len in batch (or use fixed len)

# 4. Prepare labels and agent assignments
y = torch.tensor([sample['label'] for sample in dataset])
user_ids = y  # Assign agent per news category
n_samples = len(X)

# --------- 2. Model Components ---------

class Backbone(nn.Module):
    def __init__(self, input_dim, model_dim, n_heads):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, model_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 10, model_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1)]
        out = self.encoder(x)
        return out[:, 0, :]
        
class AgentFFN(nn.Module):
    def __init__(self, model_dim, out_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, out_dim)
        )
    def forward(self, x):
        return self.ffn(x)

class RoutingNetwork(nn.Module):
    def __init__(self, model_dim, n_agents):
        super().__init__()
        self.linear = nn.Linear(model_dim, n_agents)
    def forward(self, features):
        logits = self.linear(features)
        probs = torch.softmax(logits, dim=-1)
        return logits, probs

class AssignmentModule:
    def __init__(self, n_agents):
        self.n_agents = n_agents

    def __call__(self, user_id):
        # Assign agent by rule: user_id % n_agents (can be extended)
        if isinstance(user_id, torch.Tensor):
            return (user_id % self.n_agents).item() 
        else:
            return user_id % self.n_agents

class DualRoutingModule(nn.Module):
    def __init__(self, model_dim, n_agents, agents):
        super().__init__()
        self.routing_network = RoutingNetwork(model_dim, n_agents)
        self.assignment_module = AssignmentModule(n_agents)
        self.agents = agents  # nn.ModuleList of AgentFFN

    def forward(self, features, user_id=None, mode='dynamic', return_routing=False):
        """
        features: (batch, model_dim)
        user_id: for static routing
        mode: 'dynamic' or 'static'
        return_routing: if True, return logits/probs (dynamic only)
        Returns:
            outputs: (batch, out_dim)
            (optional: logits, probs for dynamic)
        """
        batch_size = features.size(0)
        outputs = []
        if mode == 'dynamic':
            logits, probs = self.routing_network(features)
            agent_indices = torch.argmax(probs, dim=-1)  # (batch,)
            for i in range(batch_size):
                ai = agent_indices[i].item()
                outputs.append(self.agents[ai](features[i:i+1]))
            outputs = torch.cat(outputs, dim=0)
            if return_routing:
                return outputs, logits, probs
            else:
                return outputs
        elif mode == 'static':
            assert user_id is not None, "user_id required for static routing"
            agent_idx = self.assignment_module(user_id)
            out = self.agents[agent_idx](features)
            return out
        else:
            raise ValueError("[Agentic] mode must be 'dynamic' or 'static'")

class AgenticTransformerDualRouting(nn.Module):
    def __init__(self, n_agents, vocab_size, model_dim, out_dim, n_heads=2):
        super().__init__()
        self.backbone = Backbone(vocab_size, model_dim, n_heads=n_heads)
        agents = nn.ModuleList([AgentFFN(model_dim, out_dim) for _ in range(n_agents)])
        self.dual_routing_module = DualRoutingModule(model_dim, n_agents, agents)

    def forward(self, x, user_id=None, mode='dynamic', return_routing=False):
        shared = self.backbone(x)  # (batch, model_dim)
        return self.dual_routing_module(shared, user_id=user_id, mode=mode, return_routing=return_routing)

# class SimpleLearnableCoordinator(nn.Module):
#     """
#     Minimal learnable coordinator: for demonstration, returns a fixed or randomly chosen subset of agent_ids.
#     Replace logic as needed for your application.
#     """
#     def __init__(self, n_agents, n_select=2):
#         super().__init__()
#         self.n_agents = n_agents
#         self.n_select = n_select
#         self.selector = nn.Linear(1, n_agents)  # Just a dummy for illustration

#     def forward(self, x, agent_ids):
#         # For demonstration: randomly select n_select agents (could use logits from self.selector)
#         if len(agent_ids) <= self.n_select:
#             return agent_ids
#         selected = torch.randperm(len(agent_ids))[:self.n_select].tolist()
#         return [agent_ids[i] for i in selected]

class LearnableCoordinator(nn.Module):
    def __init__(self, model_dim, n_agents, n_select=2):
        super().__init__()
        self.n_agents = n_agents
        self.n_select = n_select
        self.selector = nn.Linear(model_dim, n_agents)  # Takes backbone features

    def forward(self, features, agent_ids):
        # features: (batch, model_dim) or (model_dim,) if batch=1
        if features.dim() == 1:
            features = features.unsqueeze(0)
        logits = self.selector(features)  # (batch, n_agents)
        probs = torch.softmax(logits, dim=-1)
        # Select top-k agents for each input in the batch
        selected_indices = torch.topk(probs, self.n_select, dim=-1).indices
        # Convert indices to agent IDs for each batch item
        selected_agents = []
        for i in range(features.size(0)):
            selected_agents.append([agent_ids[j] for j in selected_indices[i].tolist()])
        return selected_agents  # list of list of agent IDs (per batch)

class CAC:
    def __init__(self, model, coordinator=None, workflow=None):
        self.model = model
        self.shared_memory = {}
        self.coordinator = coordinator      # Should be a callable (e.g., a neural net)
        self.workflow = workflow            # List of agent IDs (workflow order)

    def communicate(self, outputs, protocol=None):
        communicated = []
        for idx, out in enumerate(outputs):
            # Example metadata: agent index, completion status, dummy confidence
            metadata = {
                'agent_idx': idx,
                'status': 'complete',
                'confidence': float(torch.rand(1))  # Simulated confidence
            }
            # Structured message (could be JSON serializable)
            message = {
                'output': out,
                'metadata': metadata
            }
            # Example rule: Only share if confidence > threshold (can set via protocol)
            threshold = protocol.get('confidence_threshold', 0.0) if protocol else 0.0
            if metadata['confidence'] > threshold:
                print(f"[Agentic][CAC][Communicate] Sharing Agent {idx} output with confidence {metadata['confidence']:.2f}")
                communicated.append({'output': out, 'metadata': {'agent_idx': idx}})
            else:
                print(f"[Agentic][CAC][Communicate] Agent {idx} output NOT shared (confidence {metadata['confidence']:.2f})")
        # Return only the outputs for downstream processing
        return [msg['output'] for msg in communicated]

    def update_shared_memory(self, agent_id, data):
        print(f"[Agentic][CAC][SharedMemory] Updating memory for agent {agent_id}.")
        self.shared_memory[agent_id] = data

    # def decompose_and_route(self, x, agents=[0,1,2,3], vocab=None):
    #     # If x is a dict and 'text' present, do AG News-specific split
    #     if isinstance(x, dict) and 'text' in x:
    #         assert vocab is not None, "vocab must be provided for raw article input"
    #         text = x['text']
    #         # Naive headline/body split
    #         parts = text.split('.', 1)
    #         headline = parts[0]
    #         body = parts[1] if len(parts) > 1 else ""
            
    #         def encode_text(txt):
    #             tokens = [vocab[token] for token in txt.lower().split()][:8]
    #             if len(tokens) == 0:
    #                 tokens = [vocab['<unk>']]  # Fallback: single unknown token
    #             return pad_sequence(
    #                 [torch.tensor(tokens, dtype=torch.long)],
    #                 batch_first=True,
    #                 padding_value=0
    #             )

    #         headline_encoded = encode_text(headline)
    #         body_encoded = encode_text(body)
            
    #         subtasks = [headline_encoded, body_encoded]
    #         routed_agents = [agents[0], agents[1]]
            
    #         # Optionally, expand to agents 2 and 3 for more complex coordination
    #         if len(agents) > 2:
    #             subtasks.append(body_encoded)       # or another aspect (e.g., summary)
    #             routed_agents.append(agents[2])
    #         if len(agents) > 3:
    #             subtasks.append(headline_encoded)   # or another aspect (e.g., sentiment)
    #             routed_agents.append(agents[3])
    #         return subtasks, routed_agents
    #     else:
    #         # Fallback: just send x to all agents (standard CAC pattern)
    #         xs = [x for _ in agents]
    #         return xs, agents

    # def aggregate(self, outputs, method='mean'):
    #     if method == 'mean':
    #         return torch.mean(torch.stack(outputs, dim=0), dim=0)
    #     elif method == 'majority':
    #         preds = [out.argmax(dim=-1).item() for out in outputs]
    #         return max(set(preds), key=preds.count)
    #     elif method == 'softmax_mean':
    #         probs = [F.softmax(out, dim=-1) for out in outputs]
    #         mean_probs = torch.mean(torch.stack(probs, dim=0), dim=0)
    #         return mean_probs.argmax(dim=-1).item()
    #     else:
    #         raise NotImplementedError(f"Aggregation '{method}' not supported.")

    def forward(self, x, agent_ids, user_id=None, mode='static', aggregation='mean'):
        # Get features from frozen backbone
        with torch.no_grad():
            features = self.model.backbone(x)  # (batch, model_dim)

        # # Workflow-based coordinator takes precedence if defined
        # if self.workflow is not None:
        #     print("\n[Workflow Coordinator] Using specified workflow:", self.workflow)
        #     agent_ids = self.workflow
        # # Otherwise, use learnable coordinator if defined
        # elif self.coordinator is not None:
        #     print("\n[Learnable Coordinator] Selecting agents dynamically...")
        #     selected_agents = self.coordinator(features, agent_ids)
        # else:
        #     selected_agents = [agent_ids for _ in range(x.size(0))]  # fallback: all agents

        if self.coordinator is not None:
            print("\n[Agentic][CAC][Learnable Coordinator] Selecting agents dynamically...")
            selected_agents = self.coordinator(features, agent_ids)
        else:
            selected_agents = [agent_ids for _ in range(x.size(0))]  # fallback: all agents

        # xs, agent_ids = self.decompose_and_route(x, agent_ids)

        outputs = []
        # For each sample in batch
        for i in range(x.size(0)):
            sample_outputs = []
            for agent_id in selected_agents[i]:
                out = self.model.dual_routing_module.agents[agent_id](features[i].unsqueeze(0))
                self.update_shared_memory(agent_id, out)
                sample_outputs.append(out)
            # Aggregate outputs for this sample (mean)
            sample_agg = torch.mean(torch.stack(sample_outputs, dim=0), dim=0)
            outputs.append(sample_agg)
        outputs = torch.stack(outputs, dim=0)
        return outputs        
# --------- 3. Model Initialization ---------

# After building vocab and preparing X, y, user_ids:
n_agents = 4                       # AG News has 4 classes (World, Sports, Business, Sci/Tech)
vocab_size = len(vocab)            # Set vocab_size to the size of your vocab
model_dim = 32                     # (or another dimension—keep consistent with model)
out_dim = 4                        # Number of classes in AG News

# Now initialize the model as before:
model = AgenticTransformerDualRouting(n_agents, vocab_size, model_dim, out_dim)
clf_loss_fn = nn.CrossEntropyLoss()
router_loss_fn = nn.CrossEntropyLoss()  # Supervision: agent == class
lb_lambda = 0.01  # load-balancing
router_lambda = 1.0  # supervised router


# --------- 4. Joint Training (Dynamic Routing) ---------
lb_lambda = 3  # load balancing weight

optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 5
batch_size = 32

# Add this before your loop
entropy_lambda = 0.05  # You can tune this value; start small!

for epoch in range(epochs):
    model.train()
    total_ce = 0
    total_lb = 0
    total_router = 0
    total_entropy = 0
    routing_counts = [0 for _ in range(n_agents)]
    agent_probs_sum = torch.zeros(n_agents)
    for batch_start in range(0, n_samples, batch_size):
        X_batch = X[batch_start:batch_start+batch_size]
        y_batch = y[batch_start:batch_start+batch_size]
        out, logits, probs = model(X_batch, user_id=None, mode='dynamic', return_routing=True)
        # Main classification loss
        ce_loss = clf_loss_fn(out, y_batch)
        # Supervised router loss: force router to send label i to expert i
        router_loss = router_loss_fn(logits, y_batch)
        # Load balancing (softmax entropy penalty)
        probs_mean = probs.mean(dim=0)
        lb_loss = ((probs_mean - 1.0/n_agents) ** 2).sum()
        # Entropy regularization (maximize entropy)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).mean()
        # Total loss
        loss = ce_loss + router_lambda * router_loss + lb_lambda * lb_loss + entropy_lambda * entropy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_ce += ce_loss.item() * X_batch.size(0)
        total_router += router_loss.item() * X_batch.size(0)
        total_lb += lb_loss.item() * X_batch.size(0)
        total_entropy += entropy.item() * X_batch.size(0)
        # Routing stat
        routed = probs.argmax(dim=-1)
        for idx in routed.tolist():
            routing_counts[idx] += 1
        agent_probs_sum += probs.sum(dim=0).detach()
    print(f"[Agentic][Dynamic Routing][Training] Epoch {epoch+1} | CE: {total_ce/n_samples:.3f} | LB: {total_lb/n_samples:.3f} | Router: {total_router/n_samples:.3f} | Entropy: {total_entropy/n_samples:.3f}")
    print("[Agentic][Dynamic Routing][Training] Hard assignment counts per agent:", routing_counts)
    print("[Agentic][Dynamic Routing][Training] Mean softmax probability per agent:", (agent_probs_sum / n_samples).tolist())

# 4. Per-agent accuracy (using only samples routed to that agent)
def evaluate_per_agent_handled(model, X, y, handled_by, n_agents, mode):
    per_agent_acc = []
    model.eval()
    with torch.no_grad():
        for agent in range(n_agents):
            idxs = [i for i, assigned in enumerate(handled_by) if assigned == agent]
            if len(idxs) == 0:
                per_agent_acc.append(float('nan'))
                continue
            correct = 0
            for i in idxs:
                if mode == 'dynamic':
                    out = model(X[i].unsqueeze(0), mode=mode)
                else:
                    out = model(X[i].unsqueeze(0), user_id=agent, mode=mode)
                pred = out.argmax(dim=1).item()
                correct += int(pred == y[i].item())
            per_agent_acc.append(correct / len(idxs))
    return per_agent_acc
    

# 5. Confusion Matrix (True class vs Routed agent)
conf_mat = torch.zeros((n_agents, n_agents), dtype=torch.long)
model.eval()
all_routed = []
with torch.no_grad():
    for i in range(n_samples):
        _, logits, probs = model(X[i].unsqueeze(0), return_routing=True)
        routed_agent = probs.argmax(dim=-1).item()
        conf_mat[y[i], routed_agent] += 1
        all_routed.append(routed_agent)
print("\n[Agentic][Dynamic Routing] Confusion Matrix: rows=True class, cols=Routed agent")
print(conf_mat)

# all_routed should already be a list of agent assignments by the router for each sample
per_agent_acc_dynamic = evaluate_per_agent_handled(model, X, y, all_routed, n_agents, mode='dynamic')
print("[Agentic][Dynamic Routing][Testing] Per-agent accuracy (handled samples):", per_agent_acc_dynamic)

unique, counts = torch.tensor(all_routed).unique(return_counts=True)
dist = {int(u): int(c) for u, c in zip(unique, counts)}
print(f"[Agentic][Dynamic Routing][Testing] Final routing distribution: {dist}")
# --------- 5. Post-Deployment: Freeze backbone and other agents ---------
for param in model.backbone.parameters():
    param.requires_grad = False

target_agent = 0
for i, agent in enumerate(model.dual_routing_module.agents):
    if i != target_agent:
        for param in agent.parameters():
            param.requires_grad = False

# --------- 6. Independent Fine-Tuning for Agent 0 (Static Routing) ---------
def print_agent_outputs(model, X, n_agents):
    for agent_id in range(n_agents):
        out = model(X[0].unsqueeze(0), user_id=agent_id, mode='static')
        print(f"[Agentic] Agent {agent_id} output: {out.detach().cpu().numpy()}")

print("\n[Agentic][Static Routing] Agent outputs BEFORE:")
print_agent_outputs(model, X, n_agents)

# Use real AG News samples for fine-tuning agent 0
# target_agent = 0  # Or another agent index
# X_new = X[:20]
# user_ids_new = torch.ones(20, dtype=torch.long) * target_agent
# y_new = y[:20]

target_agent = 0
idxs = (y == target_agent).nonzero(as_tuple=True)[0]
X_new = X[idxs]
y_new = y[idxs]

optimizer = torch.optim.Adam(model.dual_routing_module.agents[target_agent].parameters(), lr=5e-4)

print("\n[Agentic][Static Routing][Training] Agent 0 fine-tuning ...")
for epoch in range(3):
    total_loss = 0
    for i in range(len(X_new)):
        # Static routing: always assign to agent 0
        out = model(X_new[i].unsqueeze(0), user_id=target_agent, mode='static')
        loss = clf_loss_fn(out, y_new[i].unsqueeze(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Agentic][Static Routing][Agent {target_agent} Fine-Tuning] Epoch {epoch+1} | Loss: {total_loss:.3f}")

print("\n[Agentic][Static Routing] Agent outputs AFTER agent 0 fine-tuning:")
print_agent_outputs(model, X, n_agents)

def evaluate_per_agent_static_routing(model, X, y, n_agents):
    results = []
    model.eval()
    with torch.no_grad():
        for agent_id in range(n_agents):
            idxs = (y == agent_id).nonzero(as_tuple=True)[0]
            correct = 0
            total = 0
            for i in idxs:
                out = model(X[i].unsqueeze(0), user_id=agent_id, mode='static')
                pred = out.argmax(dim=1).item()
                correct += (pred == y[i].item())
                total += 1
            acc = correct / total if total > 0 else 0
            results.append(acc)
    return results

# For static routing, each sample is assigned to its label as agent
handled_by_static = [label.item() for label in y]
per_agent_acc_static = evaluate_per_agent_handled(model, X, y, handled_by_static, n_agents, mode='static')
print("[Agentic][Static Routing][Testing] Per-agent accuracy (handled samples):", per_agent_acc_static)

# --------- 7. CAC Scenario: Multi-agent Coordination Modes ---------

# (A) Parallel aggregation (default)

# Majority Voting aggregation
def majority_vote(outputs):
    preds = [out.argmax(dim=-1).item() for out in outputs]
    # In case of ties, returns the smallest class
    return max(set(preds), key=preds.count)

# Softmax Mean aggregation
def softmax_mean(outputs):
    probs = [F.softmax(out, dim=-1) for out in outputs]
    mean_probs = torch.mean(torch.stack(probs, dim=0), dim=0)
    return mean_probs.argmax(dim=-1).item()
    
cac_parallel = CAC(model)

input_example = X[2].unsqueeze(0)
agents_to_coord = [0, 1, 2, 3]

print("\n[Agentic][CAC][Parallel Aggregation] For input X[2], aggregate all 4 agents (Majority Voting & Softmax Mean):")

# Get each agent's output
agent_outputs_A = []
for aid in agents_to_coord:
    out = model(input_example, user_id=aid, mode='static')
    agent_outputs_A.append(out)
    print(f"[Agentic][CAC] Individual Agent {aid} output: {out.detach()}")

# # Majority voting aggregation
# maj_vote_pred = majority_vote(agent_outputs_A)
# print(f"Majority Vote predicted class: {maj_vote_pred}")

# # Softmax mean aggregation
# softmax_mean_pred = softmax_mean(agent_outputs_A)
# print(f"Softmax Mean predicted class: {softmax_mean_pred}")

# # For comparison: mean aggregation (as before)
# mean_output = torch.mean(torch.stack(aagent_outputs_A, dim=0), dim=0)
# mean_pred = mean_output.argmax(dim=-1).item()
# print(f"Mean (raw logits) predicted class: {mean_pred}")

agent_outputs_A = torch.cat(agent_outputs_A, dim=0)  # shape: [4, 1, num_classes]
# Aggregate (mean or softmax/majority vote as before)
mean_output_A = agent_outputs_A.mean(dim=0)  # shape: [num_classes]
print("[Agentic][CAC] Mean (raw logits) predicted class:", mean_output_A.argmax(dim=0).item())

# (B) Learnable coordinator (random selection for illustration)
# Freeze backbone and agents
for param in model.backbone.parameters():
    param.requires_grad = False
for agent in model.dual_routing_module.agents:
    for param in agent.parameters():
        param.requires_grad = False

# Define coordinator (to be trained)
model_dim = 32
n_agents = 4
coordinator = LearnableCoordinator(model_dim, n_agents, n_select=2)
cac = CAC(model, coordinator=coordinator)

# Example batch
# batch_X = X[:8]  # Batch of 8 samples
agent_ids = list(range(n_agents))

# Forward pass (agent selection is dynamic, learned)
print("\n[Agentic][CAC][Learnable Coordinator] Dynamic selection for input X[2]:")
agent_outputs_B = cac.forward(input_example, agent_ids=agent_ids, mode='static')
print(f"[Agentic][CAC] Aggregated output (learnable selection): {agent_outputs_B.detach()}")

# To see which agents were picked:
with torch.no_grad():
    features = model.backbone(input_example)
    selected_agents = coordinator(features, agent_ids)[0]
    print("[Agentic] Agents selected by coordinator:", selected_agents)
    for aid in selected_agents:
        out = model(input_example, user_id=aid, mode='static')
        print(f"[Agentic][CAC] Selected Agent {aid} output: {out.detach()}")

mean_output_B = agent_outputs_B.mean(dim=0)  # shape: [num_classes]
pred_class = mean_output_B.argmax().item()
print("[Agentic][CAC] Mean (raw logits) predicted class:", pred_class)

# (C) Workflow-based coordinator (sequential handoff)
# workflow = [0, 2]
# cac_workflow = CAC(model, workflow=workflow)

# print("\n[Workflow Coordinator] Sequential workflow for input X[2]:")
# workflow_output = cac_workflow.forward(input_example, agent_ids=list(range(n_agents)), mode='static')
# print(f"Workflow output (final agent): {workflow_output.detach()}")

# (D) Task decomposition and routing (example usage of decompose_and_route)
# sample_idx = 10
# article = dataset[sample_idx]
# print("\n[Decompose and Route]:")    
# subtasks, routed_agents = cac_parallel.decompose_and_route(article, agents=[0,1,2,3], vocab=vocab)
# for subtask, aid in zip(subtasks, routed_agents):
#     out = model(subtask, user_id=aid, mode='static')
#     print(f"Agent {aid} processed: {out.detach().cpu().numpy()}")

# # Optionally, aggregate the outputs (e.g., mean or majority vote)
# outputs = [model(subtask, user_id=aid, mode='static') for subtask, aid in zip(subtasks, routed_agents)]
# agg_output = torch.mean(torch.stack(outputs, dim=0), dim=0)
# print(f"Aggregated output (headline+body): {agg_output.detach()}")

[Agentic][Dynamic Routing][Training] Epoch 1 | CE: 1.418 | LB: 0.072 | Router: 1.295 | Entropy: 1.247
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [627, 138, 516, 719]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.2396836280822754, 0.20570972561836243, 0.2349652498960495, 0.3196414113044739]
[Agentic][Dynamic Routing][Training] Epoch 2 | CE: 1.361 | LB: 0.015 | Router: 1.358 | Entropy: 1.350
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [666, 89, 231, 1014]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.2457910031080246, 0.2177312970161438, 0.23370280861854553, 0.30277499556541443]
[Agentic][Dynamic Routing][Training] Epoch 3 | CE: 1.323 | LB: 0.015 | Router: 1.333 | Entropy: 1.346
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [626, 107, 318, 949]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.2408578097820282, 0.22088824212