In [None]:
import os
import glob
import torch
import networkx as nx
import numpy as np
from scipy.spatial.distance import pdist
from scipy.stats import spearmanr, rankdata, entropy
from scipy import stats
from typing import List, Dict, Any, Optional
import matplotlib.pyplot as plt
import pandas as pd
import random

In [None]:
def _set_global_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
_set_global_seed(42)

In [None]:
def load_latest_interaction_file(
    logs_root: str,
    seed_folder: str,
    split: str = "validation",
    prefix: str = "interaction_gpu0"
) -> object:
    pattern = os.path.join(
        logs_root,
        seed_folder,
        "interactions",
        split,
        "epoch_*",
        f"{prefix}*"
    )
    files = glob.glob(pattern)
    if not files:
        raise FileNotFoundError(f"No files matching {pattern}")

    # parse epoch number from folder name
    def parse_epoch(path: str) -> int:
        folder = os.path.basename(os.path.dirname(path))
        return int(folder.split("_", 1)[1])

    # sort files by epoch and select the last one
    sorted_files = sorted(files, key=parse_epoch)
    last_file = sorted_files[-1]
    print(last_file)

    return torch.load(last_file, map_location="cpu")

# Bee Language Analysis

In [None]:
bee_interaction_obj = load_latest_interaction_file(
    logs_root="../logs/interactions/2025-07-02",
    seed_folder="gamesize10_bee_gs_seed42",
    split="validation",
    prefix="interaction_gpu0"
)

In [None]:
DIRECTIONS = {
    "N": 0, "NE": 1, "E": 2, "SE": 3,
    "S": 4, "SW": 5, "W": 6, "NW": 7
}
# inverse mapping for integer codes back to strings
INV_DIRECTIONS: Dict[int, str] = {v: k for k, v in DIRECTIONS.items()}

SECTOR_ANGLE = 2 * np.pi / len(DIRECTIONS)

In [None]:
def build_graph(batch) -> nx.DiGraph:
    """
    Construct a directed graph from a DataBatch
    """
    edge_indices = batch.edge_index.cpu().numpy().T
    edge_attrs = batch.edge_attr.cpu().numpy()

    graph = nx.DiGraph()
    for (u, v), (dist, dir_raw) in zip(edge_indices, edge_attrs):
        ui, vi = int(u), int(v)
        if isinstance(dir_raw, (bytes, str)):
            dir_str = dir_raw.decode() if isinstance(dir_raw, bytes) else dir_raw
            if dir_str not in DIRECTIONS:
                raise ValueError(f"Unknown direction '{dir_str}' in edge_attrs")
        else:
            dir_int = int(dir_raw)
            dir_str = INV_DIRECTIONS.get(dir_int)
            if dir_str is None:
                raise ValueError(f"Unknown direction code {dir_int} in edge_attrs")
        graph.add_edge(ui, vi, distance=float(dist), direction=dir_str)
    return graph

In [None]:
# test the build graph function on 1 graph
first_batch = bee_interaction_obj.aux_input["data"][0]
graphs = first_batch.to_data_list()  
single = graphs[0]
G = build_graph(single)
print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

In [None]:
for u, v, attrs in list(G.edges(data=True))[:5]:
    print(f"  {u}->{v} dist={attrs['distance']:.2f}, dir={attrs['direction']}")

In [None]:
def visualize_single_batch(batch):
    G = build_graph(batch)

    pos = {i: tuple(batch.pos[i].cpu().numpy()) for i in range(batch.pos.size(0))}

    # x[:,0]=nest, x[:,1]=food, x[:,2]=distractor
    types = {}
    for i, feat in enumerate(batch.x.cpu().numpy()):
        idx = feat.argmax()
        types[i] = "nest" if idx == 0 else "food" if idx == 1 else "distractor"

    color_map = {"nest": "lightblue", "food": "red", "distractor": "grey"}
    node_colors = [color_map[types[i]] for i in G.nodes()]
    labels = {i: types[i] for i in G.nodes()}

    plt.figure(figsize=(8, 6))
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
    nx.draw_networkx_labels(G, pos, labels, font_size=8)

    for u, v, d in G.edges(data=True):
        nx.draw_networkx_edges(
            G,
            pos,
            edgelist=[(u, v)],
            arrowstyle='-|>',
            arrowsize=12,
            connectionstyle='arc3,rad=0.1'
        )
        edge_label = { (u, v): f"{d['distance']:.1f}m, dir={d['direction']}" }
        nx.draw_networkx_edge_labels(
            G,
            pos,
            edge_labels=edge_label,
            font_size=7,
            label_pos=0.4
        )

    plt.axis('off')
    plt.show()

In [None]:
visualize_single_batch(single)

In [None]:

