In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np

# 1. Load AG News data: train and test splits
train_dataset = load_dataset('ag_news', split='train')
test_dataset  = load_dataset('ag_news', split='test')

# 2. Build vocab on *training data only*
tokenizer = lambda s: s.lower().split()
vocab = build_vocab_from_iterator((tokenizer(x['text']) for x in train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 3. Encode samples (train & test separately)
def encode(text):
    tokens = tokenizer(text)
    return torch.tensor([vocab[token] for token in tokens][:8], dtype=torch.long)  # seq_len=8

X_train = [encode(sample['text']) for sample in train_dataset]
X_train = pad_sequence(X_train, batch_first=True, padding_value=0)
y_train = torch.tensor([sample['label'] for sample in train_dataset])

X_test = [encode(sample['text']) for sample in test_dataset]
X_test = pad_sequence(X_test, batch_first=True, padding_value=0)
y_test = torch.tensor([sample['label'] for sample in test_dataset])

n_samples_train = len(X_train)
n_samples_test  = len(X_test)

# 4. Model setup (as in your code, unchanged)
class Backbone(nn.Module):
    def __init__(self, input_dim, model_dim, n_heads):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, model_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 10, model_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1)]
        out = self.encoder(x)
        return out[:, 0, :]
        
class AgentFFN(nn.Module):
    def __init__(self, model_dim, out_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, out_dim)
        )
    def forward(self, x):
        return self.ffn(x)

class RoutingNetwork(nn.Module):
    def __init__(self, model_dim, n_agents):
        super().__init__()
        self.linear = nn.Linear(model_dim, n_agents)
    def forward(self, features):
        logits = self.linear(features)
        probs = torch.softmax(logits, dim=-1)
        return logits, probs

class AssignmentModule:
    def __init__(self, n_agents):
        self.n_agents = n_agents
    def __call__(self, user_id):
        if isinstance(user_id, torch.Tensor):
            return (user_id % self.n_agents).item() 
        else:
            return user_id % self.n_agents

class DualRoutingModule(nn.Module):
    def __init__(self, model_dim, n_agents, agents):
        super().__init__()
        self.routing_network = RoutingNetwork(model_dim, n_agents)
        self.assignment_module = AssignmentModule(n_agents)
        self.agents = agents
    def forward(self, features, user_id=None, mode='dynamic', return_routing=False):
        batch_size = features.size(0)
        outputs = []
        if mode == 'dynamic':
            logits, probs = self.routing_network(features)
            agent_indices = torch.argmax(probs, dim=-1)
            for i in range(batch_size):
                ai = agent_indices[i].item()
                outputs.append(self.agents[ai](features[i:i+1]))
            outputs = torch.cat(outputs, dim=0)
            if return_routing:
                return outputs, logits, probs
            else:
                return outputs
        elif mode == 'static':
            assert user_id is not None, "user_id required for static routing"
            agent_idx = self.assignment_module(user_id)
            out = self.agents[agent_idx](features)
            return out
        else:
            raise ValueError("mode must be 'dynamic' or 'static'")

class AgenticTransformerDualRouting(nn.Module):
    def __init__(self, n_agents, vocab_size, model_dim, out_dim, n_heads=2):
        super().__init__()
        self.backbone = Backbone(vocab_size, model_dim, n_heads=n_heads)
        agents = nn.ModuleList([AgentFFN(model_dim, out_dim) for _ in range(n_agents)])
        self.dual_routing_module = DualRoutingModule(model_dim, n_agents, agents)
    def forward(self, x, user_id=None, mode='dynamic', return_routing=False):
        shared = self.backbone(x)
        return self.dual_routing_module(shared, user_id=user_id, mode=mode, return_routing=return_routing)

class LearnableCoordinator(nn.Module):
    def __init__(self, model_dim, n_agents, n_select=2):
        super().__init__()
        self.n_agents = n_agents
        self.n_select = n_select
        self.selector = nn.Linear(model_dim, n_agents)  # Takes backbone features

    def forward(self, features, agent_ids):
        # features: (batch, model_dim) or (model_dim,) if batch=1
        if features.dim() == 1:
            features = features.unsqueeze(0)
        logits = self.selector(features)  # (batch, n_agents)
        probs = torch.softmax(logits, dim=-1)
        # Select top-k agents for each input in the batch
        selected_indices = torch.topk(probs, self.n_select, dim=-1).indices
        # Convert indices to agent IDs for each batch item
        selected_agents = []
        for i in range(features.size(0)):
            selected_agents.append([agent_ids[j] for j in selected_indices[i].tolist()])
        return selected_agents  # list of list of agent IDs (per batch)

