In [None]:
from mx.progress import init_with_progress
init_with_progress()
from mx.prelude import *

In [4]:
import itertools
from re import A
from mx import datasets as mxd, tasks as mxt, models as mxm, embeddings as mxe, layers as mxl
from mx.embeddings.vector_embeddings import DebugCodebookTriples
from mx.pipeline import Pipeline
from mx.tasks.vector_sequence_mse import RandomSequenceMSE
from mx.visualizer import Visualizer

def datasets():
    # yield mxd.MxMNIST()
    yield Box(
        default=True,
        name="mnist",
        desc="MNIST",
        viz_batch_size=10,
        make=lambda b: mxd.MxMNIST(
            name=b.name,
            desc=b.desc
        ),
    )
    yield Box(
        default=False,
        name="bvh",
        desc="BVH (Dense)",
        viz_batch_size=3,
        make=lambda b: mxd.BvhDataset(
            recluster=False,
            decimate=False,
            name=b.name,
            desc=b.desc,
        ),
    )
    yield Box(
        default=False,
        name="bvhdec",
        desc="BVH (Sparse, 0.5 norm-based decimation)",
        viz_batch_size=3,
        make=lambda b: mxd.BvhDataset(
            recluster=False,
            decimate=True,
            name=b.name,
            desc=b.desc,
        ),
    )

def tasks(dataset):

    if dataset.name == "mnist":
        yield Box(
            default=False,
            name="next",
            desc="Next Pixel, left-to-right, top-to-bottom",
            make=lambda c_size, b: mxt.VectorSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            default=True,
            name="rand",
            desc="Next Pixel (random order)",
            make=lambda c_size, b: mxt.RandomSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )
    elif dataset.name.startswith("bvh"):
        yield Box(
            default=False,
            name="next",
            desc="Next Frame (left-to-right)",
            make=lambda c_size, b: mxt.VectorSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            default=True,
            name="rand",
            desc="Next Frame (random order)",
            make=lambda c_size, b: mxt.RandomSequenceMSE(
                c_size.chunk_size,
                name=b.name,
                desc=b.desc,
            ),
        )

def embeddings(task):
    if task.name == "next":
        yield Box(
            default=True,
            name="codebook",
            desc="Codebook Multidim",
            make=lambda m_size, b: mxe.VectorCodebookMultidim(
                m_size.n_embd,
                name=b.name,
                desc=b.desc,
            ),
        )
        yield Box(
            default=False,
            name="sinusoidal",
            desc="Sinusoidal Multidim",
            make=lambda m_size, b: mxe.VectorSinusoidalMultidim(
                m_size.n_embd,
                name=b.name,
                desc=b.desc,
            ),
        )
    elif task.name == "rand":
        yield Box(
            default=True,
            name="codebook",
            desc="Codebook (Triples)",
            make=lambda m_size, b: mxe.DebugCodebookTriples(
                m_size.n_embd,
                name=b.name,
                desc=b.desc,
            ),
        )

def models(task):
    yield Box(
        default=False,
        name="debugmlp",
        desc="Debug 2-layer MLP",
        make=lambda size, b: mxm.DebugMLP(
            n_hidden=size.n_hidden,
            dropout=0.1,
            name=b.name,
            desc=b.desc,
        ),
    )
    yield Box(
        default=True,
        name="transformer",
        desc="Decoder-only Transformer",
        make=lambda m_size, b: mxm.DecoderOnlyTransformer(
            n_layers=m_size.n_layers,
            n_hidden=m_size.n_hidden,
            n_heads=m_size.n_heads,
            dropout=0.1,
            name=b.name,
            desc=b.desc,
        ),
    )
    yield Box(
        default=False,
        name="resnet",
        desc="Resnet",
        make=lambda m_size, b: mxm.Resnet(
            n_layers=m_size.n_layers,
            n_hidden=m_size.n_hidden,
            dropout=0.1,
            name=b.name,
            desc=b.desc,
        ),
    )


def model_sizes():
    yield Box(
        default=False,
        name="tiny",
        desc="Tiny Size",
        n_embd=32,
        n_hidden=64,
        n_heads=4,
        n_layers=2,
    )
    yield Box(
        default=False,
        name="small",
        desc="Small Size",
        n_embd=64,
        n_hidden=128,
        n_heads=8,
        n_layers=4,
    )
    yield Box(
        default=True,
        name="medium",
        desc="Medium Size",
        n_embd=128,
        n_hidden=256,
        n_heads=8,
        n_layers=6,
    )
    yield Box(
        default=False,
        name="large",
        desc="Large Size",
        n_embd=256,
        n_hidden=512,
        n_heads=12,
        n_layers=8,
    )

