In [None]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import io
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from util import TokenMergeBuffer, TokenUnmergeBuffer, EarlyStopper
from models import AE_no_lift
from toy_data.environments import RandomWalker

from types import SimpleNamespace

%load_ext autoreload
%autoreload 2

KeyboardInterrupt: 

# Dataset

In [None]:
batch_size = 128

walker = RandomWalker(
    n=1000, 
    length=10, 
    dim=2, 
    step_scale=0.1, 
    init="normal", 
    drift="centre").generate()

dataset = torch.tensor(walker.data, dtype=torch.float32)

train_pct = 0.8
train_size = int(train_pct * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

# Train

In [None]:
def plot_two_merge_event(
    tokens_2d_step1,          # (L1, 2) coords BEFORE merge #1
    prob_matrix1,             # (L1, L1) softmax AFTER masking/softmax
    chosen1,                  # (i1, j1)
    tokens_2d_step2,          # (L2, 2) coords AFTER merge #1, BEFORE merge #2
    prob_matrix2,             # (L2, L2) softmax at step 2
    chosen2,                  # (i2, j2)
    figsize=(12, 4)
):
    """
    Visualizes exactly two merge steps:
    Merge #1 probability matrix + chosen pair
    Merge #2 probability matrix + chosen pair
    Final token 2-D scatter after the second merge.
    """

    fig, axes = plt.subplots(1, 3, figsize=figsize)

    # ---- Helper to render each probability map ----
    def render_prob(ax, P, chosen, title):
        P = P.copy()
        np.fill_diagonal(P, np.nan)              # mask diagonal
        
        im = ax.imshow(P, interpolation='nearest', cmap='viridis')
        ax.set_title(title)

        # highlight chosen pair
        i, j = chosen
        ax.scatter([j], [i], s=80, edgecolor="red", facecolor="none", linewidth=1.5)
        ax.scatter([i], [j], s=80, edgecolor="red", facecolor="none", linewidth=1.5)

        # nice grid
        ax.set_xticks(np.arange(P.shape[1]) + 0.5, minor=True)
        ax.set_yticks(np.arange(P.shape[0]) + 0.5, minor=True)
        ax.grid(which="minor", color="w", linewidth=0.4)
        ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)

        # colorbar
        fig.colorbar(im, ax=ax, fraction=0.10, pad=0.03)

    # ---- Panel 1: merge #1 probability map ----
    render_prob(axes[0], prob_matrix1, chosen1, "Merge #1 Softmax")

    # ---- Panel 2: merge #2 probability map ----
    render_prob(axes[1], prob_matrix2, chosen2, "Merge #2 Softmax")

    # ---- Panel 3: final token positions after merge #2 ----
    ax = axes[2]
    ax.scatter(tokens_2d_step2[:,0], tokens_2d_step2[:,1],
               s=40, color="black")

    # highlight merged token #2
    mi2, mj2 = chosen2
    merged_xy = tokens_2d_step2  # user passes correct coords after merge #2

    # You can choose to mark the newly created token with a star if desired
    # ax.scatter([merged_xy[-1,0]], [merged_xy[-1,1]],
    #            s=200, marker="*", color="red", edgecolor="black")

    ax.set_title("Tokens After Merge #2")
    ax.set_aspect('equal', 'box')
    ax.grid(True, alpha=0.2)
    ax.tick_params(labelbottom=False, labelleft=False)

    plt.tight_layout()
    return fig

In [None]:
def plot_merge_event_full(tokens_2d, prob_matrix, chosen_i, chosen_j,
                     orig_traj, recon_traj, merged_token_2d=None,
                     title_positions="Token positions",
                     title_matrix="Softmax merge probs",
                     title_recon="Original (blue), Recon (red)",
                     highlight_recon_idx=None):

    # ------------------------------------
    # Convert to numpy
    # ------------------------------------
    if hasattr(tokens_2d, "cpu"):
        tokens_2d = tokens_2d.detach().cpu().numpy()
    if hasattr(prob_matrix, "cpu"):
        prob_matrix = prob_matrix.detach().cpu().numpy()
    if hasattr(orig_traj, "cpu"):
        orig_traj = orig_traj.detach().cpu().numpy()
    if hasattr(recon_traj, "cpu"):
        recon_traj = recon_traj.detach().cpu().numpy()

    # Create **3** subplots
    fig, axes = plt.subplots(1, 3, figsize=(15,4))
    N = tokens_2d.shape[0]

    # ========================================================
    # 1 — Softmax matrix
    # ========================================================
    np.fill_diagonal(prob_matrix, np.nan)
    ax = axes[0]
    im = ax.imshow(prob_matrix, interpolation="nearest")
    fig.colorbar(im, ax=ax, fraction=0.046)
    ax.set_title(title_matrix)
    ax.set_xlabel("Token index")
    ax.set_ylabel("Token index")

    ax.set_xticks(range(N))
    ax.set_yticks(range(N))

    # highlight chosen merge pair
    ax.scatter([chosen_j], [chosen_i], s=250, facecolor="none",
               edgecolor="red", linewidth=2)
    ax.scatter([chosen_i], [chosen_j], s=250, facecolor="none",
               edgecolor="red", linewidth=2)

    ax.set_xticks(np.arange(-.5, N, 1), minor=True)
    ax.set_yticks(np.arange(-.5, N, 1), minor=True)
    ax.grid(which="minor", color="black", linewidth=0.2)

    # ========================================================
    # 2 — Token positions + merged token
    # ========================================================
    ax = axes[1]
    ax.scatter(tokens_2d[:,0], tokens_2d[:,1], s=40)

    for idx, (x,y) in enumerate(tokens_2d):
        ax.text(x+0.02, y+0.02, str(idx))

    # highlight original pair
    ax.scatter(
        tokens_2d[[chosen_i, chosen_j],0],
        tokens_2d[[chosen_i, chosen_j],1],
        s=200, edgecolor="black", facecolor="none", linewidth=2
    )

    # merged token (star)
    # if merged_token_2d is not None:
    #     if hasattr(merged_token_2d, "cpu"):
    #         merged_token_2d = merged_token_2d.detach().cpu().numpy()

    #     ax.scatter(
    #         [merged_token_2d[0]], [merged_token_2d[1]],
    #         s=220, marker="*", color="red", linewidth=2,
    #     )

    ax.set_title(f"{title_positions}\nMerged: ({chosen_i}, {chosen_j})")
    ax.set_aspect("equal")
    ax.set_xlabel("dim 1")
    ax.set_ylabel("dim 2")

    # ========================================================
    # 3 — Original vs Reconstruction trajectory
    # ========================================================
    ax = axes[2]

    

    # draw original trajectory
    for (x0, y0), (x1, y1) in zip(orig_traj[:-1], orig_traj[1:]):
        ax.plot([x0, x1], [y0, y1], color="blue", linewidth=1)

    # draw reconstructed trajectory
    for (x0, y0), (x1, y1) in zip(recon_traj[:-1], recon_traj[1:]):
        ax.plot([x0, x1], [y0, y1], color="red", linewidth=1)

    if highlight_recon_idx is not None:
        i, j = highlight_recon_idx
        ax.scatter([recon_traj[i,0]], [recon_traj[i,1]],
                s=50, marker="*", color="red", edgecolor="black", linewidth=1)
        ax.scatter([recon_traj[j,0]], [recon_traj[j,1]],
                s=50, marker="*", color="red", edgecolor="black", linewidth=1)


    ax.set_aspect("equal")
    ax.set_title(title_recon)
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    plt.tight_layout()
    return fig

In [None]:
def remap_softmax_to_buffer(prob_matrix, active_idx, full_N, removed_pairs):
    """
    prob_matrix: (n_active, n_active) softmax matrix
    active_idx: indices in buffer corresponding to rows/cols of prob_matrix
    full_N: total buffer size BEFORE merge
    removed_pairs: indices that were merged and removed in that step
    """
    out = np.full((full_N, full_N), np.nan)

    # Fill values for still-active tokens
    for i_local, i_global in enumerate(active_idx):
        for j_local, j_global in enumerate(active_idx):
            out[i_global, j_global] = prob_matrix[i_local, j_local]

    # Insert NaNs for removed tokens
    for dead in removed_pairs:
        out[dead, :] = np.nan
        out[:, dead] = np.nan

    return out


def plot_2_merges(
        tokens_1, 
        tokens_2, 
        prob_matrix_1, 
        prob_matrix_2,
        active_idx_1,
        active_idx_2,
        removed_1,
        removed_2,
        orig_traj_1,
        orig_traj_2,
        recon_traj_1,
        recon_traj_2,
        merged_token_1,
        merged_token_2
    ):
    """
    tokens_1, tokens_2: (n_active,2) 2D coordinates of tokens before each merge
    prob_matrix_1, prob_matrix_2: softmax matrices in ACTIVE index space
    active_idx_1, active_idx_2: active token -> buffer index mapping
    removed_1, removed_2: the pair removed at each step (global buffer indices)
    merged_token_1, merged_token_2: (2,) 2D coords of new merged token
    """

    # Determine full size
    full_N = max(active_idx_1.max(), active_idx_2.max()).item() + 1

    # --- Build full softmax matrices with NaNs inserted ---
    prob_full_1 = remap_softmax_to_buffer(prob_matrix_1, active_idx_1, full_N, removed_1)
    prob_full_2 = remap_softmax_to_buffer(prob_matrix_2, active_idx_2, full_N, removed_2)

    # --- Grid of 2 steps ---
    fig, axes = plt.subplots(2, 3, figsize=(17,10))

    # ============ STEP 1 ============
    _ = plot_merge_event_full(
        tokens_2d=tokens_1,
        prob_matrix=prob_full_1,
        chosen_i=removed_1[0], chosen_j=removed_1[1],
        orig_traj=orig_traj_1,
        recon_traj=recon_traj_1,
        merged_token_2d=merged_token_1,
        title_positions="Step 1 token positions",
        title_matrix="Step 1 softmax (full buffer space)",
        title_recon="Step 1 trajectories",
    )

    # move result onto axes[0]
    fig.axes[-1].remove()   # remove auto-created fig from plot_merge_event_full
    for ax_from, ax_to in zip(plt.gcf().axes[-3:], axes[0]):
        fig.axes.remove(ax_from)
        fig.add_axes(ax_to)
        ax_to = ax_from

    # ============ STEP 2 ============
    _ = plot_merge_event_full(
        tokens_2d=tokens_2,
        prob_matrix=prob_full_2,
        chosen_i=removed_2[0], chosen_j=removed_2[1],
        orig_traj=orig_traj_2,
        recon_traj=recon_traj_2,
        merged_token_2d=merged_token_2,
        title_positions="Step 2 token positions",
        title_matrix="Step 2 softmax (full buffer space)",
        title_recon="Step 2 trajectories",
    )

    # move onto axes[1]
    fig.axes[-1].remove()
    for ax_from, ax_to in zip(plt.gcf().axes[-3:], axes[1]):
        fig.axes.remove(ax_from)
        fig.add_axes(ax_to)

    plt.tight_layout()
    return fig


In [None]:


def get_merge_depth(curr_epoch, total_epochs, max_depth, power=2.0):
    progress = curr_epoch / total_epochs
    scaled = progress ** (1.0 / power)     # fast rise
    depth = int(scaled * max_depth)        # quantize into buckets
    return min(max(depth, 1), max_depth)


@torch.no_grad()
def evaluate(model, val_loader, device, policy, current_depth, criterion, temperature=1):
    model.eval()
    total_loss = 0.0

    for batch in val_loader:
        raw_tokens = batch.to(device)
        buf = TokenMergeBuffer(raw_tokens)

        while buf.can_merge(current_depth):
            active = buf.get_active_tokens()
            B, N, D = active.shape
            pair_idx = torch.combinations(torch.arange(N, device=device))
            num_pairs = pair_idx.shape[0]

            t1 = active[:, pair_idx[:, 0], :]
            t2 = active[:, pair_idx[:, 1], :]

            t1 = t1.detach()
            t2 = t2.detach()

            x = torch.cat([t1, t2], dim=-1)

            x_hat = model(x)

            t1_hat, t2_hat = torch.chunk(x_hat, 2, dim=-1)                    

            mse = nn.MSELoss(reduction="none")
            recon_t1 = mse(t1, t1_hat).mean(dim=-1)
            recon_t2 = mse(t2, t2_hat).mean(dim=-1)
            loss_per_pair = recon_t1 + recon_t2

            if policy == "argmin":
                chosen = loss_per_pair.argmin(dim=1)
            elif policy == "softmax":
                probs = torch.softmax(-loss_per_pair / temperature, dim=1)
                chosen = torch.multinomial(probs, 1).squeeze(1)
            elif policy == "uniform":
                chosen = torch.randint(0, num_pairs, (B,), device=device)
            else:
                raise ValueError(f"Unknown policy {policy}")

            local_t1_idx = pair_idx[chosen, 0]
            local_t2_idx = pair_idx[chosen, 1]

            chosen_t1 = active[torch.arange(B), local_t1_idx, :]
            chosen_t2 = active[torch.arange(B), local_t2_idx, :]

            chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

            merged_tokens = model.encode(chosen_x)

            buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

        buffer = buf.buffer
        merge_history = buf.get_merge_history()
        active_mask = buf.get_active_mask()
        n_original_tokens = raw_tokens.shape[1]    
        unmerge_buf = TokenUnmergeBuffer(buffer=buffer, active_mask=active_mask, merges=merge_history, n_original=n_original_tokens) 

        while not unmerge_buf.is_done():
            merged_token = unmerge_buf.get_next_to_unmerge()
            pred = model.decode(merged_token)
            t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
            unmerge_buf.step_unmerge(t1_pred, t2_pred)
        
        reconstructed = unmerge_buf.get_original_tokens()

        loss = criterion(reconstructed, raw_tokens)
        total_loss += loss.item()


    return total_loss / len(val_loader)