def extract_meaning_spaces(
    interaction,
    hypothesis: str
) -> np.ndarray:
    """
    Generate meaning vectors [distance, direction] under specified hypotheses.

    Always assumes shortest path and uses only edge attributes (except for coordinates).
    """
    # split batched graphs into individual graph
    data_list = []
    for batch in interaction.aux_input['data']:
        data_list.extend(batch.to_data_list())

    nest_indices = interaction.aux_input['nest_idx'].tolist()
    food_indices = interaction.aux_input['food_idx'].tolist()
    if len(data_list) != len(nest_indices):
        raise ValueError(
            f"Mismatch examples vs indices: {len(data_list)} graphs but {len(nest_indices)} nest indices"
        )

    representations: List[List[float]] = []

    

    for i, data in enumerate(data_list):
        # positions and graph
        pos = data.pos.cpu().numpy()
        G   = build_graph(data)

        # local start/target
        start_idx  = nest_indices[i]
        target_idx = food_indices[i]
        p0, p1     = pos[start_idx], pos[target_idx]

        # shortest-path
        hop_paths = list(nx.all_shortest_paths(G, source=start_idx, target=target_idx)) 
        def total_dist(path):
            return sum(G[u][v]['distance'] for u, v in zip(path, path[1:]))
        path_info = [(p, total_dist(p)) for p in hop_paths]
        min_d = min(d for _, d in path_info)
        best_paths = [p for p, d in path_info if d == min_d]
        best_path = min(best_paths)
        edges     = list(zip(best_path, best_path[1:]))

        DIRECTION_TO_DEGREES_VOCAB: Dict[str, float] = {
        "N": 90.0, "NE": 45.0, "E": 0.0, "SE": 315.0,
        "S": 270.0, "SW": 225.0, "W": 180.0, "NW": 135.0
    }
        COMPASS_VECS: Dict[str, np.ndarray] = {
            d: np.array([np.cos(np.deg2rad(phi)), np.sin(np.deg2rad(phi))])
            for d, phi in DIRECTION_TO_DEGREES_VOCAB.items()
        }
        SECTOR_ANGLE = 2 * np.pi / 8

        def degrees_to_discrete_direction(degrees):
            """Convert degrees to discrete direction using binning"""
            # Normalize to 0-360
            degrees = degrees % 360.0
            # Bin into 8 sectors (each 45 degrees wide)
            sector_idx = int((degrees + 22.5) // 45) % 8
            # Map sector index to direction
            sectors = ["E", "NE", "N", "NW", "W", "SW", "S", "SE"]
            return sectors[sector_idx]

        # prepare edge attributes
        edge_dirs = [G[u][v]['direction'] for u, v in edges]
        edge_degs = [DIRECTION_TO_DEGREES_VOCAB[dir_str] for dir_str in edge_dirs]
        edge_vecs = [COMPASS_VECS[dir_str] for dir_str in edge_dirs]
        edge_dists= [G[u][v]['distance'] for u, v in edges]

        total_dist = float(sum(edge_dists))
        hop_count  = float(len(edges))

        # straight-line metrics
        dx, dy      = p1[0] - p0[0], p1[1] - p0[1]
        straight_dist = float(np.hypot(dx, dy))
        straight_ang  = float(np.arctan2(dy, dx))

        # angle arithmetic: sum of degrees
        # treats directions like rotations (turning left/right)
        sum_deg = sum(edge_degs) % 360.0
        # map to sector index
        sector_idx = int((sum_deg + 22.5) // 45) % 8
        angle_arith = sector_idx * SECTOR_ANGLE

        # vector sum compass
        # treats distance like displacements (moving through space)
        vsum = np.sum(edge_vecs, axis=0) if edge_vecs else np.array([0.0, 0.0])
        if np.allclose(vsum, 0):
            # fallback to first direction
            sector_idx_vs = list(DIRECTION_TO_DEGREES_VOCAB.keys()).index(edge_dirs[0]) if edge_dirs else 0
        else:
            ang_v_deg = np.degrees(np.arctan2(vsum[1], vsum[0])) % 360.0
            sector_idx_vs = int((ang_v_deg + 22.5) // 45) % 8
        angle_vs = sector_idx_vs * SECTOR_ANGLE

        if hypothesis == 'coordinates':
            distance, direction = straight_dist, straight_ang

        elif hypothesis == 'hop_count_distance_vector_sum_direction':
            distance, direction = hop_count, angle_vs

        elif hypothesis == 'sum_distances_vector_sum_direction':
            distance, direction = total_dist, angle_vs

        elif hypothesis == 'hop_count_distance_angle_direction':
            distance, direction = hop_count, angle_arith

        elif hypothesis == 'sum_distances_angle_direction':
            distance, direction = total_dist, angle_arith

        else:
            raise ValueError(f"Unknown hypothesis {hypothesis}")

        representations.append([distance, direction])

    return np.array(representations, dtype=float)

In [None]:
def compute_topsim(
    message_vectors: np.ndarray,
    meaning_vectors: np.ndarray,
    normalize: str = "linear"
) -> float:
    # 1) custom message‐space metric on [dir_token, distance]
    def message_metric(u, v):
        tok_u, ru = int(u[0]) % 8, u[1]
        tok_v, rv = int(v[0]) % 8, v[1]
        θu = 2*np.pi * tok_u / 8
        θv = 2*np.pi * tok_v / 8

        dr = ru - rv
        Δθ = abs(θu - θv) % (2*np.pi)
        if Δθ > np.pi:
            Δθ = 2*np.pi - Δθ
        Δθ_norm = Δθ / np.pi

        return np.hypot(dr, Δθ_norm)

    # 2) message distances
    msg_dists = pdist(message_vectors, metric=message_metric)
    if np.std(msg_dists) == 0:
        return np.nan

    # 3) extract raw dist & angle from meanings
    d = meaning_vectors[:, 0]
    a = meaning_vectors[:, 1]
    n = len(d)

    # 4) normalize the distance column
    if normalize == "linear":
        if d.max() != d.min():
            d_norm = (d - d.min()) / (d.max() - d.min())
        else:
            d_norm = np.zeros_like(d)
    elif normalize == "log":
        d_log = np.log1p(d)
        if d_log.max() != d_log.min():
            d_norm = (d_log - d_log.min()) / (d_log.max() - d_log.min())
        else:
            d_norm = np.zeros_like(d_log)
    elif normalize == "rank":
        ranks = rankdata(d, method="average")
        d_norm = (ranks - 1) / (n - 1)
    else:
        raise ValueError(f"Unknown normalize={normalize!r}")

    # 5) build meaning‐space distances
    meaning_dists = []
    for i in range(n - 1):
        for j in range(i + 1, n):
            dd = d_norm[i] - d_norm[j]
            Δθ = abs(a[i] - a[j]) % (2 * np.pi)
            if Δθ > np.pi:
                Δθ = 2*np.pi - Δθ
            θ_norm = Δθ / np.pi
            meaning_dists.append(np.hypot(dd, θ_norm))

    meaning_dists = np.array(meaning_dists)
    if np.std(meaning_dists) == 0:
        return np.nan

    # 6) Spearman‐ρ
    return spearmanr(msg_dists, meaning_dists, nan_policy="raise").correlation



In [None]:
from sklearn.metrics import mutual_info_score

def _percentile_bins(x: np.ndarray, n_bins: int) -> tuple[np.ndarray, np.ndarray]:
    """Return digitised array and the bin edges used."""
    pct  = np.linspace(0, 100, n_bins + 1)
    edges = np.percentile(x, pct)
    edges[-1] += 1e-12
    return np.digitize(x, edges) - 1, edges

def _shannon_entropy(ids: np.ndarray) -> float:
    vals, cnts = np.unique(ids, return_counts=True)
    p = cnts / cnts.sum()
    return float(-np.sum(p * np.log(p + 1e-12)))

def _angle_to_sector(theta: np.ndarray) -> np.ndarray:
    """8-way compass bin  (E=0, NE=1, …)."""
    return (((theta % (2*np.pi)) + np.pi/8) // (np.pi/4)).astype(int)

def compute_posdis(
    meanings : np.ndarray, # (N,2)  [dist , θ(rad)]
    messages : np.ndarray, # (N,2)  [dist_tok , dir_tok]
    *, n_bins_distance: int = 2
) -> float:
    """
    Positional disentanglement (Chaabouni et al.2020).
    Returns a value in [0,1]; higher means cleaner token-concept alignment.
    """
    # discretise concepts
    concept_dist_disc, dist_edges = _percentile_bins(meanings[:, 0], n_bins_distance)
    concept_dir_disc  = _angle_to_sector(meanings[:, 1])
    concepts          = np.column_stack([concept_dist_disc, concept_dir_disc])

    pos_scores = []

    for pos in (0, 1):
        token_raw = messages[:, pos]

        # discretise token if continuous (distance position)
        if pos == 0:
            token_disc, _ = _percentile_bins(token_raw, n_bins_distance)
        else:
            token_disc   = token_raw.astype(int)

        # mutual information with each concept
        mi = [mutual_info_score(token_disc, concepts[:, i]) for i in (0, 1)]
        mi_sorted = sorted(mi, reverse=True)
        top1, top2 = mi_sorted + [0.0] * (2 - len(mi_sorted))   # pad if less than 2

        H = _shannon_entropy(token_disc)
        score = (top1 - top2) / H if H > 0 else 0.0
        pos_scores.append(score)

    total = float(np.mean(pos_scores))
    breakdown = {'distance_token' : pos_scores[0],
                'direction_token': pos_scores[1]}

    return total, breakdown


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

class _TREModel(nn.Module):
    """Linear-composition model  s = A φ_d + B φ_dir  (Appendix C of the paper)."""
    def __init__(self, vocab_size: int, emb: int = 32):
        super().__init__()
        self.dist_emb = nn.Linear(1, emb)        # φ(distance_token)
        self.dir_emb  = nn.Embedding(vocab_size, emb)
        self.A = nn.Linear(emb, 3, bias=False)   # -> [z_dist , z_cos , z_sin]
        self.B = nn.Linear(emb, 3, bias=False)

    def forward(self, dist_t: torch.Tensor, dir_t: torch.Tensor) -> torch.Tensor:
        z = self.A(self.dist_emb(dist_t)) + self.B(self.dir_emb(dir_t))
        return z 

def _loss_tre(pred: torch.Tensor, target: torch.Tensor, 
              dist_weight: float = 1.0, dir_weight: float = 1.0) -> torch.Tensor:
    """MSE on distance + cos/sin of angle with separate weights."""
    loss_dist = F.mse_loss(pred[:, 0], target[:, 0])
    loss_dir  = F.mse_loss(pred[:, 1:], target[:, 1:])
    return dist_weight * loss_dist + dir_weight * loss_dir

def compute_tre(
    meanings : np.ndarray,             # (N,2)   [distance , θ(rad)]
    messages : np.ndarray,             # (N,2)   [distance_token , dir_token]
    hyperparams: Optional[Dict[str, Any]] = None
) -> float:
    """
    Tree-reconstruction error.  Lower means more compositional.
    (We report the best validation loss, like the paper.)
    """
    hp = dict(batch_size=256, val_split=0.2, epochs=300,
              lr=1e-2, weight_decay=1e-5, seed=42, emb=32)
    if hyperparams: hp.update(hyperparams)
    _set_global_seed(hp["seed"])

    torch.manual_seed(hp["seed"]);  np.random.seed(hp["seed"])

    # ── prepare tensors ────────────────────────────────────────────────────
    dist_tok = messages[:, 0].astype(np.float32).reshape(-1, 1)
    dist_tok = (dist_tok - dist_tok.mean()) / (dist_tok.std() + 1e-8)   # z‑score
    dir_tok  = messages[:, 1].astype(np.int64)
    vocab    = int(dir_tok.max()) + 1

    # targets: distance  +  angle→(cos,sin)
    y_dist = dist_tok                          # same scale as input (z‑scored)
    y_vec  = np.column_stack([np.cos(meanings[:, 1]), np.sin(meanings[:, 1])])
    y      = np.column_stack([y_dist, y_vec]).astype(np.float32)

    # to torch
    Xd = torch.from_numpy(dist_tok)
    Xc = torch.from_numpy(dir_tok)
    Y  = torch.from_numpy(y)

    dataset  = TensorDataset(Xd, Xc, Y)
    n_val    = int(hp["val_split"] * len(dataset))
    n_train  = len(dataset) - n_val

    g = torch.Generator()
    g.manual_seed(hp["seed"])
    ds_train, ds_val = random_split(dataset, [n_train, n_val], generator=g)

    ld_train = DataLoader(ds_train, batch_size=hp["batch_size"], shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(hp["seed"]))
    ld_val   = DataLoader(ds_val, batch_size=hp["batch_size"], num_workers=0)

    model = _TREModel(vocab, emb=hp["emb"])
    opt   = torch.optim.Adam(model.parameters(), lr=hp["lr"],
                             weight_decay=hp["weight_decay"])

    best = np.inf;  patience = 25;  bad = 0
    for _ in range(hp["epochs"]):
        model.train()
        for xd, xc, y in ld_train:
            loss = _loss_tre(model(xd, xc), y, hp["dist_weight"], hp["dir_weight"])
            opt.zero_grad(); loss.backward(); opt.step()

        # validation
        model.eval();  val = 0.0
        with torch.no_grad():
            for xd, xc, y in ld_val:
                val += _loss_tre(model(xd, xc), y, hp["dist_weight"], hp["dir_weight"]).item() * len(y)
        val /= n_val

        if val < best - 1e-4:
            best, bad = val, 0
        else:
            bad += 1
            if bad >= patience: break          # early stop

    return float(best)

In [None]:
vocab_size = 8
token_directions = bee_interaction_obj.message[:, :8].argmax(dim=-1)
token_distances = bee_interaction_obj.message[:, -1]
messages = torch.stack([token_directions, token_distances], dim=1).numpy()

In [None]:
meaning = extract_meaning_spaces(bee_interaction_obj, 'hop_count_distance_vector_sum_direction')

In [None]:
tre_distance_only = compute_tre(meaning, messages, {'dist_weight': 1.0, 'dir_weight': 0.0})
tre_direction_only = compute_tre(meaning, messages, {'dist_weight': 0.0, 'dir_weight': 1.0})
tre_distance_heavy = compute_tre(meaning, messages, {'dist_weight': 2.0, 'dir_weight': 1.0})
tre_direction_heavy = compute_tre(meaning, messages, {'dist_weight': 1.0, 'dir_weight': 2.0})

In [None]:
total, per_tok = compute_posdis(meaning, messages)

# Distance Analysis

In [None]:
raw = np.array(token_distances)
fig = plt.figure()
plt.hist(raw, bins=50)
plt.title("Histogram of raw token_distances")
plt.xlabel("token_distance")
plt.ylabel("count")
plt.show()

In [None]:
expd = np.exp(raw)
fig = plt.figure()
plt.hist(expd, bins=50)
plt.title("Histogram of exp(token_distances)")
plt.xlabel("exp(token_distance)")
plt.ylabel("count")
plt.show()

# Compositionality

In [None]:
logs_root    = "../logs/interactions/2025-07-02"
seed_folders = [
    # # baseline
    'gamesize10_bee_gs_seed42',
    'gamesize10_bee_gs_seed123', 
    'gamesize10_bee_gs_seed2025',
    'gamesize10_bee_gs_seed31', 
    'gamesize10_bee_gs_seed27',
    # binned distance
    # "binneddistance_bee_gs_seed27",
    # "binneddistance_bee_gs_seed31",
    # "binneddistance_bee_gs_seed42",
    # "binneddistance_bee_gs_seed123",
    # "binneddistance_bee_gs_seed2025"
]
hypotheses = [
    "coordinates",
    "hop_count_distance_vector_sum_direction",
    "sum_distances_vector_sum_direction",
    "hop_count_distance_angle_direction",
    "sum_distances_angle_direction"
]

In [None]:
os.makedirs("results", exist_ok=True)
rows = []

for seed_folder in seed_folders:
    interaction = load_latest_interaction_file(logs_root, seed_folder)

    # ── extract emerged messages ─────────────────────────────────────────
    token_dirs  = interaction.message[:, :8].argmax(dim=-1)
    token_dists = interaction.message[:, -1]
    emerged_msgs = torch.stack([token_dirs, token_dists], dim=1).numpy()
    print(f"Processing {seed_folder}: {len(emerged_msgs)} messages")

    rng = np.random.RandomState(int(seed_folder.split("seed")[-1]))
    seed = int(seed_folder.split("seed")[-1])
    print(f"Seed: {seed}")

    for hyp in hypotheses:
        meaning = extract_meaning_spaces(interaction, hyp)

        tl = compute_topsim(emerged_msgs, meaning, normalize="linear")
        tg = compute_topsim(emerged_msgs, meaning, normalize="log")
        tr = compute_topsim(emerged_msgs, meaning, normalize="rank")

        ptot, pbd = compute_posdis(emerged_msgs, meaning)
        tre       = compute_tre(meaning, emerged_msgs, {'dist_weight': 1.0, 'dir_weight': 1.0})
        tre_dist_only    = compute_tre(meaning,    emerged_msgs,
                                       {'dist_weight':1.0, 'dir_weight':0.0})
        tre_dir_only     = compute_tre(meaning,    emerged_msgs,
                                       {'dist_weight':0.0, 'dir_weight':1.0})
        tre_dist_heavy   = compute_tre(meaning,    emerged_msgs,
                                       {'dist_weight':2.0, 'dir_weight':1.0})
        tre_dir_heavy    = compute_tre(meaning,    emerged_msgs,
                                       {'dist_weight':1.0, 'dir_weight':2.0})

        rows.append({
            'seed':               seed_folder,
            'hypothesis':         hyp,
            'topsim_lin':         tl,
            'topsim_log':         tg,
            'topsim_rank':        tr,
            'posdis_total':       ptot,
            'posdis_distance':    pbd['distance_token'],
            'posdis_direction':   pbd['direction_token'],
            'tre':                tre,
            'tre_dist_only':      tre_dist_only,
            'tre_dir_only':       tre_dir_only,
            'tre_dist_heavy':     tre_dist_heavy,
            'tre_dir_heavy':      tre_dir_heavy
        })

    # ── random sanity check ───────────────────────────────────────────────
    N = len(emerged_msgs)
    rand_logits = rng.randn(N, 8)
    rand_dist   = rng.rand(N, 1) * 10
    rand_dir    = rand_logits.argmax(axis=-1).reshape(-1, 1)
    rand_msgs   = np.hstack([rand_dist, rand_dir])

    rand_true_d = rng.rand(N) * 10
    rand_true_a = rng.rand(N) * 2*np.pi - np.pi
    rand_mean   = np.vstack([rand_true_d, rand_true_a]).T

    r_l   = compute_topsim(rand_msgs, rand_mean, normalize="linear")
    r_g   = compute_topsim(rand_msgs, rand_mean, normalize="log")
    r_r   = compute_topsim(rand_msgs, rand_mean, normalize="rank")
    rp, rpd = compute_posdis(rand_msgs, rand_mean)
    r_tre = compute_tre(rand_mean, rand_msgs, {'dist_weight': 1.0, 'dir_weight': 1.0})

    rows.append({
        'seed':               seed_folder,
        'hypothesis':         'random',
        'topsim_lin':         r_l,
        'topsim_log':         r_g,
        'topsim_rank':        r_r,
        'posdis_total':       rp,
        'posdis_distance':    rpd['distance_token'],
        'posdis_direction':   rpd['direction_token'],
        'tre':                r_tre
    })

df = pd.DataFrame(rows, columns=[
    'seed',
    'hypothesis',
    'topsim_lin',
    'topsim_log',
    'topsim_rank',
    'posdis_total',
    'posdis_distance',
    'posdis_direction',
    'tre',
    'tre_dist_only',
    'tre_dir_only',
    'tre_dist_heavy',
    'tre_dir_heavy'
])
df.to_csv("../results/bee_compositionality.csv", index=False)

In [None]:

df.to_csv("bee_compositionality.csv", index=False)

# Human Analysis

In [None]:
human_interaction_obj = load_latest_interaction_file(
    logs_root="../logs/interactions/2025-06-22",
    seed_folder="maxlen10_human_gs_seed42",
    split="validation",
    prefix="interaction_gpu0"
)

In [None]:
human_interaction_obj.message.argmax(-1)[:, :-1].cpu().numpy()

In [None]:
PAD = 0
DIRECTIONS = { "N":0,  "NE":1,  "E":2,  "SE":3,
               "S":4,  "SW":5,  "W":6,  "NW":7 }

def intify(seq, vocab):
    """Map symbols to ints ≥1; PAD stays 0."""
    out = []
    for s in seq:
        if s not in vocab:
            vocab[s] = len(vocab) + 1
        out.append(vocab[s])
    return out

def pad_to_max_len(seqs, max_len):
    """Right-pad or truncate each int sequence to max_len."""
    M = len(seqs)
    out = np.zeros((M, max_len), dtype=int)
    for i, s in enumerate(seqs):
        ln = min(len(s), max_len)
        out[i,:ln] = s[:ln]
    return out

def extract_truth_sequences(interaction, max_len, dist_bins=3):
    """
    For each example in interaction:
      - Enumerate all fewest-hop paths (nx.all_shortest_paths)
      - Tie-break by summing each path's G[u][v]['distance'] → pick minimal
      - Read G[u][v]['direction'] → 0..7 via DIRECTIONS
      - Collect raw distances, bin globally into `dist_bins` percentile bins → 0..dist_bins-1
      - Pad/truncate both sequences to max_len (PAD=0)
    Returns:
      truth_dirs, truth_dists: List[List[int]] each of length max_len
    """
    # 1) gather graphs & indices
    data_list, raw_paths = [], []
    for batch in interaction.aux_input['data']:
        data_list.extend(batch.to_data_list())
    nests = interaction.aux_input['nest_idx'].tolist()
    foods = interaction.aux_input['food_idx'].tolist()
    if len(data_list) != len(nests):
        raise ValueError("mismatch data vs indices")

    # 2) collect all raw distances for binning later
    all_dists = []
    for i, data in enumerate(data_list):
        G = build_graph(data)
        s, t = nests[i], foods[i]
        # all fewest‐hop paths
        cands = list(nx.all_shortest_paths(G, source=s, target=t))
        # pick by minimal sum of edge‐distance
        def cost(path):
            return sum(G[u][v]['distance'] for u,v in zip(path,path[1:]))
        best = min(cands, key=cost)
        # record its edges
        eds = []
        for u,v in zip(best, best[1:]):
            raw = G[u][v]['distance']
            all_dists.append(raw)
            eds.append((G[u][v]['direction'], raw))
        raw_paths.append(eds)

    # 3) global percentile bins for distances
    pct = np.linspace(0,100,dist_bins+1)
    edges = np.percentile(all_dists, pct)
    edges[-1] += 1e-12  # close last bin

    # 4) build discrete sequences
    truth_dirs, truth_dists = [], []
    for eds in raw_paths:
        dirs, dists = [], []
        for dstr, raw in eds:
            dirs.append(DIRECTIONS[dstr])
            cat = int(np.digitize(raw, edges)) - 1
            dists.append(cat)
        # pad/truncate
        if len(dirs) < max_len:
            padlen = max_len - len(dirs)
            dirs  += [PAD]*padlen
            dists += [PAD]*padlen
        else:
            dirs  = dirs[:max_len]
            dists = dists[:max_len]
        truth_dirs.append(dirs)
        truth_dists.append(dists)

    return truth_dirs, truth_dists

In [None]:
def compute_posdis_seq(msg_seqs, truth_dirs, truth_dists):
    """
    For each slot p:
      MI_dir = I(Msg_p;Dir_p)
      MI_dst = I(Msg_p;Dist_p)
      posdis_p = (max(MI_dir,MI_dst) - min(...)) / H(Msg_p)
    Return (mean_posdis, per_slot_list).
    """
    X = np.array(msg_seqs,     dtype=int)
    D = np.array(truth_dirs,   dtype=int)
    Z = np.array(truth_dists,  dtype=int)
    N, L = X.shape

    # entropy per slot
    H = np.zeros(L)
    for p in range(L):
        _, cnt = np.unique(X[:,p], return_counts=True)
        probs = cnt / cnt.sum()
        H[p] = entropy(probs)

    # compute MI per slot
    scores = []
    for p in range(L):
        tok = X[:, p]
        mi_dir  = mutual_info_score(tok, D[:, p])
        mi_dist = mutual_info_score(tok, Z[:, p])
        top1, top2 = max(mi_dir,mi_dist), min(mi_dir,mi_dist)
        scores.append((top1 - top2)/H[p] if H[p]>0 else 0.0)

    return float(np.mean(scores)), scores

In [None]:
class TRESeq(nn.Module):
    """
    Takes a derivation sequence of integers (concatenated dir+dist tokens),
    embeds and sums them, then predicts the full message sequence via a single linear head.
    """
    def __init__(self, deriv_vocab_size: int, msg_vocab_size: int, seq_len: int, emb: int = 64):
        super().__init__()
        self.emb = nn.Embedding(deriv_vocab_size, emb, padding_idx=PAD)
        self.head = nn.Linear(emb, seq_len * msg_vocab_size, bias=False)
        self.seq_len = seq_len
        self.msg_vocab_size = msg_vocab_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.emb(x).sum(dim=1)
        out = self.head(z)
        return out.view(-1, self.seq_len, self.msg_vocab_size)

def compute_tre_seq(
    deriv_seqs: list[list[int]],
    msg_seqs:   list[list[int]],
    tre_hp:     dict | None = None
) -> float:
    """
    Train TRESeq to predict `msg_seqs` from `deriv_seqs`.
    Returns the best validation cross‐entropy loss (lower ⇒ more compositional).
    
    deriv_seqs : List of integer lists (the concatenated dir+dist derivation tokens)
    msg_seqs   : List of integer lists (the actual message token sequences)
    tre_hp     : Optional hyperparam overrides, e.g. {'seed':123, 'epochs':100}
    """
    # 1) default hyperparameters
    hp = {
        'seed': 42,
        'emb': 64,
        'batch_size': 256,
        'val_split': 0.2,
        'epochs': 200,
        'lr': 1e-2,
        'weight_decay': 1e-5,
        'patience': 20,
    }
    if tre_hp:
        hp.update(tre_hp)
    _set_global_seed(hp['seed'])

    # 2) build vocab sizes
    deriv_vocab_size = max(max(seq) for seq in deriv_seqs) + 1
    msg_vocab_size   = max(max(seq) for seq in msg_seqs)   + 1
    L_deriv = max(len(seq) for seq in deriv_seqs)
    L_msg   = max(len(seq) for seq in msg_seqs)

    # 3) pad sequences to fixed length
    def pad(seqs, L):
        arr = np.full((len(seqs), L), PAD, dtype=np.int64)
        for i, s in enumerate(seqs):
            ln = min(len(s), L)
            arr[i, :ln] = s[:ln]
        return arr
    X_deriv = pad(deriv_seqs, L_deriv)
    Y_msg   = pad(msg_seqs,   L_msg)

    # 4) create PyTorch dataset & split
    X_t = torch.from_numpy(X_deriv)
    Y_t = torch.from_numpy(Y_msg)
    ds = TensorDataset(X_t, Y_t)
    n_val = int(hp['val_split'] * len(ds))
    n_tr  = len(ds) - n_val
    g = torch.Generator().manual_seed(hp['seed'])
    tr_ds, vl_ds = random_split(ds, [n_tr, n_val], generator=g)
    tr_ld = DataLoader(tr_ds, batch_size=hp['batch_size'], shuffle=True,
                       generator=torch.Generator().manual_seed(hp['seed']))
    vl_ld = DataLoader(vl_ds, batch_size=hp['batch_size'])

    # 5) model & optimizer
    model = TRESeq(deriv_vocab_size, msg_vocab_size, L_msg, emb=hp['emb'])
    opt   = torch.optim.Adam(model.parameters(),
                             lr=hp['lr'],
                             weight_decay=hp['weight_decay'])

    # 6) training with early stopping
    best_loss, bad = float('inf'), 0
    for epoch in range(hp['epochs']):
        model.train()
        for xb, yb in tr_ld:
            logits = model(xb)
            loss = F.cross_entropy(
                logits.view(-1, msg_vocab_size),
                yb.view(-1),
                ignore_index=PAD
            )
            opt.zero_grad()
            loss.backward()
            opt.step()

        model.eval()
        val_loss = 0.0
        total    = 0
        with torch.no_grad():
            for xb, yb in vl_ld:
                logits = model(xb)
                l = F.cross_entropy(
                    logits.view(-1, msg_vocab_size),
                    yb.view(-1),
                    ignore_index=PAD
                )
                val_loss += l.item() * yb.numel()
                total    += yb.numel()
        val_loss /= total

        if val_loss < best_loss - 1e-4:
            best_loss, bad = val_loss, 0
        else:
            bad += 1
            if bad >= hp['patience']:
                break

    return best_loss

In [None]:
def _ed(a, b):
    """Levenshtein edit distance on integer tuples."""
    m,n = len(a), len(b)
    dp = list(range(n+1))
    for i in range(1,m+1):
        prev, dp[0] = dp[0], i
        for j in range(1,n+1):
            cost = 0 if a[i-1]==b[j-1] else 1
            cur  = min(prev+cost, dp[j]+1, dp[j-1]+1)
            prev, dp[j] = dp[j], cur
    return dp[-1]

def compute_topsim_seq(msg_seqs, truth_seqs):
    """
    Spearman‐ρ between:
      msg_dist(i,j)   = edit_dist(msg_i,msg_j)/L
      truth_dist(i,j) = edit_dist(truth_i,truth_j)/L
    Now with nan_policy="raise".
    """
    X = np.array(msg_seqs,   dtype=int)
    Y = np.array(truth_seqs, dtype=int)
    N, L = X.shape

    msg_d = pdist(X, metric=lambda a,b: _ed(tuple(a),tuple(b))/L)
    truth_d = pdist(Y, metric=lambda a,b: _ed(tuple(a),tuple(b))/L)

    return spearmanr(msg_d, truth_d, nan_policy="raise").correlation

In [None]:
import re
human_seeds = [
    "logs/interactions/2025-07-08/maxlen2_vocab100_human_gs_seed27",
    "logs/interactions/2025-07-08/maxlen2_vocab100_human_gs_seed31",
    "logs/interactions/2025-06-22/maxlen2_human_gs_seed42",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len2_seed123",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len2_seed2025",

    "logs/interactions/2025-07-08/maxlen4_vocab100_human_gs_seed27",
    "logs/interactions/2025-07-08/maxlen4_vocab100_human_gs_seed31",
    "logs/interactions/2025-06-22/maxlen4_human_gs_seed42",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len4_seed123",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len4_seed2025",

    "logs/interactions/2025-07-08/maxlen6_vocab100_human_gs_seed27",
    "logs/interactions/2025-07-08/maxlen6_vocab100_human_gs_seed31",
    "logs/interactions/2025-06-22/maxlen6_human_gs_seed42",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len6_seed123",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len6_seed2025",

    "logs/interactions/2025-07-08/maxlen10_vocab100_human_gs_seed27",
    "logs/interactions/2025-07-08/maxlen10_vocab100_human_gs_seed31",
    "logs/interactions/2025-06-22/maxlen10_human_gs_seed42",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len10_seed123",
    "logs/interactions/2025-06-23/human_maxlen_sweep_max_len10_seed2025"
]

logs_root = ".."
_set_global_seed(42)
rows = []
for seed_folder in human_seeds:
    m = re.search(r"max[_]?len[_]?(\d+)", seed_folder)
    if not m:
        raise ValueError(f"Could not parse max_len from '{seed_folder}'")
    max_len = int(m.group(1))

    interaction = load_latest_interaction_file(
        logs_root=logs_root,
        seed_folder=seed_folder,
        split='validation',
        prefix='interaction_gpu0'
    )

    truth_dirs, truth_dists = extract_truth_sequences(
        interaction,
        max_len=max_len,
        dist_bins=3
    )

    msg_ids   = interaction.message.argmax(-1)[:, :-1].cpu().numpy()
    msg_seqs  = [ list(row[:max_len]) for row in msg_ids ]

    topsim_actual   = compute_topsim_seq(msg_seqs, truth_dirs)
    posdis_actual,_ = compute_posdis_seq(msg_seqs, truth_dirs, truth_dists)

    truth_dirs_shift  = [[d+1 for d in seq] for seq in truth_dirs]
    truth_dists_shift = [[z+1 for z in seq] for seq in truth_dists]

    offset = len(DIRECTIONS) 
    deriv_seqs = []
    for dseq, zseq in zip(truth_dirs_shift, truth_dists_shift):
        seq = []
        for d, z in zip(dseq, zseq):
            if d != PAD: 
                seq.append(d)
            if z != PAD:
                seq.append(z + offset)
        deriv_seqs.append(seq)

    tre_actual = compute_tre_seq(deriv_seqs, msg_seqs, tre_hp={'seed':42})

    rows.append({
        'seed_folder': seed_folder,
        'max_len':     max_len,
        'type':        'actual',
        'TopSim':      topsim_actual,
        'PosDis':      posdis_actual,
        'TRE':         tre_actual
    })

    unique_tokens = sorted({t for seq in msg_seqs for t in seq})
    np.random.seed(42)
    rand_seqs = [
        list(np.random.choice(unique_tokens, size=max_len))
        for _ in msg_seqs
    ]

    topsim_rand   = compute_topsim_seq(rand_seqs, truth_dirs)
    posdis_rand,_ = compute_posdis_seq(rand_seqs, truth_dirs, truth_dists)
    tre_rand      = compute_tre_seq(deriv_seqs, rand_seqs, tre_hp={'seed':42})

    rows.append({
        'seed_folder': seed_folder,
        'max_len':     max_len,
        'type':        'random',
        'TopSim':      topsim_rand,
        'PosDis':      posdis_rand,
        'TRE':         tre_rand
    })

df = pd.DataFrame(rows)
df.to_csv("human_compositionality.csv", index=False)