class CAC:
    def __init__(self, model, coordinator=None, workflow=None):
        self.model = model
        self.shared_memory = {}
        self.coordinator = coordinator      # Should be a callable (e.g., a neural net)
        self.workflow = workflow            # List of agent IDs (workflow order)

    def communicate(self, outputs, protocol=None):
        communicated = []
        for idx, out in enumerate(outputs):
            # Example metadata: agent index, completion status, dummy confidence
            metadata = {
                'agent_idx': idx,
                'status': 'complete',
                'confidence': float(torch.rand(1))  # Simulated confidence
            }
            # Structured message (could be JSON serializable)
            message = {
                'output': out,
                'metadata': metadata
            }
            # Example rule: Only share if confidence > threshold (can set via protocol)
            threshold = protocol.get('confidence_threshold', 0.0) if protocol else 0.0
            if metadata['confidence'] > threshold:
                print(f"[Agentic][CAC][Communicate] Sharing Agent {idx} output with confidence {metadata['confidence']:.2f}")
                communicated.append({'output': out, 'metadata': {'agent_idx': idx}})
            else:
                print(f"[Agentic][CAC][Communicate] Agent {idx} output NOT shared (confidence {metadata['confidence']:.2f})")
        # Return only the outputs for downstream processing
        return [msg['output'] for msg in communicated]

    def update_shared_memory(self, agent_id, data):
        print(f"[Agentic][CAC][SharedMemory] Updating memory for agent {agent_id}.")
        self.shared_memory[agent_id] = data

    def forward(self, x, agent_ids, user_id=None, mode='static', aggregation='mean'):
        # Get features from frozen backbone
        with torch.no_grad():
            features = self.model.backbone(x)  # (batch, model_dim)
        if self.coordinator is not None:
            print("\n[Agentic][CAC][Learnable Coordinator] Selecting agents dynamically...")
            selected_agents = self.coordinator(features, agent_ids)
        else:
            selected_agents = [agent_ids for _ in range(x.size(0))]  # fallback: all agents

        # xs, agent_ids = self.decompose_and_route(x, agent_ids)

        outputs = []
        # For each sample in batch
        for i in range(x.size(0)):
            sample_outputs = []
            for agent_id in selected_agents[i]:
                out = self.model.dual_routing_module.agents[agent_id](features[i].unsqueeze(0))
                self.update_shared_memory(agent_id, out)
                sample_outputs.append(out)
            # Aggregate outputs for this sample (mean)
            sample_agg = torch.mean(torch.stack(sample_outputs, dim=0), dim=0)
            outputs.append(sample_agg)
        outputs = torch.stack(outputs, dim=0)
        return outputs        
        

# --------- 1. Model Initialization ---------
n_agents = 4
vocab_size = len(vocab)
model_dim = 32
out_dim = 4

model = AgenticTransformerDualRouting(n_agents, vocab_size, model_dim, out_dim)
clf_loss_fn = nn.CrossEntropyLoss()
router_loss_fn = nn.CrossEntropyLoss()
lb_lambda = 3
router_lambda = 1.0
entropy_lambda = 0.05
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 5
batch_size = 32

# --------- 2. Joint Training (Dynamic Routing) ---------
for epoch in range(epochs):
    model.train()
    total_ce = 0
    total_lb = 0
    total_router = 0
    total_entropy = 0
    routing_counts = [0 for _ in range(n_agents)]
    agent_probs_sum = torch.zeros(n_agents)
    for batch_start in range(0, n_samples_train, batch_size):
        X_batch = X_train[batch_start:batch_start+batch_size]
        y_batch = y_train[batch_start:batch_start+batch_size]
        out, logits, probs = model(X_batch, user_id=None, mode='dynamic', return_routing=True)
        ce_loss = clf_loss_fn(out, y_batch)
        router_loss = router_loss_fn(logits, y_batch)
        probs_mean = probs.mean(dim=0)
        lb_loss = ((probs_mean - 1.0/n_agents) ** 2).sum()
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).mean()
        loss = ce_loss + router_lambda * router_loss + lb_lambda * lb_loss + entropy_lambda * entropy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_ce += ce_loss.item() * X_batch.size(0)
        total_router += router_loss.item() * X_batch.size(0)
        total_lb += lb_loss.item() * X_batch.size(0)
        total_entropy += entropy.item() * X_batch.size(0)
        routed = probs.argmax(dim=-1)
        for idx in routed.tolist():
            routing_counts[idx] += 1
        agent_probs_sum += probs.sum(dim=0).detach()
    print(f"[Agentic][Dynamic Routing][Training] Epoch {epoch+1} | CE: {total_ce/n_samples_train:.3f} | LB: {total_lb/n_samples_train:.3f} | Router: {total_router/n_samples_train:.3f} | Entropy: {total_entropy/n_samples_train:.3f}")
    print("[Agentic][Dynamic Routing][Training] Hard assignment counts per agent:", routing_counts)
    print("[Agentic][Dynamic Routing][Training] Mean softmax probability per agent:", (agent_probs_sum / n_samples_train).tolist())

