In [1]:
from train_any_order import BracketFlowModule
from sampling import any_order_mask_insertion_euler_sampling, SamplingTraceDatapoint

In [2]:
checkpoint_path = "checkpoints/bracket-flow/bracket-any-order-mask-flow-epoch=89-train/total_loss=3.7114.ckpt"
model = BracketFlowModule.load_from_checkpoint(checkpoint_path)

In [13]:
steps = 2000
batch_size = 20
samples, trace = any_order_mask_insertion_euler_sampling(
    model,
    model.interpolant,
    steps=steps,
    mask=0,
    pad=3,
    batch_size=batch_size,
    max_length=64,
    return_trace=True,
)

In [4]:
def process_trace(trace: list[SamplingTraceDatapoint]):
    event_type_mapping = dict(change="c", insertion="i")
    token_mapping = {0: "m", 1: "(", 2: ")"}

    def _process_datapoint(datapoint: SamplingTraceDatapoint):
        return dict(
            t=datapoint.t,
            a=event_type_mapping[datapoint.event_type],
            tk=token_mapping[datapoint.token],
            i=datapoint.position,
        )

    return [_process_datapoint(datapoint) for datapoint in trace]

In [11]:
samples

tensor([[1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1, 2, 1, 2, 1, 1,
         1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1,
         2, 2, 1, 1, 0, 1, 2, 2, 2, 1, 1, 2, 2, 1, 2, 1]], device='cuda:0')

In [12]:
from IPython.display import display, HTML
import json

# Select the index of the tensor to visualize
id = 0
assert 0 <= id < batch_size, "id must be in [0, batch_size)"

events = process_trace(trace[id])
print(events)

display(
    HTML(f"""
<style>
.token {{
    display: inline-block;
    padding: 5px 10px;
    margin: 2px;
    border: 1px solid #aaa;
    border-radius: 4px;
    font-family: monospace;
    transition: background 0.5s;
}}
.hi-i {{ background: lightgreen; }}
.hi-c {{ background: yellow; }}
#s {{ min-height: 30px; margin-bottom: 20px; }}
</style>
<div id="s"></div>
<button id="r">Replay</button>
<script>
var ev = {json.dumps(events)};
var s = [];
var delay = 100; // Delay between steps (ms)

function render(hIdx, action) {{
    var container = document.getElementById("s");
    container.innerHTML = "";
    s.forEach(function(token, idx) {{
        var span = document.createElement("span");
        span.className = "token" + (idx === hIdx ? " hi-" + action : "");
        span.textContent = token;
        container.appendChild(span);
    }});
}}

async function run() {{
    s = [];
    render();
    
    for (const e of ev) {{
        if (e.a === "i") {{
            // Insert token at position e.i
            var pos = (e.i < 0 || e.i > s.length) ? Math.floor(s.length / 2) : e.i;
            s.splice(pos, 0, e.tk);
            render(pos, "i");
        }} else {{
            // Change token at position e.i
            var pos = (e.i === undefined || e.i < 0 || e.i >= s.length) ? Math.floor(s.length / 2) : e.i;
            s[pos] = e.tk;
            render(pos, "c");
        }}
        
        // Small delay to make the animation visible (but still synchronous in execution order)
        await new Promise(resolve => setTimeout(resolve, delay));
    }}
}}

document.getElementById("r").onclick = run;
run(); // Start automatically
</script>
""")
)

[{'t': 0.019999999552965164, 'a': 'i', 'tk': 'm', 'i': 0}, {'t': 0.1250000149011612, 'a': 'c', 'tk': ')', 'i': 0}, {'t': 0.1300000101327896, 'a': 'i', 'tk': 'm', 'i': 1}, {'t': 0.15999998152256012, 'a': 'i', 'tk': 'm', 'i': 2}, {'t': 0.21499992907047272, 'a': 'c', 'tk': ')', 'i': 1}, {'t': 0.21499992907047272, 'a': 'i', 'tk': 'm', 'i': 3}, {'t': 0.22999991476535797, 'a': 'i', 'tk': 'm', 'i': 4}, {'t': 0.24999989569187164, 'a': 'i', 'tk': 'm', 'i': 5}, {'t': 0.2649998962879181, 'a': 'i', 'tk': 'm', 'i': 5}, {'t': 0.2749998867511749, 'a': 'i', 'tk': 'm', 'i': 4}, {'t': 0.2949998676776886, 'a': 'i', 'tk': 'm', 'i': 6}, {'t': 0.33999982476234436, 'a': 'c', 'tk': '(', 'i': 3}, {'t': 0.35999980568885803, 'a': 'i', 'tk': 'm', 'i': 5}, {'t': 0.35999980568885803, 'a': 'i', 'tk': 'm', 'i': 0}, {'t': 0.3849997818470001, 'a': 'c', 'tk': '(', 'i': 3}, {'t': 0.42499974370002747, 'a': 'c', 'tk': ')', 'i': 7}, {'t': 0.4599997103214264, 'a': 'i', 'tk': 'm', 'i': 10}, {'t': 0.4899996817111969, 'a': 'c',