In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import torch

from startorch import example as ex
from startorch import sequence as seq
from startorch import tensor as te
from startorch import timeseries as ts
from startorch import transformer as tr
from startorch import transition as tran
from startorch.tensor.transformer import Abs
from startorch.utils.seed import get_torch_generator
import itertools

In [None]:
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (16, 5)

In [None]:
random_seed = 15803713921897248048

# Table of content

## Markov + Exponential
[top](#Table-of-content)

In [None]:
rng = get_torch_generator(random_seed)
num_classes = 9
actions = torch.arange(num_classes)

generator_transition = tran.Normalize(
    generator=tran.MaskTransitionGenerator(
        generator=tran.TransformTransitionGenerator(
            generator=tran.TensorTransitionGenerator(te.RandUniform(0.5, 1.0)),
            # generator=tran.TensorTransitionGenerator(te.RandCauchy(0.0, 1.0)),
            # generator=tran.TensorTransitionGenerator(te.Full(1.0)),
            transformer=Abs(),
        ),
        num_mask=num_classes - 2,
    ),
    p=1.0,
)
transition = generator_transition.generate(n=num_classes, rng=rng)

weights = torch.logspace(start=0.1, end=1.0, steps=num_classes)

generator = ex.TransformExampleGenerator(
    generator=ex.TimeSeriesExampleGenerator(
        generators=ts.SequenceTimeSeriesGenerator(
            # generators={"action": seq.RandInt(0, num_classes)},
            generators={"action": seq.MarkovChain(transition=transition)},
        ),
        seq_len=te.Full(128),
    ),
    transformer=tr.Sequential(
        [
            tr.LookupTable(weights=weights, index="action", output="rate"),
            tr.Exponential(rate="rate", output="time"),
        ]
    ),
)

In [None]:
batch = generator.generate(batch_size=10000, rng=rng)

for _key, _value in batch.items():
    pass

action = batch["action"].numpy().ravel()
time = batch["time"].numpy().ravel()

In [None]:
# Analyze the action distribution
action_transition = torch.zeros(num_classes, num_classes)
for i, j in itertools.pairwise(action):
    action_transition[i, j] += 1
action_transition /= action_transition.sum(dim=1)
fig, axes = plt.subplots(figsize=(16, 8), ncols=2)
t0 = axes[0].matshow(transition)
t1 = axes[1].matshow(action_transition)
_ = fig.colorbar(t0, ax=axes[0], fraction=0.046, pad=0.02)
_ = fig.colorbar(t1, ax=axes[1], fraction=0.046, pad=0.02)

In [None]:
# Analyze the time distribution
fig, ax = plt.subplots(figsize=(16, 8))
_ = ax.scatter(time, action, alpha=0.1, color="tab:cyan", label="sample")
_ = ax.scatter(
    [time[action == i].mean().item() for i in range(num_classes)],
    actions,
    color="tab:red",
    label="mean (estimated)",
    marker="x",
)
_ = ax.scatter(
    (1 / weights).numpy(), actions.numpy(), color="black", label="mean (exact)", marker="x"
)
_ = ax.legend()