# RandomWalker Length 10 with one merge sanity checks 

In [1]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import io
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from util import TokenMergeBuffer, TokenUnmergeBuffer, EarlyStopper
from models import AE
from toy_data.environments import RandomWalker

from types import SimpleNamespace

%load_ext autoreload
%autoreload 2

## Dataset

In [2]:
batch_size = 128

walker = RandomWalker(
    n=1000, 
    length=10, 
    dim=2, 
    step_scale=0.1, 
    init="normal", 
    drift="centre").generate()

dataset = torch.tensor(walker.data, dtype=torch.float32)

train_pct = 0.8
train_size = int(train_pct * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [3]:
batch = next(iter(train_loader))
batch.shape

torch.Size([128, 10, 2])

## Train Util

In [20]:
def get_merge_depth(curr_epoch, total_epochs, max_depth):
    """
    For curriculum depth.
    """
    progress = curr_epoch / total_epochs
    depth = int(progress * max_depth)
    depth = depth if depth > 0 else 1

    return depth


def visualise(orig, recon, show=False):

    if isinstance(orig, torch.Tensor):
        orig = orig.detach().cpu().numpy()
    if isinstance(recon, torch.Tensor):
        recon = recon.detach().cpu().numpy()

    def plot_trajectory(traj, ax, color="blue", label=None, linewidth=1, linestyle="solid"):
        for (x0, y0), (x1, y1) in zip(traj[:-1], traj[1:]):
            ax.plot([x0, x1], [y0, y1], color=color, linewidth=1, linestyle=linestyle)

        if label:
            ax.plot([], [], color=color, label=label)

        ax.set_aspect("equal")

        return ax
    
    fig, ax = plt.subplots(figsize=(6, 6))
    plot_trajectory(orig, ax=ax, color="blue", label="Original", linewidth=2)
    plot_trajectory(recon, ax=ax, color="red", label="Reconstruction", linestyle="solid")

    # ax.set_title("")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend()

    # ax.set_xlim(-2, 2)
    # ax.set_ylim(-2, 2)
    ax.set_aspect("equal")

    if show: 
        plt.show()
    else:
        plt.close(fig)
        return fig
    
def visualise_multiple(orig_list, recon_list, show=False):
    """
    orig_list, recon_list: lists of trajectories (numpy arrays or torch tensors)
    """
    n = len(orig_list)
    fig, axes = plt.subplots(1, n, figsize=(6 * n, 6))

    if n == 1:
        axes = [axes]

    for i, (orig, recon) in enumerate(zip(orig_list, recon_list)):
        if isinstance(orig, torch.Tensor):
            orig = orig.detach().cpu().numpy()
        if isinstance(recon, torch.Tensor):
            recon = recon.detach().cpu().numpy()

        ax = axes[i]

        # plot original
        line_orig, = ax.plot([], [], color="blue", linewidth=1)
        for (x0, y0), (x1, y1) in zip(orig[:-1], orig[1:]):
            ax.plot([x0, x1], [y0, y1], color="blue", linewidth=1)
        # plot reconstruction
        line_recon, = ax.plot([], [], color="red", linewidth=1)
        for (x0, y0), (x1, y1) in zip(recon[:-1], recon[1:]):
            ax.plot([x0, x1], [y0, y1], color="red", linewidth=1)

        ax.set_aspect("equal")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_title(f"Sample {i+1}")

        if i == 0:
            legend_lines = [line_orig, line_recon]

    fig.legend(legend_lines, ["Original", "Reconstruction"], loc="upper right")
    
    if show:
        plt.show()
    else:
        plt.close(fig)
        return fig


def visualise_merge_progress(orig, merge_snapshots, show=False):
    """
    Visualize step-by-step merging of tokens.
    
    Parameters:
    - orig: original trajectory (N, 2)
    - merge_snapshots: list of intermediate merged tokens (arrays of shape M,2)
    - show: whether to plt.show() or just return the figure
    """
    if isinstance(orig, torch.Tensor):
        orig = orig.detach().cpu().numpy()
    merge_snapshots = [s.detach().cpu().numpy() if isinstance(s, torch.Tensor) else s for s in merge_snapshots]

    n_steps = len(merge_snapshots)
    fig, axes = plt.subplots(1, n_steps + 1, figsize=(5 * (n_steps + 1), 5))

    if n_steps == 0:
        axes = [axes]
    if n_steps == 1:
        axes = [axes[0], axes[1]]

    # Plot original on first subplot
    ax = axes[0]
    ax.scatter(orig[:,0], orig[:,1], color='blue', label='Original', s=30)
    ax.set_title("Original")
    ax.set_aspect("equal")
    ax.legend()

    # Plot merged snapshots step by step
    for i, merged in enumerate(merge_snapshots):
        ax = axes[i+1]
        ax.scatter(orig[:,0], orig[:,1], color='blue', alpha=0.2, label='Original', s=20)
        ax.scatter(merged[:,0], merged[:,1], color='red', label=f'Merged step {i+1}', s=50)
        ax.set_aspect("equal")
        ax.set_title(f"Merge step {i+1}")
        ax.legend()

    plt.tight_layout()
    if show:
        plt.show()
    else:
        plt.close(fig)
        return fig


@torch.no_grad()
def evaluate(model, val_loader, device, policy, current_depth, criterion, temperature=1):
    model.eval()
    total_loss = 0.0



    for batch in val_loader:
        raw_tokens = batch.to(device)
        buf = TokenMergeBuffer(raw_tokens)

        while buf.can_merge(current_depth):
            active = buf.get_active_tokens()
            B, N, D = active.shape
            pair_idx = torch.combinations(torch.arange(N, device=device))
            num_pairs = pair_idx.shape[0]

            t1 = active[:, pair_idx[:, 0], :]
            t2 = active[:, pair_idx[:, 1], :]

            t1 = t1.detach()
            t2 = t2.detach()
             
            t1_lifted = model.lift(t1)
            t2_lifted = model.lift(t2)

            x = torch.cat([t1_lifted, t2_lifted], dim=-1)

            x_hat = model(x)

            t1_hat_lifted, t2_hat_lifted = torch.chunk(x_hat, 2, dim=-1)                    
            t1_hat = model.unlift(t1_hat_lifted) 
            t2_hat = model.unlift(t2_hat_lifted)  

            mse = nn.MSELoss(reduction="none")
            recon_t1 = mse(t1, t1_hat).mean(dim=-1)
            recon_t2 = mse(t2, t2_hat).mean(dim=-1)
            loss_per_pair = recon_t1 + recon_t2

            if policy == "argmin":
                chosen = loss_per_pair.argmin(dim=1)
            elif policy == "softmax":
                probs = torch.softmax(-loss_per_pair / temperature, dim=1)
                chosen = torch.multinomial(probs, 1).squeeze(1)
            elif policy == "uniform":
                chosen = torch.randint(0, num_pairs, (B,), device=device)
            else:
                raise ValueError(f"Unknown policy {policy}")

            local_t1_idx = pair_idx[chosen, 0]
            local_t2_idx = pair_idx[chosen, 1]

            chosen_t1 = active[torch.arange(B), local_t1_idx, :]
            chosen_t2 = active[torch.arange(B), local_t2_idx, :]
            chosen_t1 = model.lift(chosen_t1)
            chosen_t2 = model.lift(chosen_t2)

            chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

            merged_tokens = model.encode(chosen_x)
            merged_tokens = model.unlift(merged_tokens)

            buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

        buffer = buf.buffer
        merge_history = buf.get_merge_history()
        active_mask = buf.get_active_mask()
        n_original_tokens = raw_tokens.shape[1]    
        unmerge_buf = TokenUnmergeBuffer(buffer=buffer, active_mask=active_mask, merges=merge_history, n_original=n_original_tokens) 

        while not unmerge_buf.is_done():
            merged_token = unmerge_buf.get_next_to_unmerge()
            merged_token = model.lift(merged_token)
            pred = model.decode(merged_token)
            t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
            t1_pred = model.unlift(t1_pred)
            t2_pred = model.unlift(t2_pred)
            unmerge_buf.step_unmerge(t1_pred, t2_pred)
        
        reconstructed = unmerge_buf.get_original_tokens()

        loss = criterion(reconstructed, raw_tokens)

        total_loss += loss.item() # TODO: eventuell ohne .item()

    return total_loss / len(val_loader)



def train(
        model,
        train_loader,
        val_loader,
        config,
        device,
        checkpoint_path,
        max_depth,
        ):
    
    model.to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.MSELoss()
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimiser, mode='min', factor=0.8, patience=15, verbose=True, min_lr=1e-6
    # )    
    stopper = EarlyStopper(patience=10, min_delta=1e-4)
    best_val_loss = float('inf')


    wandb.init(
        project="random-walker",
        config=config,
        name=config.model_name
    )

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0.0

        orig_samples = None
        recon_samples = None

        for batch in train_loader:
            optimiser.zero_grad()
            raw_tokens = batch[0].to(device) if isinstance(batch, (list, tuple)) else batch.to(device)

            if orig_samples is None:
                orig_samples = raw_tokens

            buf = TokenMergeBuffer(raw_tokens)

            # current_depth = get_merge_depth(epoch+1, config.epochs, config.max_depth)
            current_depth = max_depth
            
            merge_snapshots_per_sample = [[] for _ in range(5)] 

            while buf.can_merge(current_depth):
                active = buf.get_active_tokens()  # (B, N, D)
                B, N, D = active.shape

                for i in range(5):
                    merge_snapshots_per_sample[i].append(active[i, :, :2].detach().cpu().numpy())

                pair_idx = torch.combinations(torch.arange(N, device=device))
                num_pairs = pair_idx.shape[0]

                t1 = active[:, pair_idx[:, 0], :]
                t2 = active[:, pair_idx[:, 1], :]

                t1 = t1.detach()
                t2 = t2.detach()

                t1_lifted = model.lift(t1)
                t2_lifted = model.lift(t2)

                x = torch.cat([t1_lifted, t2_lifted], dim=-1)

                x_hat = model(x)

                t1_hat_lifted, t2_hat_lifted = torch.chunk(x_hat, 2, dim=-1)                    
                t1_hat = model.unlift(t1_hat_lifted) 
                t2_hat = model.unlift(t2_hat_lifted)  

                mse = nn.MSELoss(reduction="none")
                recon_t1 = mse(t1, t1_hat).mean(dim=-1)
                recon_t2 = mse(t2, t2_hat).mean(dim=-1)
                loss_per_pair = recon_t1 + recon_t2

                policy = config.sampling_policy
                if policy == "argmin":
                    chosen = loss_per_pair.argmin(dim=1)
                elif policy == "uniform":
                    chosen = torch.randint(0, num_pairs, (B,), device=device)
                elif policy == "softmax":
                    probs = torch.softmax(-loss_per_pair / config.temperature, dim=1)
                    chosen = torch.multinomial(probs, 1).squeeze(1)
                else:
                    raise ValueError(f"Unknown sampling policy: {policy}")
                
                local_t1_idx = pair_idx[chosen, 0]
                local_t2_idx = pair_idx[chosen, 1]

                chosen_t1 = active[torch.arange(B), local_t1_idx, :]
                chosen_t2 = active[torch.arange(B), local_t2_idx, :]
                chosen_t1 = model.lift(chosen_t1)
                chosen_t2 = model.lift(chosen_t2)
                 
                chosen_x = torch.cat([chosen_t1, chosen_t2], dim=-1)

                merged_tokens = model.encode(chosen_x)
                merged_tokens = model.unlift(merged_tokens)
                buf.merge_batch(local_t1_idx, local_t2_idx, merged_tokens)

            buffer = buf.buffer
            merge_history = buf.get_merge_history()
            active_mask = buf.get_active_mask()
            n_original_tokens = raw_tokens.shape[1]

            unmerge_buf = TokenUnmergeBuffer(buffer=buffer,
                                             active_mask=active_mask,
                                             merges=merge_history,
                                             n_original=n_original_tokens)

            while not unmerge_buf.is_done():
                merged_token = unmerge_buf.get_next_to_unmerge()
                merged_token = model.lift(merged_token)
                pred = model.decode(merged_token)
                t1_pred, t2_pred = torch.chunk(pred, 2, dim=-1)
                t1_pred = model.unlift(t1_pred)
                t2_pred = model.unlift(t2_pred)
                unmerge_buf.step_unmerge(t1_pred, t2_pred)
            reconstructed = unmerge_buf.get_original_tokens()
            
            if recon_samples is None:
                recon_samples = reconstructed

            loss = criterion(reconstructed, raw_tokens)
            loss.backward()
            optimiser.step()

            total_loss += loss.item()
        
        
        train_loss = total_loss / len(train_loader)

        val_loss = evaluate(model, val_loader, device, config.sampling_policy, current_depth, criterion, config.temperature)
        
        # scheduler.step(val_loss)


                
        if checkpoint_path and val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)

            # fig = visualise(orig_samples, recon_samples) 
            
            orig_samples = [orig_samples[i] for i in range(5)]
            recon_samples = [recon_samples[i] for i in range(5)]
            fig = visualise_multiple(orig_samples, recon_samples)

            process_fig = visualise_merge_progress(orig_samples[0], merge_snapshots_per_sample[0])
            


        log_dict = {
        "train_loss": train_loss,
        "val_loss": val_loss,
        "epoch": epoch,
        "merge_depth": current_depth,
        }
        
        if 'fig' in locals() and fig is not None:
            log_dict["reconstruction_plot"] = wandb.Image(fig)
        
        if 'scheduler' in locals():
            log_dict["lr"] = scheduler.get_last_lr()[0]

        # TODO: is broken
        # if 'process_fig' in locals() and process_fig is not None:
        #     log_dict["merge_progress_plot"] = wandb.Image(process_fig)

        wandb.log(log_dict)

        # TODO: did stupid shit, is on timeout  
        # if stopper.should_stop(val_loss):
        #     print(f"Early stopping at epoch {epoch}")
        #     break

        

    wandb.finish()

## Run Training

In [21]:
token_dim = 2
hidden_dim = 64
merges = 1

latent_dim = token_dim * (merges+1)  # - epsilon

config = SimpleNamespace(
    epochs=500,
    learning_rate=1e-4,
    sampling_policy="softmax",
    batch_size=128,
    device=device,
    model_name=f"AE-softmax-lr-{1e-4}-latent-{latent_dim}-merges-{merges}",
    max_depth=merges,
    latent_dim=latent_dim,
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    temperature=1,
)

model = AE(
    token_dim=token_dim,
    latent_dim=latent_dim,
    hidden_dim=hidden_dim
    ).to(device)

train(model, train_loader, val_loader, config, device, checkpoint_path=f"./checkpoints/{config.model_name}.pth", max_depth=merges)

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
merge_depth,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,██▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,███▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,499.0
merge_depth,1.0
train_loss,0.00166
val_loss,0.00156
