In [None]:
from mx.progress import init_with_progress
init_with_progress()
from mx.prelude import *
# tf.config.run_functions_eagerly(True)
# tf.data.experimental.enable_debug_mode()

In [None]:
import itertools
from mx import datasets as mxd, tasks as mxt, models as mxm, embeddings as mxe, layers as mxl
from mx.embeddings import DebugCodebookTriples
from mx.pipeline import Pipeline
from mx.tasks import RandomSequenceMSE
from mx.visualizer import Visualizer

def datasets():
    yield Box(
        default=True,
        name="mnist",
        desc="MNIST",
        viz_batch_size=10,
        make=lambda b: mxd.MxMNIST(
            name=b.name,
            desc=b.desc
        ),
    )
    yield Box(
        name="bvh",
        desc="BVH (Dense)",
        viz_batch_size=3,
        make=lambda b: mxd.BvhDataset(
            name=b.name,
            desc=b.desc,
        ),
    )
    yield Box(
        name="bvhdec",
        desc="BVH (Sparse, 1.5 norm-based decimation)",
        viz_batch_size=3,
        make=lambda b: mxd.BvhDataset(
            do_decimate=1.5,
            name=b.name,
            desc=b.desc,
        ),
    )

def tasks(dataset):

    if dataset.name == "mnist":
        yield Box(
            name="next",
            desc="Next Pixel, left-to-right, top-to-bottom",
            make=lambda c_size, b: mxt.VectorSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            name="rand",
            desc="Next Pixel (random order)",
            make=lambda c_size, b: mxt.RandomSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            default=True,
            name="randtoken",
            desc="Next Token (randomized)",
            make=lambda c_size, b: mxt.RandomTokens(
                c_size.chunk_size,
            ),
        )
    elif dataset.name.startswith("bvh"):
        yield Box(
            name="next",
            desc="Next Frame (left-to-right)",
            make=lambda c_size, b: mxt.ForwardAngleAMSE(
                c_size.chunk_size,
                n_test_val_repeats=c_size.n_test_val_repeats,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            name="targ",
            desc="Next Frame, except the start (targeted)",
            make=lambda c_size, b: mxt.TargetedAngleAMSE(
                c_size.chunk_size,
                n_test_val_repeats=c_size.n_test_val_repeats,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            default=True,
            name="rand",
            desc="Next Frame (random order)",
            make=lambda c_size, b: mxt.RandomAngleAMSE(
                c_size.chunk_size,
                n_test_val_repeats=c_size.n_test_val_repeats,
                name=b.name,
                desc=b.desc,
            ),
        )

def embeddings(ds, task):
    if ds.name.startswith("mnist"):
        if task.name == "next":
            yield Box(
                default=True,
                name="codebook",
                desc="Codebook Multidim",
                make=lambda m_size, b: mxe.VectorCodebookMultidim(
                    m_size.n_embd,
                    name=b.name,
                    desc=b.desc,
                ),
            )
            yield Box(
                name="sinusoidal",
                desc="Sinusoidal Multidim",
                make=lambda m_size, b: mxe.VectorSinusoidalMultidim(
                    m_size.n_embd,
                    name=b.name,
                    desc=b.desc,
                ),
            )
        elif task.name.startswith("rand"):
            yield Box(
                name="codebook",
                desc="Codebook (Triples)",
                make=lambda m_size, b: mxe.DebugCodebookTriples(
                    m_size.n_embd,
                    name=b.name,
                    desc=b.desc,
                ),
            )
            yield Box(
                default=True,
                name="dubcodetrip",
                desc="Codebook/Codebook (Triples)",
                make=lambda m_size, b: mxe.DebugCodebookCodebookTriples(
                    m_size.n_embd,
                    name=b.name,
                    desc=b.desc,
                ),
            )
    elif ds.name.startswith("bvh"):
        if task.name == "next":
            yield Box(
                default=True,
                name="anglevec",
                desc="Angle Embd / Codebook Position Embd",
                make=lambda m_size, b: mxe.AngleCodebook(
                    n_embd=m_size.n_embd,
                    n_repeats=1,
                    name=b.name,
                ),
            )
            yield Box(
                name="anglevecsin",
                desc="Angle Embd / Sinusoidal Position Embd",
                make=lambda m_size, b: mxe.AngleSinusoidal(
                    n_embd=m_size.n_embd,
                    n_repeats=1,
                    name=b.name,
                ),
            )
        elif task.name == "rand" or task.name == "targ":
            yield Box(
                name="ranglevec",
                desc="Angle Embd / Codebook Position Embd - (Val, Inp, Tar) triples",
                make=lambda m_size, b: mxe.AngleCodebookTriples(
                    n_embd=m_size.n_embd,
                    n_repeats=1,
                    name=b.name,
                ),
            )
            yield Box(
                default=True,
                name="singletrip",
                desc="Angle Embd / Sinusoidal Position Embd - (Val, Inp, Tar) triples",
                make=lambda m_size, b: mxe.AngleSinusoidalTriples(
                    n_embd=m_size.n_embd,
                    n_repeats=max(1, int(m_size.n_embd / 102)),
                    name=b.name,
                ),
            )

def models(task):
    yield Box(
        name="debugmlp",
        desc="Debug 2-layer MLP",
        make=lambda size, b: mxm.DebugMLP(
            n_hidden=size.n_hidden,
            dropout=0.1,
            name=b.name,
            # desc=b.desc, # models make their own better desc
        ),
    )
    yield Box(
        default=True,
        name="transformer",
        desc="Decoder-only Transformer",
        make=lambda m_size, b: mxm.DecoderOnlyTransformer(
            n_layers=m_size.n_layers,
            n_hidden=m_size.n_hidden,
            n_heads=m_size.n_heads,
            dropout=0.1,
            name=b.name,
            # desc=b.desc, # models make their own better desc
        ),
    )
    yield Box(
        name="resnet",
        desc="Resnet",
        make=lambda m_size, b: mxm.Resnet(
            n_layers=m_size.n_layers,
            n_hidden=m_size.n_hidden,
            dropout=0.1,
            name=b.name,
            # desc=b.desc, # models make their own better desc
        ),
    )


def model_sizes():
    yield Box(
        name="tiny",
        desc="Tiny Size",
        n_embd=32,
        n_hidden=64,
        n_heads=4,
        n_layers=2,
    )
    yield Box(
        name="small",
        desc="Small Size",
        n_embd=64,
        n_hidden=128,
        n_heads=8,
        n_layers=4,
    )
    yield Box(
        default=True,
        name="medium",
        desc="Medium Size",
        n_embd=128,
        n_hidden=256,
        n_heads=8,
        n_layers=6,
    )
    yield Box(
        name="large",
        desc="Large Size",
        n_embd=256,
        n_hidden=512,
        n_heads=12,
        n_layers=8,
    )
    yield Box(
        name="honking",
        desc="Truly Honker Size",
        n_embd=512,
        n_hidden=1024,
        n_heads=12,
        n_layers=16,
    )
    yield Box(
        name="wide",
        desc="WIDE Size",
        n_embd=1024,
        n_hidden=2048,
        n_heads=8,
        n_layers=4,
    )
    yield Box(
        name="deep",
        desc="DEEP Size",
        n_embd=256,
        n_hidden=512,
        n_heads=8,
        n_layers=24,
    )

def lengths():
    yield Box(
        name="debug",
        desc="10 steps",
        n_steps=10,
    )
    yield Box(
        name="normal",
        desc="20k steps",
        n_steps=20_000,
    )
    yield Box(
        default=True,
        name="long",
        desc="1M steps",
        n_steps=1_000_000,
    )

def chunk_sizes(dataset):
    if dataset.name.startswith("mnist"):
        yield Box(
            name="blind",
            desc="Blind",
            chunk_size=1,
            batch_size=784*16,
            n_test_val_repeats=5,
        )
        yield Box(
            default=True,
            name="chunks",
            desc="Chunks",
            chunk_size=28*2,
            batch_size=784*16//(28*2),
            n_test_val_repeats=5,
        )
        yield Box(
            name="fullimg",
            desc="Full Image",
            chunk_size=784,
            batch_size=16,
            n_test_val_repeats=5,
        )
    elif dataset.name.startswith("bvh"):
        yield Box(
            name="chunkshort",
            desc="Chunk Size = 32, Batch Size = 256",
            chunk_size=32,
            batch_size=16*16,
        )
        yield Box(
            default=True,
            name="chunklong",
            desc="Chunk Size = 512, Batch Size = 16",
            chunk_size=32*16,
            batch_size=16,
        )
        yield Box(
            name="chunklittle",
            desc="Chunk Size = 32, Batch Size = 16",
            chunk_size=32,
            batch_size=16,
        )


def choose_from_iterator(iterator, title):
    l = list(iterator)
    max_name = max(len(l.name) for l in l)
    max_i = max(len(str(i)) for i in range(len(l)))
    print()
    print(f"    {title}")
    for i, item in enumerate(l):
        default = "(default)" if hasattr(item, 'default') and item.default else " "*9
        
        print(f"{default} ({i: >{max_i}}): {item.name:.<{max_name}}....{item.desc}")
    choice = input("        choice: ")
    if choice == "":
        choice = next(i for i, item in enumerate(l) if hasattr(item, 'default') and item.default)
    else:
        choice = int(choice)
    return l[choice]

# choose pipeline interactively
# previous implementation was a big list, now we do it part-at-a-time
# pipeline, dataset, task, model = choose_pipeline()
def choose_pipeline():
    dataset = choose_from_iterator(datasets(), "Dataset")
    task = choose_from_iterator(tasks(dataset), "Task")
    embedding = choose_from_iterator(embeddings(dataset, task), "Embedding")
    model = choose_from_iterator(models(task), "Model")
    model_size = choose_from_iterator(model_sizes(), "Model Size")
    length = choose_from_iterator(lengths(), "Training Length")
    chunk_size = choose_from_iterator(chunk_sizes(dataset), "Seq & Batch Size")

    mxds = dataset.make(dataset)
    mxtask = task.make(chunk_size, task)
    mxemb = embedding.make(model_size, embedding)
    mxmodel = model.make(model_size, model)
    
    pipeline = Pipeline(
        name=f"{dataset.name}-{task.name}-{embedding.name}-{model.name}-{model_size.name}-{length.name}-{chunk_size.name}",
        desc=f"{dataset.desc}, {task.desc}, {embedding.desc}, {model.desc}, {model_size.desc}, {length.desc}, {chunk_size.desc}",
        batch_size=chunk_size.batch_size,
        test_batch_size=chunk_size.batch_size,
        viz_batch_size=dataset.viz_batch_size,
        n_steps=length.n_steps,
        dataset=mxds,
        task=mxtask,
        embedding=mxemb,
        model=mxmodel,
    )
    print()
    print(f"Pipeline: {pipeline.name}")
    print(f"    {pipeline.desc}")
    print()
    return pipeline


In [None]:
from mx import train

rand_run_name = u.random_run_name()
run_name_options = [
    Box(
        default=True,
        name="dev",
        desc="Interactive development",
    ),
    Box(
        name="random",
        desc=f"'{rand_run_name}' (Random name)",
    ),
    Box(
        name="prev",
        desc="Existing previous run",
    ),
    Box(
        name="interactive",
        desc="Interactive",
    ),
]
run_name = choose_from_iterator(run_name_options, "Run Name")
if run_name.name == "dev":
    u.set_run_name("dev")
    force_new = True
    force_not_new = False
elif run_name.name == "prev":
    u.set_run_name("blessed")
    force_new = False
    force_not_new = True
elif run_name.name == "random":
    u.set_run_name(rand_run_name)
    force_new = True
    force_not_new = False
else:
    # default run name is "interactive-<date>"
    force_new = False
    force_not_new = False

pipeline = choose_pipeline()

output_dir = pipeline.output_dir()

if pipeline.use_float16:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")

if isinstance(pipeline.dataset, mxd.MxMNIST):
    u.setregtype('mnist')
elif isinstance(pipeline.dataset, mxd.BvhDataset):
    u.setregtype('bvh')

model = pipeline.make_or_load_model(force_new=force_new, force_not_new=force_not_new)
loss_fn = pipeline.make_loss_fn()
train_data, val_data = pipeline.make_train_data()
predict_fn = pipeline.task.make_predict_fn(model)
vizs = pipeline.dataset.get_visualizations(model, output_dir)
# vizr = Visualizer(
#     visualizations=vizs,
#     configs={},
# )
vizr = None
train_loop = train.make_train_loop(
    model=model,
    loss_fn=loss_fn,
    data=train_data,
    val_data=val_data,
    output_dir=output_dir,
    vizr=vizr,
    log_interval="never",
)

In [None]:
from mx.progress import create_progress_manager


if force_new:
    train_loop()
