This repository presents code for our paper:
Cluster-aware Semi-supervised Learning: Relational Knowledge Distillation Provably Learns Clustering.
Yijun Dong*, Kevin Miller*, Qi Lei, Rachel Ward. NeurIPS 2023.
$ conda env create -f environment.yml -n rkd
$ conda activate rkd
- The relative paths for the datasets, the pretrained teacher models/features, and the pre-allocated directory for results are configured as follows.
.. |-- cifar10_pretrained # teacher models pretrained on CIFAR-10 |-- data # datasets: cifar-10/100 |-- pretrained # teacher models | |-- cifar10 | | |-- densenet161_cifar10_dim10.pt # pretrained teacher features | | |-- densenet161_cifar1010_active-fl_40.npy # coreset labeled samples selected via StochasticGreedy | | |-- ... | |-- cifar100 | | |-- resnet50w5_swav_dim1000.pt # pretrained teacher features | | |-- resnet50w5_swav1000_active-fl_400.npy # coreset labeled samples selected via StochasticGreedy | | |-- ... |-- result # experiment results |-- Semi_Supervised_Knowledge_Distillation # main implementation (this repo)
-
CIFAR-10 pretrained models: For CIFAR-10 experiments, we use in-distribution teacher models pretrained on the same dataset (CIFAR-10) based on the PyTorch models trained on CIFAR-10 dataset as follows:
$ cd .. $ git clone https://github.com/huyvnphan/PyTorch_CIFAR10.git $ mv PyTorch_CIFAR10 cifar10_pretrained $ cd cifar10_pretrained $ python train.py --download_weights 1
Alternatively, one can download the pretrained weights directly from the Google Drive link provided in PyTorch_CIFAR10 and unzip the file in
../cifar10_pretrained/cifar10_models/
. -
CIFAR-100 pretrained models: For CIFAR-100 experiments, we use out-of-distribution teacher models pretrained on a different dataset (ImageNet) via (unsupervised) contrastive learning (SwAV) based on the official PyTorch implementation and pretrained models for SwAV as follows:
import torch model = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')
-
Inference of teacher features on the pretrained teacher models
$ bash teach.sh
For CIFAR-10 features evaluated with supervisedly pretrained DenseNet161 on CIFAR-10:
$ python teach.py --dataset cifar10 --teacher_arch densenet161 --teacher_pretrain cifar10
For CIFAR-100 features evaluated with pretrained ResNet-50 (of width x5, i.e., resnet50w5) on ImageNet via SwAV:
$ python teach.py --dataset cifar100 --teacher_arch resnet50w5 --teacher_pretrain swav
- We follow the implementations in fbuchert: Unofficial PyTorch implementation of FixMatch and kekmodel: Unofficial PyTorch implementation of FixMatch
- CIFAR-10 experiments
$ bash train_cifar10.sh
- CIFAR-100 experiments
$ bash train_cifar100.sh
- fbuchert: Unofficial PyTorch implementation of FixMatch
- kekmodel: Unofficial PyTorch implementation of FixMatch
- PyTorch models trained on CIFAR-10 dataset
- PyTorch implementation and pretrained models for SwAV
- Official TensorFlow implementation of FixMatch
- Unofficial PyTorch implementation of MixMatch
- Unofficial PyTorch Reimplementation of RandAugment
- PyTorch image models
@article{sohn2020fixmatch,
title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},
author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},
journal={Advances in neural information processing systems},
volume={33},
pages={596--608},
year={2020}
}
@article{caron2020unsupervised,
title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments},
author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand},
booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)},
year={2020}
}
@inproceedings{dong2023clusteraware,
title={Cluster-aware Semi-supervised Learning: Relational Knowledge Distillation Provably Learns Clustering},
author={Yijun Dong and Kevin Miller and Qi Lei and Rachel Ward},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}