# --------- 3. Per-agent accuracy (using only samples routed to that agent). Evaluation on TEST SET (never used during training)
def evaluate_per_agent_handled(model, X, y, handled_by, n_agents, mode):
    per_agent_acc = []
    model.eval()
    with torch.no_grad():
        for agent in range(n_agents):
            idxs = [i for i, assigned in enumerate(handled_by) if assigned == agent]
            if len(idxs) == 0:
                per_agent_acc.append(float('nan'))
                continue
            correct = 0
            for i in idxs:
                if mode == 'dynamic':
                    out = model(X[i].unsqueeze(0), mode=mode)
                else:
                    out = model(X[i].unsqueeze(0), user_id=agent, mode=mode)
                pred = out.argmax(dim=1).item()
                correct += int(pred == y[i].item())
            per_agent_acc.append(correct / len(idxs))
    return per_agent_acc

# --------- 4. Confusion Matrix (True class vs Routed agent) ---------

conf_mat = torch.zeros((n_agents, n_agents), dtype=torch.long)
model.eval()
all_routed = []
with torch.no_grad():
    for i in range(n_samples_test):
        _, logits, probs = model(X_test[i].unsqueeze(0), return_routing=True)
        routed_agent = probs.argmax(dim=-1).item()
        conf_mat[y_test[i], routed_agent] += 1
        all_routed.append(routed_agent)
print("\n[Agentic][Dynamic Routing] Confusion Matrix: rows=True class, cols=Routed agent")
print(conf_mat)

per_agent_acc_dynamic = evaluate_per_agent_handled(model, X_test, y_test, all_routed, n_agents, mode='dynamic')
print("[Agentic][Dynamic Routing][Testing] Per-agent accuracy (handled samples):", per_agent_acc_dynamic)

unique, counts = torch.tensor(all_routed).unique(return_counts=True)
dist = {int(u): int(c) for u, c in zip(unique, counts)}
print(f"[Agentic][Dynamic Routing][Testing] Final routing distribution: {dist}")

# Test overall accuracy
correct = 0
with torch.no_grad():
    for i in range(n_samples_test):
        out = model(X_test[i].unsqueeze(0), mode='dynamic')
        pred = out.argmax(dim=1).item()
        correct += int(pred == y_test[i].item())
test_acc = correct / n_samples_test
print("[Agentic][Dynamic Routing][Testing] Overall test accuracy:", test_acc)

# --------- 5. Post-Deployment: Freeze backbone and other agents ---------
for param in model.backbone.parameters():
    param.requires_grad = False

target_agent = 0
for i, agent in enumerate(model.dual_routing_module.agents):
    if i != target_agent:
        for param in agent.parameters():
            param.requires_grad = False

# --------- 6. Independent Fine-Tuning for Agent 0 (Static Routing) ---------
def print_agent_outputs(model, X, n_agents):
    for agent_id in range(n_agents):
        out = model(X[0].unsqueeze(0), user_id=agent_id, mode='static')
        print(f"[Agentic] Agent {agent_id} output: {out.detach().cpu().numpy()}")

print("\n[Agentic][Static Routing] Agent outputs BEFORE:")
print_agent_outputs(model, X_test, n_agents)  # Use test set to inspect agent outputs

# ... Fine-tuning on agent 0 (pick data from train set)
target_agent = 0
idxs = (y_train == target_agent).nonzero(as_tuple=True)[0]
X_new = X_train[idxs]
y_new = y_train[idxs]

optimizer = torch.optim.Adam(model.dual_routing_module.agents[target_agent].parameters(), lr=5e-4)

