In [None]:
import torch
from experiments import run_experiments, run_invariance_tests, run_time_benchmarks, run_flops_benchmarks, auto_device, MODEL_NAMES

In [None]:
from pprint import pprint
pprint(MODEL_NAMES)

In [None]:
# we dont test the full symmetrization models since they require N! permutations per input, and that number becomes astronomical very quickly
model_to_test = [
    "canonical-mlp",
    "canonical-attn",
    "symmetry-sampling-mlp",
    "symmetry-sampling-attn",
    "intrinsic",
    "augmented-mlp",
    "augmented-attn",
]

In [None]:
# choose device and sequence length
device = auto_device()
seq_len = 10  # N
feature_dim = 5  # D

# perform equivariance and invariance tests
run_invariance_tests(seq_len, feature_dim, device, model_names=model_to_test)

# measure FLOPs
run_flops_benchmarks(seq_len, feature_dim, device, model_names=model_to_test)

# measure inference and forward-backward pass times
run_time_benchmarks(seq_len, feature_dim, device, model_names=model_to_test)

In [None]:
# train all models and evaluate them
# logs the results using tensorboard for future analysis
# best device is chosen automatically
for seq_len in [10, 100]:
    for train_set_size in [100, 1000, 10000]:
        run_experiments(seq_len, feature_dim=5, train_size=train_set_size, model_names=model_to_test)