This project benchmarks three attention families on three tasks:
- MNIST (classification)
- Tiny ImageNet (classification)
- MIL (AUC on tiger/fox/elephant)
Implemented model families:
kf_*(Karcher Flow)hf_*(Hopfield)ein_*(Einstein-midpoint aggregation)
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
pip install -r requirements.txtpython test_mnist.py [OPTIONS]Key options and defaults:
--modelchoices:kf_attention,kf_layer,kf_pooling,hf_attention,hf_layer,hf_pooling,ein_attention,ein_layer,ein_pooling--batch-size:64--epochs:14--lr:0.001--gamma:0.96--hidden-dim:8--beta:None(core derives to1/sqrt(d))--num-states:1--num-memories:64--seed:1
Example:
python test_mnist.py --model ein_attention --hidden-dim 32python test_tiny_imagenet.py [OPTIONS]Example:
python test_tiny_imagenet.py --model ein_attention --hidden-dim 128 --epochs 1 --dry-runpython test_mil.py [OPTIONS]Key options and defaults:
--modelchoices:kf_attention,kf_layer,kf_pooling,hf_attention,hf_layer,hf_pooling,ein_attention,ein_layer,ein_pooling--datasetchoices:tiger,fox,elephant--batch-size:16--epochs:100--lr:0.001--gamma:0.96--hidden-dim:128--beta:None(core derives to1/sqrt(d))--num-states:1--num-memories:64--bag-dropout:0.5--seed:1
Example:
python test_mil.py --dataset fox --model ein_poolingpython run_mnist_table.pyCurrent benchmark defaults:
- Models:
["kf_attention", "hf_attention", "ein_attention"] - Hidden dims:
[4, 8, 32] - Trials:
5 - Epochs:
14 - Optimizer/LR/Gamma:
AdamW,0.001,0.96
Output:
results/mnist/mnist_benchmark_results.csv
python run_tiny_imagenet.pyCurrent benchmark defaults:
- Models:
["kf_attention", "hf_attention", "ein_attention"] - Hidden dim:
64 - Trials:
5(seeds42..46) - Epochs:
14 - Optimizer/LR/Gamma:
AdamW,0.001,0.96
Output:
results/tiny_imagenet/tiny_imagenet_benchmark_results.csv
python run_mil_table.pyCurrent benchmark defaults:
- Datasets:
["tiger", "fox", "elephant"] - Models:
["kf_pooling", "hf_pooling", "ein_pooling"] - Trials:
5(seeds42..46) - Epochs:
100 - Batch size/LR/Gamma:
16,0.001,0.96
Output:
results/mil/mil_benchmark_results.csv