Improving Consistency-Based Semi-Supervised Learning with Weight Averaging
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Failed to load latest commit information.
data-local
experiments
figs
mean_teacher
.gitignore
README.md
main.py adding sample scripts Jun 6, 2018
plot_results.sh
read_log.py

README.md

Improving Consistency-Based Semi-Supervised Learning with Weight Averaging

This repository contains the code for our paper Improving Consistency-Based Semi-Supervised Learning with Weight Averaging, which achieves the best known performance for semi-supervised learning on CIFAR-10 and CIFAR-100. By analyzing the geometry of training objectives involving consistency regularization, we can significantly improve the Mean Teacher model (Tarvainen and Valpola, NIPS 2017) and the Pi Model (Laine and Aila, ICLR 2017), using stochastic weight averaging (SWA) and our proposed variant fast-SWA which achieves faster reduction in prediction errors.

The BibTeX entry for the paper is:

@article{athiwaratkun2018improving,
  title={Improving Consistency-Based Semi-Supervised Learning with Weight Averaging},
  author={Athiwaratkun, Ben and Finzi, Marc and Izmailov, Pavel and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:1806.05594},
  year={2018}
}

Preparing Packages and Data

The code runs on Python 3 with Pytorch 0.3. The following packages are also required.

pip install scipy tqdm matplotlib pandas msgpack

Then prepare CIFAR-10 and CIFAR-100 with the following commands:

./data-local/bin/prepare_cifar10.sh
./data-local/bin/prepare_cifar100.sh

Semi-Supervised Learning with fastSWA

We provide training scripts in folder exps. To replicate the results for CIFAR-10 using the Mean Teacher model on 4000 labels with a 13-layer CNN, run the following:

python experiments/cifar10_mt_cnn_short_n4k.py

Similarly, for CIFAR-100 with 10k labels:

python experiments/cifar100_mt_cnn_short_n10k.py

The results are saved to the directories cifar10_mt_cnn_short_n4k and results/cifar100_mt_cnn_short_n10k. The plot the accuracy versus epoch, run

python read_log.py --pattern results/cifar10_*n4k --cutoff 84 --interval 2 --upper 92
python read_log.py --pattern results/cifar100_*n10k --cutoff 54 --interval 4 --upper 70

Figure 1. CIFAR-10 Test Accuracy of the Mean Teacher Model with SWA and fastSWA using a 13-layer CNN and 4000 labels.

Figure 2. CIFAR-100 Test Accuracy of the Mean Teacher Model with SWA and fastSWA using a 13-layer CNN and 10000 labels.

We provide scripts for ResNet-26 with Shake-Shake regularization for CIFAR-10 and CIFAR-100, as well as other label settings in the directory experiments.

fastSWA and SWA Implementation

fastSWA can be incorporated into training very conveniently with just a few lines of code. First, we initialize a replicate model (which can be set to require no gradients to save memory) and initialize the weight averaging optimization object.

fastswa_net = create_model(no_grad=True)
fastswa_net_optim = optim_weight_swa.WeightSWA(swa_model)

Then, the fastSWA model can be updated every fastswa_freq epochs. Note that after updating the weights, we need to update Batch Normalization running average variables by passing the training data through the fastSWA model.

if epoch >= (args.epochs - args.cycle_interval) and (epoch - args.epochs - args.cycle_interval) % fastswa_freq == 0:
  fastswa_net_optim.update(fastswa_net)
  update_batchnorm(fastswa_net, train_loader, train_loader_len)

For more details, see main.py.

Note: the code is adapted from https://github.com/CuriousAI/mean-teacher/tree/master/pytorch