Skip to content

feizc/Meta-Ensemble

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

82 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Meta-Ensemble Parameter Learning

This is the PyTorch implementation for inference and training of the weightformer network as described in:

Meta-Ensemble Parameter Learning

In between, we introduce a new task, referred to meta ensemble parameter learning, which aims to directly predict the parameters of knowledge distillation model based on the parameters of base learners as well as small part of training dataset.

🔥 WeightFormer

We introduce WeightFormer, a model to directly predict the distilled student model parameters. Our architecture takes inspiration from the Transformer and incorporates three key novelties to imitate the characteristics of model ensemble, i.e., cross-layer information flow, learnable attention mask and shift consistency limitation.

Weightformer Architecture
Overview of WeightFormer for the generation of one layer weights.
Transformer-based weight generator receives concatenated weight matrices of teacher models along with model id and position information and produce the corresponding layer weights. After being generated, the predicted student model is used to compute the loss on the training set, whose gradients are then used to update the weights of WeightFormer.

⚙ Dataset and Model

We support the image classification datasets: CIFAR-10, CIFAR-100, and ImageNet, for performance evaluation. Please download the corresponding datasets and put them in the file path data/ or specify with argparse.

The trained checkpoints for WeighFormer will be available at Googledrive.

🙌 Training

Training scripts for different training scenarios.

All the training scripts are in the folder ./scripts and run python script_name.py for corresponding process.

Scripts Scenarios
train_vgg.py train single vggnet-11
train_resnet.py train single resnet-50
train_distillation.py average knowledge distillation for model ensemble
train_mlp.py mlp network for weight generation
train_transformer.py WeightFormer for weight generation

For help or issues related to this package, please submit a GitHub issue.

About

Meta-Ensemble Parameter Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages