In [1]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import io
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from util import TokenMergeBuffer, TokenUnmergeBuffer, EarlyStopper
from models import AE_no_lift
from toy_data.environments import RandomWalker

from types import SimpleNamespace

%load_ext autoreload
%autoreload 2

# Dataset

In [2]:
batch_size = 128

walker = RandomWalker(
    n=1000, 
    length=10, 
    dim=2, 
    step_scale=0.1, 
    init="normal", 
    drift="centre").generate()

dataset = torch.tensor(walker.data, dtype=torch.float32)

train_pct = 0.8
train_size = int(train_pct * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

# Train

In [3]:
# def plot_merge_event(tokens_2d, prob_matrix, chosen_i, chosen_j, orig_traj, recon_traj, merged_token_2d=None,
#                      title_positions="Token positions",
#                      title_matrix="Softmax merge probs",
#                      title_recon="Original (blue), Recon (red)"):

#     # ---------------- existing code ----------------
#     if hasattr(tokens_2d, "cpu"):
#         tokens_2d = tokens_2d.detach().cpu().numpy()
#     if hasattr(prob_matrix, "cpu"):
#         prob_matrix = prob_matrix.detach().cpu().numpy()

#     N = tokens_2d.shape[0]
#     fig, axes = plt.subplots(1, 2, figsize=(10,4))


#     # --------------------------------------------------------
#     # 2 — Softmax matrix
#     # --------------------------------------------------------
#     ax = axes[0]
#     im = ax.imshow(prob_matrix, interpolation="nearest")
#     fig.colorbar(im, ax=ax, fraction=0.046)
#     ax.set_title(title_matrix)
#     ax.set_xlabel("Token index")
#     ax.set_ylabel("Token index")

#     ax.set_xticks(range(N))
#     ax.set_yticks(range(N))

#     ax.scatter([chosen_j], [chosen_i], s=250, facecolor="none",
#                edgecolor="red", linewidth=2)
#     ax.scatter([chosen_i], [chosen_j], s=250, facecolor="none",
#                edgecolor="red", linewidth=2)

#     ax.set_xticks(np.arange(-.5, N, 1), minor=True)
#     ax.set_yticks(np.arange(-.5, N, 1), minor=True)
#     ax.grid(which="minor", color="black", linewidth=0.2)

#         # ================= LEFT PLOT =====================
#     ax = axes[1]
#     ax.scatter(tokens_2d[:,0], tokens_2d[:,1], s=40)

#     for idx, (x,y) in enumerate(tokens_2d):
#         ax.text(x+0.02, y+0.02, str(idx))

#     # highlight original merged tokens
#     ax.scatter(
#         tokens_2d[[chosen_i, chosen_j],0],
#         tokens_2d[[chosen_i, chosen_j],1],
#         s=200, edgecolor="black", facecolor="none", linewidth=2,
#         label="Original merge pair"
#     )

#     # --- NEW: plot the merged token ---
#     if merged_token_2d is not None:
#         if hasattr(merged_token_2d, "cpu"):
#             merged_token_2d = merged_token_2d.detach().cpu().numpy()

#         ax.scatter(
#             [merged_token_2d[0]], [merged_token_2d[1]],
#             s=220, marker="*", color="red", linewidth=2,
#             label="Merged token"
#         )

#     ax.set_title(f"{title_positions}\nMerged: ({chosen_i}, {chosen_j})")
#     ax.set_aspect("equal")
#     ax.set_xlabel("dim 1")
#     ax.set_ylabel("dim 2")
#     # ax.legend()
#     # # --------------------------------------------------------
#     # # 1 — Token positions
#     # # --------------------------------------------------------
#     # ax = axes[1]
#     # ax.scatter(tokens_2d[:,0], tokens_2d[:,1], s=40)
#     # for idx, (x,y) in enumerate(tokens_2d):
#     #     ax.text(x+0.02, y+0.02, str(idx))

#     # ax.scatter(tokens_2d[[chosen_i, chosen_j],0],
#     #            tokens_2d[[chosen_i, chosen_j],1],
#     #            s=200, edgecolor="black", facecolor="none", linewidth=2)

#     # ax.set_title(f"{title_positions}\nMerged: ({chosen_i}, {chosen_j})")
#     # ax.set_xlabel("dim 1")
#     # ax.set_ylabel("dim 2")
#     # ax.set_aspect("equal")

#     # --------------------------------------------------------
#     # 3 — Original vs Reconstruction trajectory
#     # --------------------------------------------------------
#     ax = axes[2]

#     # draw lines
#     for (x0, y0), (x1, y1) in zip(orig_traj[:-1], orig_traj[1:]):
#         ax.plot([x0, x1], [y0, y1], color="blue", linewidth=1)

#     for (x0, y0), (x1, y1) in zip(recon_traj[:-1], recon_traj[1:]):
#         ax.plot([x0, x1], [y0, y1], color="red", linewidth=1)

#     ax.set_aspect("equal")
#     ax.set_title(title_recon)
#     ax.set_xlabel("x")
#     ax.set_ylabel("y")

#     plt.tight_layout()
#     return fig


In [4]:
def plot_merge_event_full(tokens_2d, prob_matrix, chosen_i, chosen_j,
                     orig_traj, recon_traj, merged_token_2d=None,
                     title_positions="Token positions",
                     title_matrix="Softmax merge probs",
                     title_recon="Original (blue), Recon (red)",
                     highlight_recon_idx=None):

    # ------------------------------------
    # Convert to numpy
    # ------------------------------------
    if hasattr(tokens_2d, "cpu"):
        tokens_2d = tokens_2d.detach().cpu().numpy()
    if hasattr(prob_matrix, "cpu"):
        prob_matrix = prob_matrix.detach().cpu().numpy()
    if hasattr(orig_traj, "cpu"):
        orig_traj = orig_traj.detach().cpu().numpy()
    if hasattr(recon_traj, "cpu"):
        recon_traj = recon_traj.detach().cpu().numpy()

    # Create **3** subplots
    fig, axes = plt.subplots(1, 3, figsize=(15,4))
    N = tokens_2d.shape[0]

    # ========================================================
    # 1 — Softmax matrix
    # ========================================================
    np.fill_diagonal(prob_matrix, np.nan)
    ax = axes[0]
    im = ax.imshow(prob_matrix, interpolation="nearest")
    fig.colorbar(im, ax=ax, fraction=0.046)
    ax.set_title(title_matrix)
    ax.set_xlabel("Token index")
    ax.set_ylabel("Token index")

    ax.set_xticks(range(N))
    ax.set_yticks(range(N))

    # highlight chosen merge pair
    ax.scatter([chosen_j], [chosen_i], s=250, facecolor="none",
               edgecolor="red", linewidth=2)
    ax.scatter([chosen_i], [chosen_j], s=250, facecolor="none",
               edgecolor="red", linewidth=2)

    ax.set_xticks(np.arange(-.5, N, 1), minor=True)
    ax.set_yticks(np.arange(-.5, N, 1), minor=True)
    ax.grid(which="minor", color="black", linewidth=0.2)

    # ========================================================
    # 2 — Token positions + merged token
    # ========================================================
    ax = axes[1]
    ax.scatter(tokens_2d[:,0], tokens_2d[:,1], s=40)

    for idx, (x,y) in enumerate(tokens_2d):
        ax.text(x+0.02, y+0.02, str(idx))

    # highlight original pair
    ax.scatter(
        tokens_2d[[chosen_i, chosen_j],0],
        tokens_2d[[chosen_i, chosen_j],1],
        s=200, edgecolor="black", facecolor="none", linewidth=2
    )

    # merged token (star)
    if merged_token_2d is not None:
        if hasattr(merged_token_2d, "cpu"):
            merged_token_2d = merged_token_2d.detach().cpu().numpy()

        ax.scatter(
            [merged_token_2d[0]], [merged_token_2d[1]],
            s=220, marker="*", color="red", linewidth=2,
        )

    ax.set_title(f"{title_positions}\nMerged: ({chosen_i}, {chosen_j})")
    ax.set_aspect("equal")
    ax.set_xlabel("dim 1")
    ax.set_ylabel("dim 2")

    # ========================================================
    # 3 — Original vs Reconstruction trajectory
    # ========================================================
    ax = axes[2]

    

    # draw original trajectory
    for (x0, y0), (x1, y1) in zip(orig_traj[:-1], orig_traj[1:]):
        ax.plot([x0, x1], [y0, y1], color="blue", linewidth=1)

    # draw reconstructed trajectory
    for (x0, y0), (x1, y1) in zip(recon_traj[:-1], recon_traj[1:]):
        ax.plot([x0, x1], [y0, y1], color="red", linewidth=1)

    if highlight_recon_idx is not None:
        i, j = highlight_recon_idx
        ax.scatter([recon_traj[i,0]], [recon_traj[i,1]],
                s=50, marker="*", color="red", edgecolor="black", linewidth=1)
        ax.scatter([recon_traj[j,0]], [recon_traj[j,1]],
                s=50, marker="*", color="red", edgecolor="black", linewidth=1)


    ax.set_aspect("equal")
    ax.set_title(title_recon)
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    plt.tight_layout()
    return fig


import numpy as np
import matplotlib.pyplot as plt

def plot_two_merge_event(
    tokens_2d_step1,          # (L1, 2) coords BEFORE merge #1
    prob_matrix1,             # (L1, L1) softmax AFTER masking/softmax
    chosen1,                  # (i1, j1)
    tokens_2d_step2,          # (L2, 2) coords AFTER merge #1, BEFORE merge #2
    prob_matrix2,             # (L2, L2) softmax at step 2
    chosen2,                  # (i2, j2)
    figsize=(12, 4)
):
    """
    Visualizes exactly two merge steps:
    Merge #1 probability matrix + chosen pair
    Merge #2 probability matrix + chosen pair
    Final token 2-D scatter after the second merge.
    """

    fig, axes = plt.subplots(1, 3, figsize=figsize)

    # ---- Helper to render each probability map ----
    def render_prob(ax, P, chosen, title):
        P = P.copy()
        np.fill_diagonal(P, np.nan)              # mask diagonal
        
        im = ax.imshow(P, interpolation='nearest', cmap='viridis')
        ax.set_title(title)

        # highlight chosen pair
        i, j = chosen
        ax.scatter([j], [i], s=80, edgecolor="red", facecolor="none", linewidth=1.5)
        ax.scatter([i], [j], s=80, edgecolor="red", facecolor="none", linewidth=1.5)

        # nice grid
        ax.set_xticks(np.arange(P.shape[1]) + 0.5, minor=True)
        ax.set_yticks(np.arange(P.shape[0]) + 0.5, minor=True)
        ax.grid(which="minor", color="w", linewidth=0.4)
        ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)

        # colorbar
        fig.colorbar(im, ax=ax, fraction=0.10, pad=0.03)

    # ---- Panel 1: merge #1 probability map ----
    render_prob(axes[0], prob_matrix1, chosen1, "Merge #1 Softmax")

    # ---- Panel 2: merge #2 probability map ----
    render_prob(axes[1], prob_matrix2, chosen2, "Merge #2 Softmax")

    # ---- Panel 3: final token positions after merge #2 ----
    ax = axes[2]
    ax.scatter(tokens_2d_step2[:,0], tokens_2d_step2[:,1],
               s=40, color="black")

    # highlight merged token #2
    mi2, mj2 = chosen2
    merged_xy = tokens_2d_step2  # user passes correct coords after merge #2

    # You can choose to mark the newly created token with a star if desired
    # ax.scatter([merged_xy[-1,0]], [merged_xy[-1,1]],
    #            s=200, marker="*", color="red", edgecolor="black")

    ax.set_title("Tokens After Merge #2")
    ax.set_aspect('equal', 'box')
    ax.grid(True, alpha=0.2)
    ax.tick_params(labelbottom=False, labelleft=False)

    plt.tight_layout()
    return fig


In [5]:
def get_merge_depth(curr_epoch, total_epochs, max_depth, power=2.0):
    progress = curr_epoch / total_epochs
    scaled = progress ** (1.0 / power)     # fast rise
    depth = int(scaled * max_depth)        # quantize into buckets
    return min(max(depth, 1), max_depth)

def plot_token_matrix(matrix, title="Token Merge Probabilities"):
    if hasattr(matrix, "cpu"):
        matrix = matrix.cpu().numpy()

    fig, ax = plt.subplots(figsize=(5,4))
    im = ax.imshow(matrix, interpolation="nearest")
    ax.set_title(title)
    ax.set_xlabel("Token index")
    ax.set_ylabel("Token index")
    fig.colorbar(im, ax=ax)

    N = matrix.shape[0]
    ax.set_xticks(range(N))
    ax.set_yticks(range(N))

    ax.set_xticks(np.arange(-.5, N, 1), minor=True)
    ax.set_yticks(np.arange(-.5, N, 1), minor=True)
    ax.grid(which="minor", color="black", linestyle="-", linewidth=0.2)

    plt.tight_layout()
    return fig


def plot_token_positions(tokens_2d, chosen_i, chosen_j, title="2D Token Embeddings"):
    if hasattr(tokens_2d, "cpu"):
        tokens_2d = tokens_2d.detach().cpu().numpy()

    fig, ax = plt.subplots(figsize=(5,4))

    ax.scatter(tokens_2d[:,0], tokens_2d[:,1])

    for idx, (x,y) in enumerate(tokens_2d):
        ax.text(x+0.02, y+0.02, str(idx))

    ax.scatter(tokens_2d[[chosen_i, chosen_j],0],
               tokens_2d[[chosen_i, chosen_j],1],
               s=150, edgecolor="black", linewidth=2)

    ax.set_title(title)
    ax.set_xlabel("dim 1")
    ax.set_ylabel("dim 2")
    ax.set_aspect("equal")

    plt.tight_layout()
    return fig


def visualise(orig, recon, show=False):

    if isinstance(orig, torch.Tensor):
        orig = orig.detach().cpu().numpy()
    if isinstance(recon, torch.Tensor):
        recon = recon.detach().cpu().numpy()

    def plot_trajectory(traj, ax, color="blue", label=None, linewidth=1, linestyle="solid"):
        for (x0, y0), (x1, y1) in zip(traj[:-1], traj[1:]):
            ax.plot([x0, x1], [y0, y1], color=color, linewidth=1, linestyle=linestyle)

        if label:
            ax.plot([], [], color=color, label=label)

        ax.set_aspect("equal")

        return ax
    
    fig, ax = plt.subplots(figsize=(6, 6))
    plot_trajectory(orig, ax=ax, color="blue", label="Original", linewidth=2)
    plot_trajectory(recon, ax=ax, color="red", label="Reconstruction", linestyle="solid")

    # ax.set_title("")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend()

    # ax.set_xlim(-2, 2)
    # ax.set_ylim(-2, 2)
    ax.set_aspect("equal")

    if show: 
        plt.show()
    else:
        plt.close(fig)
        return fig
    
def visualise_multiple(orig_list, recon_list, show=False):
    """
    orig_list, recon_list: lists of trajectories (numpy arrays or torch tensors)
    """
    n = len(orig_list)
    fig, axes = plt.subplots(1, n, figsize=(6 * n, 6))

    if n == 1:
        axes = [axes]

    for i, (orig, recon) in enumerate(zip(orig_list, recon_list)):
        if isinstance(orig, torch.Tensor):
            orig = orig.detach().cpu().numpy()
        if isinstance(recon, torch.Tensor):
            recon = recon.detach().cpu().numpy()

        ax = axes[i]

        # plot original
        line_orig, = ax.plot([], [], color="blue", linewidth=1)
        for (x0, y0), (x1, y1) in zip(orig[:-1], orig[1:]):
            ax.plot([x0, x1], [y0, y1], color="blue", linewidth=1)
        # plot reconstruction
        line_recon, = ax.plot([], [], color="red", linewidth=1)
        for (x0, y0), (x1, y1) in zip(recon[:-1], recon[1:]):
            ax.plot([x0, x1], [y0, y1], color="red", linewidth=1)

        ax.set_aspect("equal")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_title(f"Sample {i+1}")

        if i == 0:
            legend_lines = [line_orig, line_recon]

    fig.legend(legend_lines, ["Original", "Reconstruction"], loc="upper right")
    
    if show:
        plt.show()
    else:
        plt.close(fig)
        return fig


def visualise_merge_progress(orig, merge_snapshots, show=False):
    """
    Visualize step-by-step merging of tokens.
    
    Parameters:
    - orig: original trajectory (N, 2)
    - merge_snapshots: list of intermediate merged tokens (arrays of shape M,2)
    - show: whether to plt.show() or just return the figure
    """
    if isinstance(orig, torch.Tensor):
        orig = orig.detach().cpu().numpy()
    merge_snapshots = [s.detach().cpu().numpy() if isinstance(s, torch.Tensor) else s for s in merge_snapshots]

    n_steps = len(merge_snapshots)
    fig, axes = plt.subplots(1, n_steps + 1, figsize=(5 * (n_steps + 1), 5))

    if n_steps == 0:
        axes = [axes]
    if n_steps == 1:
        axes = [axes[0], axes[1]]

    # Plot original on first subplot
    ax = axes[0]
    ax.scatter(orig[:,0], orig[:,1], color='blue', label='Original', s=30)
    ax.set_title("Original")
    ax.set_aspect("equal")
    ax.legend()

    # Plot merged snapshots step by step
    for i, merged in enumerate(merge_snapshots):
        ax = axes[i+1]
        ax.scatter(orig[:,0], orig[:,1], color='blue', alpha=0.2, label='Original', s=20)
        ax.scatter(merged[:,0], merged[:,1], color='red', label=f'Merged step {i+1}', s=50)
        ax.set_aspect("equal")
        ax.set_title(f"Merge step {i+1}")
        ax.legend()

    plt.tight_layout()
    if show:
        plt.show()
    else:
        plt.close(fig)
        return fig


# def plot_merge_event(tokens_2d, prob_matrix, chosen_i, chosen_j, 
#                      title_positions="Token positions", 
#                      title_matrix="Softmax merge probs"):
#     """
#     tokens_2d: (N,2) numpy or tensor
#     prob_matrix: (N,N) numpy or tensor
#     chosen_i, chosen_j: merged indices
#     """
#     # tensor → numpy
#     if hasattr(tokens_2d, "cpu"):
#         tokens_2d = tokens_2d.detach().cpu().numpy()
#     if hasattr(prob_matrix, "cpu"):
#         prob_matrix = prob_matrix.detach().cpu().numpy()

#     N = tokens_2d.shape[0]

#     fig, axes = plt.subplots(1, 2, figsize=(10,4))

#     # --------------------------------------------------------
#     # LEFT PLOT — Token positions + highlight merged pair
#     # --------------------------------------------------------
#     ax = axes[0]
#     ax.scatter(tokens_2d[:,0], tokens_2d[:,1], s=40)

#     for idx, (x,y) in enumerate(tokens_2d):
#         ax.text(x+0.02, y+0.02, str(idx))

#     # highlight merged tokens
#     ax.scatter(tokens_2d[[chosen_i, chosen_j],0],
#                tokens_2d[[chosen_i, chosen_j],1],
#                s=200, edgecolor="black", facecolor="none", linewidth=2)

#     ax.set_title(f"{title_positions}\nMerged: ({chosen_i}, {chosen_j})")
#     ax.set_aspect("equal")
#     ax.set_xlabel("dim 1")
#     ax.set_ylabel("dim 2")


#     # --------------------------------------------------------
#     # RIGHT PLOT — Softmax probability matrix + highlight pair
#     # --------------------------------------------------------
#     ax = axes[1]
#     im = ax.imshow(prob_matrix, interpolation="nearest")
#     fig.colorbar(im, ax=ax)

#     ax.set_title(title_matrix)
#     ax.set_xlabel("Token index")
#     ax.set_ylabel("Token index")

#     ax.set_xticks(range(N))
#     ax.set_yticks(range(N))

#     # highlight chosen merge in matrix
#     ax.scatter([chosen_j], [chosen_i], s=250, facecolor="none",
#                edgecolor="red", linewidth=2)
#     ax.scatter([chosen_i], [chosen_j], s=250, facecolor="none",
#                edgecolor="red", linewidth=2)

#     # minor grid
#     ax.set_xticks(np.arange(-.5, N, 1), minor=True)
#     ax.set_yticks(np.arange(-.5, N, 1), minor=True)
#     ax.grid(which="minor", color="black", linewidth=0.2)

#     plt.tight_layout()
#     return fig







@torch.no_grad()
def evaluate(model, val_loader, device, policy, current_depth, criterion, temperature=1):
    model.eval()
    total_loss = 0.0

    for batch in val_loader:
        raw_tokens = batch.to(device)
        buf = TokenMergeBuffer(raw_tokens)

        while buf.can_merge(current_depth):
            active = buf.get_active_tokens()
            B, N, D = active.shape
            pair_idx = torch.combinations(torch.arange(N, device=device))
            num_pairs = pair_idx.shape[0]

            t1 = active[:, pair_idx[:, 0], :]
            t2 = active[:, pair_idx[:, 1], :]

            t1 = t1.detach()
            t2 = t2.detach()

            x = torch.cat([t1, t2], dim=-1)

            x_hat = model(x)

            t1_hat, t2_hat = torch.chunk(x_hat, 2, dim=-1)                    

            mse = nn.MSELoss(reduction="none")
            recon_t1 = mse(t1, t1_hat).mean(dim=-1)
            recon_t2 = mse(t2, t2_hat).mean(dim=-1)
            loss_per_pair = recon_t1 + recon_t2

            if policy == "argmin":
                chosen = loss_per_pair.argmin(dim=1)
            elif policy == "softmax":
                probs = torch.softmax(-loss_per_pair / temperature, dim=1)
                chosen = torch.multinomial(probs, 1).squeeze(1)
            elif policy == "uniform":
                chosen = torch.randint(0, num_pairs, (B,), device=device)
            else:
                raise ValueError(f"Unknown policy {policy}")

            local_t1_idx = pair_idx[chosen, 0]
            local_t2_idx = pair_idx[chosen, 1]

            chosen_t1 = active[torch.arange(B), local_t1_idx, :]
            chosen_t2 = active[torch.arange(B), local_t2_idx, :]

            chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

            merged_tokens = model.encode(chosen_x)

            buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

        buffer = buf.buffer
        merge_history = buf.get_merge_history()
        active_mask = buf.get_active_mask()
        n_original_tokens = raw_tokens.shape[1]    
        unmerge_buf = TokenUnmergeBuffer(buffer=buffer, active_mask=active_mask, merges=merge_history, n_original=n_original_tokens) 

        while not unmerge_buf.is_done():
            merged_token = unmerge_buf.get_next_to_unmerge()
            pred = model.decode(merged_token)
            t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
            unmerge_buf.step_unmerge(t1_pred, t2_pred)
        
        reconstructed = unmerge_buf.get_original_tokens()

        loss = criterion(reconstructed, raw_tokens)

        total_loss += loss.item() # TODO: eventuell ohne .item()

    return total_loss / len(val_loader)