print("\n[Agentic][Static Routing][Training] Agent 0 fine-tuning ...")
for epoch in range(3):
    total_loss = 0
    for i in range(len(X_new)):
        out = model(X_new[i].unsqueeze(0), user_id=target_agent, mode='static')
        loss = clf_loss_fn(out, y_new[i].unsqueeze(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Agentic][Static Routing][Agent {target_agent} Fine-Tuning] Epoch {epoch+1} | Loss: {total_loss:.3f}")

print("\n[Agentic][Static Routing] Agent outputs AFTER agent 0 fine-tuning:")
print_agent_outputs(model, X_test, n_agents)  # Use test set for after

def evaluate_per_agent_static_routing(model, X, y, n_agents):
    results = []
    model.eval()
    with torch.no_grad():
        for agent_id in range(n_agents):
            idxs = (y == agent_id).nonzero(as_tuple=True)[0]
            correct = 0
            total = 0
            for i in idxs:
                out = model(X[i].unsqueeze(0), user_id=agent_id, mode='static')
                pred = out.argmax(dim=1).item()
                correct += (pred == y[i].item())
                total += 1
            acc = correct / total if total > 0 else 0
            results.append(acc)
    return results

# For static routing, each sample is assigned to its label as agent
handled_by_static = [label.item() for label in y_test]
per_agent_acc_static = evaluate_per_agent_handled(model, X_test, y_test, handled_by_static, n_agents, mode='static')
print("[Agentic][Static Routing][Testing] Per-agent accuracy (handled samples):", per_agent_acc_static)

# --------- 7. CAC Scenario: Multi-agent Coordination Modes ---------

# (A) Parallel aggregation (default)

# Majority Voting aggregation
def majority_vote(outputs):
    preds = [out.argmax(dim=-1).item() for out in outputs]
    # In case of ties, returns the smallest class
    return max(set(preds), key=preds.count)

# Softmax Mean aggregation
def softmax_mean(outputs):
    probs = [F.softmax(out, dim=-1) for out in outputs]
    mean_probs = torch.mean(torch.stack(probs, dim=0), dim=0)
    return mean_probs.argmax(dim=-1).item()
    
cac_parallel = CAC(model)

input_example = X_test[2].unsqueeze(0)
agents_to_coord = [0, 1, 2, 3]

print("\n[Agentic][CAC][Parallel Aggregation] For input X_test[2], aggregate all 4 agents (Majority Voting & Softmax Mean):")
# Get each agent's output
agent_outputs_A = []
for aid in agents_to_coord:
    out = model(input_example, user_id=aid, mode='static')
    agent_outputs_A.append(out)
    print(f"[Agentic][CAC] Individual Agent {aid} output: {out.detach()}")
agent_outputs_A = torch.cat(agent_outputs_A, dim=0)
mean_output_A = agent_outputs_A.mean(dim=0)
print("[Agentic][CAC] Mean (raw logits) predicted class:", mean_output_A.argmax(dim=0).item())

# (B) Learnable coordinator (random selection for illustration)
# Freeze backbone and agents
for param in model.backbone.parameters():
    param.requires_grad = False
for agent in model.dual_routing_module.agents:
    for param in agent.parameters():
        param.requires_grad = False

# Define coordinator (to be trained)
model_dim = 32
n_agents = 4
coordinator = LearnableCoordinator(model_dim, n_agents, n_select=2)
cac = CAC(model, coordinator=coordinator)

# Example batch
# batch_X = X[:8]  # Batch of 8 samples
agent_ids = list(range(n_agents))

# Forward pass (agent selection is dynamic, learned)
print("\n[Agentic][CAC][Learnable Coordinator] Dynamic selection for input X_test[2]:")
agent_outputs_B = cac.forward(input_example, agent_ids=agent_ids, mode='static')
print(f"[Agentic][CAC] Aggregated output (learnable selection): {agent_outputs_B.detach()}")

# To see which agents were picked:
with torch.no_grad():
    features = model.backbone(input_example)
    selected_agents = coordinator(features, agent_ids)[0]
    print("[Agentic] Agents selected by coordinator:", selected_agents)
    for aid in selected_agents:
        out = model(input_example, user_id=aid, mode='static')
        print(f"[Agentic][CAC] Selected Agent {aid} output: {out.detach()}")
mean_output_B = agent_outputs_B.mean(dim=0)
pred_class = mean_output_B.argmax().item()
print("[Agentic][CAC] Mean (raw logits) predicted class:", pred_class)

[Agentic][Dynamic Routing][Training] Epoch 1 | CE: 1.005 | LB: 0.017 | Router: 1.002 | Entropy: 1.035
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [31456, 33178, 26232, 29134]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.24964860081672668, 0.24977873265743256, 0.2515988051891327, 0.24897359311580658]
[Agentic][Dynamic Routing][Training] Epoch 2 | CE: 0.636 | LB: 0.031 | Router: 0.636 | Entropy: 0.678
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [29073, 31942, 28929, 30056]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.25011909008026123, 0.2484709769487381, 0.25215843319892883, 0.24925167858600616]
[Agentic][Dynamic Routing][Training] Epoch 3 | CE: 0.491 | LB: 0.038 | Router: 0.492 | Entropy: 0.533
[Agentic][Dynamic Routing][Training] Hard assignment counts per agent: [29002, 31358, 29413, 30227]
[Agentic][Dynamic Routing][Training] Mean softmax probability per agent: [0.2504