def train(
        model,
        train_loader,
        val_loader,
        config,
        device,
        checkpoint_path,
        max_depth,
        ):
    
    model.to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.MSELoss()

    best_val_loss = float('inf')


    wandb.init(
        project="random-walker",
        config=config,
        name=config.model_name
    )

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0.0

        orig_samples = None
        recon_samples = None

        capture_done = False
        merge_viz_events = []
        merge_figs = []
        recon_fig = None

        for batch in train_loader:
            optimiser.zero_grad()
            raw_tokens = batch[0].to(device) if isinstance(batch, (list, tuple)) else batch.to(device)

            if orig_samples is None:
                orig_samples = raw_tokens

            buf = TokenMergeBuffer(raw_tokens)

            # current_depth = get_merge_depth(epoch+1, config.epochs, config.max_depth)
            current_depth = max_depth
            

            while buf.can_merge(current_depth):
                active = buf.get_active_tokens()  # (B, N, D)
                B, N, D = active.shape

                pair_idx = torch.combinations(torch.arange(N, device=device))
                num_pairs = pair_idx.shape[0]

                t1 = active[:, pair_idx[:, 0], :]
                t2 = active[:, pair_idx[:, 1], :]

                # TODO: fix detaching
                # t1 = t1.detach()
                # t2 = t2.detach()

                x = torch.cat([t1, t2], dim=-1)

                x_hat = model(x)

                t1_hat, t2_hat = torch.chunk(x_hat, 2, dim=-1)                    

                mse = nn.MSELoss(reduction="none")
                recon_t1 = mse(t1, t1_hat).mean(dim=-1)
                recon_t2 = mse(t2, t2_hat).mean(dim=-1)
                loss_per_pair = recon_t1 + recon_t2

                policy = config.sampling_policy
                if policy == "argmin":
                    chosen = loss_per_pair.argmin(dim=1)
                elif policy == "uniform":
                    chosen = torch.randint(0, num_pairs, (B,), device=device)
                elif policy == "softmax":
                    probs = torch.softmax(-loss_per_pair / config.temperature, dim=1)
                    chosen = torch.multinomial(probs, 1).squeeze(1)

                    if not capture_done:
                        b = 0  # pick one sample to visualize

                        chosen_i = pair_idx[chosen[b], 0].item()
                        chosen_j = pair_idx[chosen[b], 1].item()

                        token_prob_matrix = torch.zeros(N, N, device=device)
                        for p in range(num_pairs):
                            i = pair_idx[p, 0].item()
                            j = pair_idx[p, 1].item()
                            token_prob_matrix[i, j] = probs[b, p]
                            token_prob_matrix[j, i] = probs[b, p]
                        tokens_2d = active[b, :, :].detach().cpu()

                        merge_viz_events.append({
                            "tokens_2d": tokens_2d,
                            "prob_matrix": token_prob_matrix.detach().cpu(),
                            "chosen_i": chosen_i,
                            "chosen_j": chosen_j,
                            "local_t1_idx": None,    # fill later
                            "local_t2_idx": None,
                            "merged_token_2d": None  # fill after merge
                        })


                else:
                    raise ValueError(f"Unknown sampling policy: {policy}")
                
                local_t1_idx = pair_idx[chosen, 0]
                local_t2_idx = pair_idx[chosen, 1]

                chosen_t1 = active[torch.arange(B), local_t1_idx, :]
                chosen_t2 = active[torch.arange(B), local_t2_idx, :]

                 
                chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

                merged_tokens = model.encode(chosen_x)
                buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

                if len(merge_viz_events) <= 2:   # store only first 2 merges
                    merge_viz_events[-1]["local_t1_idx"] = local_t1_idx[b].item()
                    merge_viz_events[-1]["local_t2_idx"] = local_t2_idx[b].item()
                    merge_viz_events[-1]["merged_token_2d"] = merged_tokens[b, :2].detach().cpu()

            buffer = buf.buffer
            merge_history = buf.get_merge_history()
            active_mask = buf.get_active_mask()
            n_original_tokens = raw_tokens.shape[1]

            unmerge_buf = TokenUnmergeBuffer(buffer=buffer,
                                             active_mask=active_mask,
                                             merges=merge_history,
                                             n_original=n_original_tokens)

            
            while not unmerge_buf.is_done():
                merged_token = unmerge_buf.get_next_to_unmerge()
                pred = model.decode(merged_token)
                t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
                unmerge_buf.step_unmerge(t1_pred, t2_pred)
            reconstructed = unmerge_buf.get_original_tokens()
            
            if recon_fig is None:
                recon_fig, axes = plt.subplots(1, 5, figsize=(20, 4))
                recon_fig.set_title(f"Original (red), Recon (blue)")

                for i in range(5):
                    sample1 = reconstructed[i].detach().cpu()
                    sample2 = raw_tokens[i].detach().cpu()

                    axes[i].scatter(sample1[:, 0], sample1[:, 1], color='blue', alpha=0.7)
                    axes[i].scatter(sample2[:, 0], sample2[:, 1], color='red', alpha=0.7)
                    
                    # axes[i].set_title(f"{sample1.shape}")
                    axes[i].set_xlabel("x")
                    axes[i].set_ylabel("y")


            if not capture_done:
                if len(merge_viz_events) >= 1:
                    orig_traj = raw_tokens[b, :, :2].detach().cpu()
                    recon_traj = reconstructed[b, :, :2].detach().cpu()

                    # print(orig_traj-recon_traj)

                    # --- For each merge event we stored ---
                    for k, event in enumerate(merge_viz_events[:2]):
                        fig = plot_merge_event_full(
                            tokens_2d=event["tokens_2d"],
                            prob_matrix=event["prob_matrix"],
                            chosen_i=event["chosen_i"],
                            chosen_j=event["chosen_j"],
                            orig_traj=orig_traj,
                            recon_traj=recon_traj,
                            merged_token_2d=event["merged_token_2d"],
                            highlight_recon_idx=(event["local_t1_idx"], event["local_t2_idx"]),
                            title_matrix=f"Softmax merge probs (step {k+1})",
                            title_positions=f"Token positions (step {k+1})"
                        )
                        merge_figs.append(fig)
                        plt.close(fig)

                capture_done = True


            loss = criterion(reconstructed, raw_tokens)
            loss.backward()
            optimiser.step()

            total_loss += loss.item()


        
        
        train_loss = total_loss / len(train_loader)

        val_loss = evaluate(model, val_loader, device, config.sampling_policy, current_depth, criterion, config.temperature)

                
        if checkpoint_path and val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)

        log_dict = {
        "train_loss": train_loss,
        "val_loss": val_loss,
        "epoch": epoch,
        "merge_depth": current_depth,
        "merge_event_step_1": wandb.Image(merge_figs[0]),
        # "merge_event_step_2": wandb.Image(merge_figs[1]),
        "recon_fig": wandb.Image(recon_fig)
        }
        plt.close(recon_fig)

        wandb.log(log_dict)
        

    wandb.finish()

# Run

In [None]:
token_dim = 2
latent_dim = 2
hidden_dim = 128
merges = 1

config = SimpleNamespace(
    epochs=10,
    learning_rate=1e-4,
    sampling_policy="softmax",
    batch_size=128,
    device=device,
    model_name=f"softmax-0.5-lr-{1e-4}-merges-{merges}",
    max_depth=merges,
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    temperature=0.5,
)

model = AE_no_lift(
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    latent_dim=latent_dim
    ).to(device)

train(model, train_loader, val_loader, config, device, checkpoint_path=f"./checkpoints/{config.model_name}.pth", max_depth=merges)