def train(
        model,
        train_loader,
        val_loader,
        config,
        device,
        checkpoint_path,
        max_depth,
        ):
    
    model.to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.MSELoss()
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimiser, mode='min', factor=0.8, patience=15, verbose=True, min_lr=1e-6
    # )    
    stopper = EarlyStopper(patience=10, min_delta=1e-4)
    best_val_loss = float('inf')


    wandb.init(
        project="random-walker",
        config=config,
        name=config.model_name
    )

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0.0

        orig_samples = None
        recon_samples = None

        capture_done = False


        for batch in train_loader:
            optimiser.zero_grad()
            raw_tokens = batch[0].to(device) if isinstance(batch, (list, tuple)) else batch.to(device)

            if orig_samples is None:
                orig_samples = raw_tokens

            buf = TokenMergeBuffer(raw_tokens)

            current_depth = get_merge_depth(epoch+1, config.epochs, config.max_depth)
            # current_depth = max_depth
            
            merge_snapshots_per_sample = [[] for _ in range(5)] 

            while buf.can_merge(current_depth):
                active = buf.get_active_tokens()  # (B, N, D)
                B, N, D = active.shape

                for i in range(5):
                    merge_snapshots_per_sample[i].append(active[i, :, :2].detach().cpu().numpy())

                pair_idx = torch.combinations(torch.arange(N, device=device))
                num_pairs = pair_idx.shape[0]

                t1 = active[:, pair_idx[:, 0], :]
                t2 = active[:, pair_idx[:, 1], :]

                t1 = t1.detach()
                t2 = t2.detach()

                x = torch.cat([t1, t2], dim=-1)

                x_hat = model(x)

                t1_hat, t2_hat = torch.chunk(x_hat, 2, dim=-1)                    

                mse = nn.MSELoss(reduction="none")
                recon_t1 = mse(t1, t1_hat).mean(dim=-1)
                recon_t2 = mse(t2, t2_hat).mean(dim=-1)
                loss_per_pair = recon_t1 + recon_t2

                policy = config.sampling_policy
                if policy == "argmin":
                    chosen = loss_per_pair.argmin(dim=1)
                elif policy == "uniform":
                    chosen = torch.randint(0, num_pairs, (B,), device=device)
                elif policy == "softmax":
                    probs = torch.softmax(-loss_per_pair / config.temperature, dim=1)
                    chosen = torch.multinomial(probs, 1).squeeze(1)

                   # --- NEW: Collect visualisation for the first sample in batch ---
                    b = 0  # pick one sample to visualize

                    # chosen pair for this step (for sample b)
                    chosen_i = pair_idx[chosen[b], 0].item()
                    chosen_j = pair_idx[chosen[b], 1].item()

                    # --- Token correlation matrix for this step ---
                    token_prob_matrix = torch.zeros(N, N, device=device)

                    for p in range(num_pairs):
                        i = pair_idx[p, 0].item()
                        j = pair_idx[p, 1].item()
                        token_prob_matrix[i, j] = probs[b, p]
                        token_prob_matrix[j, i] = probs[b, p]

                    # if "token_prob_matrix_frames" not in locals():
                    #     token_prob_matrix_frames = []
                    # token_prob_matrix_frames.append(token_prob_matrix.detach().cpu())

                    # --- 2D token visualisation (use first 2 dims) ---
                    tokens_2d = active[b, :, :2]

                    # if "token_position_frames" not in locals():
                    #     token_position_frames = []
                    # token_position_frames.append((tokens_2d.detach().cpu(),
                    #                             chosen_i,
                    #                             chosen_j))
                    # ------------------------------------------------------------
                    # if not capture_done:
                    #     fig_merge = plot_merge_event(tokens_2d, token_prob_matrix, chosen_i, chosen_j)
                    #     wandb.log({"merge_event": wandb.Image(fig_merge)})
                    #     plt.close(fig_merge)
                    #     capture_done = True


                else:
                    raise ValueError(f"Unknown sampling policy: {policy}")
                
                local_t1_idx = pair_idx[chosen, 0]
                local_t2_idx = pair_idx[chosen, 1]

                chosen_t1 = active[torch.arange(B), local_t1_idx, :]
                chosen_t2 = active[torch.arange(B), local_t2_idx, :]

                 
                chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

                merged_tokens = model.encode(chosen_x)
                buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

                merged_token_2d = merged_tokens[b, :2].detach().cpu() # part of viz 


            buffer = buf.buffer
            merge_history = buf.get_merge_history()
            active_mask = buf.get_active_mask()
            n_original_tokens = raw_tokens.shape[1]

            unmerge_buf = TokenUnmergeBuffer(buffer=buffer,
                                             active_mask=active_mask,
                                             merges=merge_history,
                                             n_original=n_original_tokens)

            while not unmerge_buf.is_done():
                merged_token = unmerge_buf.get_next_to_unmerge()
                pred = model.decode(merged_token)
                t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
                unmerge_buf.step_unmerge(t1_pred, t2_pred)
            reconstructed = unmerge_buf.get_original_tokens()
            
            if recon_samples is None:
                recon_samples = reconstructed

            if not capture_done:
                # b=0 sample trajectory
                orig_traj = raw_tokens[b, :, :2]
                recon_traj = reconstructed[b, :, :2]   # NOTE: move reconstruction before visualisation

                fig_merge = plot_merge_event_full(
                    tokens_2d=tokens_2d,
                    prob_matrix=token_prob_matrix,
                    chosen_i=chosen_i,
                    chosen_j=chosen_j,
                    orig_traj=orig_traj,
                    recon_traj=recon_traj,
                    merged_token_2d=merged_token_2d,
                    highlight_recon_idx=(local_t1_idx[b].item(), local_t2_idx[b].item())
                )

                wandb.log({f"merge_event": wandb.Image(fig_merge)})
                plt.close(fig_merge)
                capture_done = True


            loss = criterion(reconstructed, raw_tokens)
            loss.backward()
            optimiser.step()

            total_loss += loss.item()


        
        
        train_loss = total_loss / len(train_loader)

        val_loss = evaluate(model, val_loader, device, config.sampling_policy, current_depth, criterion, config.temperature)
        
        # scheduler.step(val_loss)


                
        if checkpoint_path and val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)

            # fig = visualise(orig_samples, recon_samples) 
            
            orig_samples = [orig_samples[i] for i in range(5)]
            recon_samples = [recon_samples[i] for i in range(5)]
            fig = visualise_multiple(orig_samples, recon_samples)

            # process_fig = visualise_merge_progress(orig_samples[0], merge_snapshots_per_sample[0])


        # --- FINAL VISUALISATIONS FOR THE EPOCH ---

        # if "token_prob_matrix_frames" in locals() and len(token_prob_matrix_frames) > 0:
        #     last_matrix = token_prob_matrix_frames[-1]
        #     fig1 = plot_token_matrix(last_matrix, title="Token merge matrix (final step)")
        #     wandb.log({"merge_matrix_final": wandb.Image(fig1)})
        #     plt.close(fig1)

        #     token_prob_matrix_frames.clear()
        # if "token_position_frames" in locals() and len(token_position_frames) > 0:
        #     tokens_2d, ci, cj = token_position_frames[-1]
        #     fig1 = plot_token_positions(tokens_2d, ci, cj,
        #                             title=f"Token positions (final step, merged {ci},{cj})")
        #     wandb.log({"token_positions_final": wandb.Image(fig1)})
        #     plt.close(fig1)

        #     token_position_frames.clear()



        log_dict = {
        "train_loss": train_loss,
        "val_loss": val_loss,
        "epoch": epoch,
        "merge_depth": current_depth,
        }

        # if 'corr_fig' in locals() and corr_fig is not None:
        #     log_dict['softmax_corr_matrix'] = wandb.Image(corr_fig)
        
        # if 'fig' in locals() and fig is not None:
        #     log_dict["reconstruction_plot"] = wandb.Image(fig)
        
        # if 'scheduler' in locals():
        #     log_dict["lr"] = scheduler.get_last_lr()[0]

        # TODO: is broken
        # if 'process_fig' in locals() and process_fig is not None:
        #     log_dict["merge_progress_plot"] = wandb.Image(process_fig)

        wandb.log(log_dict)

        # TODO: did stupid shit, is on timeout  
        # if stopper.should_stop(val_loss):
        #     print(f"Early stopping at epoch {epoch}")
        #     break

        

    wandb.finish()

# Run

In [6]:
token_dim = 2
hidden_dim = 128
merges = 1

config = SimpleNamespace(
    epochs=100,
    learning_rate=1e-4,
    sampling_policy="softmax",
    batch_size=128,
    device=device,
    model_name=f"sanity-softmax-0.5-lr-{1e-4}-merges-{merges}",
    max_depth=merges,
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    temperature=0.5,
)

model = AE_no_lift(
    token_dim=token_dim,
    hidden_dim=hidden_dim
    ).to(device)

train(model, train_loader, val_loader, config, device, checkpoint_path=f"./checkpoints/{config.model_name}.pth", max_depth=merges)

[34m[1mwandb[0m: Currently logged in as: [33mfhahn[0m ([33mfabianhahn[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
merge_depth,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▆▄▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▇▆▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
merge_depth,1.0
train_loss,0.00174
val_loss,0.00203
