Skip to content

karellat/h-next

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

H-NeXt

The next step towards roto-translation invariant networks

Prerequisites

  • Install Conda
  • Trained using RockyLinux and NVIDIA A100
  • Create conda environment conda env create -f environment.yml

Training

  • Training CLI in training.py
    • Specify ID of GPU by CUDA_VISIBLE_DEVICES
    • To choose size of batch, data dir, for zero padding pad
      • --datamodule_hparams "{'batch_size': 63, 'data_dir' : '/tmp', 'pad':0}"
    • For more details see training.py
  • Parameters are optimized for Wandb Sweeps, see examples in: data/sweeps.

0. Activate Conda Env

conda activate h-next

MNIST Models

Each architecture design is represented by the best model of 10 runs and how to train them is listed bellow:

1. H-Nets

CUDA_VISIBLE_DEVICES=0 python training.py 

2. UP

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetBackbone"

3. UP + MASK

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetBackbone" --backbone_hparams "{'circular_masking':True}"

CIFAR-10 Models

1. UP + MASK

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetBackbone" --backbone_hparams "{'maximum_order': 1, 'circular_masking':True, 'in_channels':3}" --datamodule_name "cifar10-rot-test"

2. UP + MASK + HUGE

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetBackbone" --backbone_hparams "{'maximum_order': 2, 'circular_masking':True, 'in_channels':3, 'nf1':32, 'nf2':64, 'nf3':128}" --datamodule_name "cifar10-rot-test"

3. UP + MASK + WIDE

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetWideBackbone" --classnet_name "ZernikeProtypePooling" --datamodule_name "cifar10-rot-test"

4. UP + MASK + ATT

CUDA_VISIBLE_DEVICES=0 python training.py  --backbone_name "UpscaleHnetWideBackbone" --backbone_hparams "{'model_str' : 'B-8-MP,B-16' }" --classnet_name "TransformerPooling" --datamodule_name "cifar10-rot-test"

Testing

  • Artifacts of models are listed in data/models, and divided according to datasets.
  • How to load and test model see testing.ipynb.
  • For SWN-GCN evaluation same models as for mnist-rot-test and cifar10-rot-test were used, thus their training datasets are equal.

Datasets

When using our datasets, they will be downloaded automatically see: custom_datasets.py Direct links:

Troubleshooting

libcublasLt.so.11

export LD_LIBRARY_PATH=~/miniconda3/lib:"$LD_LIBRARY_PATH"

About

Improved Harmonic Networks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published