def lengths():
    yield Box(
        default=False,
        name="debug",
        desc="10 steps",
        n_steps=10,
    )
    yield Box(
        default=True,
        name="normal",
        desc="20k steps",
        n_steps=20_000,
    )
    yield Box(
        default=False,
        name="long",
        desc="1M steps",
        n_steps=1_000_000,
    )

def chunk_sizes():
    yield Box(
        default=False,
        name="blind",
        desc="Blind",
        chunk_size=1,
        batch_size=784*16,
    )
    yield Box(
        default=False,
        name="chunks",
        desc="Chunks",
        chunk_size=28*2,
        batch_size=784*16//(28*2),
    )
    yield Box(
        default=True,
        name="fullimg",
        desc="Full Image",
        chunk_size=784,
        batch_size=16,
    )

def choose_from_iterator(iterator, title):
    l = list(iterator)
    max_name = max(len(l.name) for l in l)
    max_i = max(len(str(i)) for i in range(len(l)))
    print()
    print(f"    {title}")
    for i, item in enumerate(l):
        default = "(default)" if item.default else " "*9
        
        print(f"{default} ({i: >{max_i}}): {item.name:.<{max_name}}....{item.desc}")
    choice = input("        choice: ")
    if choice == "":
        choice = next(i for i, item in enumerate(l) if item.default)
    else:
        choice = int(choice)
    return l[choice]

# choose pipeline interactively
# previous implementation was a big list, now we do it part-at-a-time
# pipeline, dataset, task, model = choose_pipeline()
def choose_pipeline():
    dataset = choose_from_iterator(datasets(), "Dataset")
    task = choose_from_iterator(tasks(dataset), "Task")
    embedding = choose_from_iterator(embeddings(task), "Embedding")
    model = choose_from_iterator(models(task), "Model")
    model_size = choose_from_iterator(model_sizes(), "Model Size")
    length = choose_from_iterator(lengths(), "Training Length")
    chunk_size = choose_from_iterator(chunk_sizes(), "Seq & Batch Size")

    mxds = dataset.make(dataset)
    mxtask = task.make(chunk_size, task)
    mxemb = embedding.make(model_size, embedding)
    mxmodel = model.make(model_size, model)
    
    pipeline = Pipeline(
        name=f"{dataset.name}-{task.name}-{embedding.name}-{model.name}-{model_size.name}-{length.name}-{chunk_size.name}",
        desc=f"{dataset.desc}, {task.desc}, {embedding.desc}, {model.desc}, {model_size.desc}, {length.desc}, {chunk_size.desc}",
        batch_size=chunk_size.batch_size,
        test_batch_size=chunk_size.batch_size,
        viz_batch_size=dataset.viz_batch_size,
        n_steps=length.n_steps,
        dataset=mxds,
        task=mxtask,
        embedding=mxemb,
        model=mxmodel,
    )
    print()
    print(f"Pipeline: {pipeline.name}")
    print(f"    {pipeline.desc}")
    print()
    return pipeline


In [None]:
from mx import train

run_name_options = [
    Box(
        default=False,
        name="dev",
        desc="Interactive development",
    ),
    Box(
        default=True,
        name=u.random_run_name(),
        desc="New training run",
    ),
    Box(
        default=False,
        name="prev",
        desc="Existing previous run",
    ),
]
run_name = choose_from_iterator(run_name_options, "Run Name")
if run_name.name == "dev":
    force_new_run = False
elif run_name.name == "prev":
    prev_runs = u.list_previous_runs()
    run_name = choose_from_iterator(prev_runs, "Previous Run")
    force_new_run = False
else: # random
    force_new_run = True

run_name = u.set_run_name(run_name.name)


pipeline = choose_pipeline()

output_dir = pipeline.output_dir()

if pipeline.use_float16:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")

model = pipeline.make_or_load_model(force_new=force_new_run)
loss_fn = pipeline.make_loss_fn()
data = pipeline.make_train_data()
predict_fn = pipeline.task.make_predict_fn(model)
vizs = pipeline.dataset.get_visualizations(
    viz_batch_size=pipeline.viz_batch_size,
    task_specific_predict_fn=predict_fn,
)
vizr = Visualizer(
    visualizations=vizs,
    output_dir=output_dir,
    configs={},
)
train_loop = train.make_train_loop(
    model=model,
    loss_fn=loss_fn,
    data=data,
    run_name=run_name,
    output_dir=output_dir,
    vizr=vizr,
    log_interval="never",
)

In [None]:
from mx.progress import create_progress_manager


if force_new_run:
    train_loop()
