Code for the paper: "AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation" by David Berthelot, Rebecca Roelofs, Kihyuk Sohn, Nicholas Carlini, and Alex Kurakin.
This is not an officially supported Google product.
sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages ~/jax3
. ~/jax3/bin/activate
# Install dependencies (replace with your installed CUDA version)
CUDA_VERSION=11.2
pip install --upgrade -r requirements.txt
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`
export ML_DATA="path to where you want the datasets saved"
export PYTHONPATH=$PYTHONPATH:.
# Use this config if you won't want JAX to take the whole GPU memory
export XLA_PYTHON_CLIENT_PREALLOCATE=false
# Download datasets and save them as TF records
CUDA_VISIBLE_DEVICES= python scripts/create_datasets.py
# Alternatively if your machine has many cores, this parallel version is much faster to run.
bash ./runs/create_datasets.sh
--logdir
where the results are saved.--uratio
ratio of unlabeled data / label data.--augment
typically can take two values (weak,strong) augmentation and a prefix (Control-Theory Augment (CTA) introduced in ReMixMatch or no prefix). Examples:--augment=\(sm,smc\)
weak augmentation is shift and mirror, strong augmentation is shift, mirror and cutout.--augment=CTA\(sm,sm\)
weak augmentation is shift and mirror, strong augmentation CTA on top of shift and mirror.
--dataset
dataset to use with desired size, example:domainnet32
(32x32 images)--source
source subset (for exampleclipart
)--target
target subset to which to adapt the source (for examplequickdraw
)
Note: for SSDA, the target takes a different form.
--target
target subset to which to adapt the source and how many label per class to use and what random seed to use (for examplequickdraw(3,seed=2)
means use 3 labels per class picked at random using random seed 2)
--dataset
combines both dataset and source from SSDA in a single option. For example:domainnet32_quickdraw(3,seed=2)
.
# Baseline: source and target must be the same. Additionally, one can specify extra test sets.
python fully_supervised/baseline.py --dataset=domainnet32 --source=clipart --target=clipart\
--logdir experiments/2021/02.12-32 --augment=CTA\(sm,sm\)\
--test_extra=clipart,infograph,quickdraw,real,sketch,painting
# Baseline: does nothing for unlabeled except running it with labeled as a single batch through the network.
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)
# FixMatch
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)
# AdaMatch
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)
# NoisyStudent
## Teacher
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\)\
--id=0
## Student
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
--logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\) --id=1\
--pseudo_label_file=experiments/2021/03.1-32/DA/domainnet32/clipart/quickdraw/CTA\(sm,sm\)/NoisyStudent/archwrn28-2_batch64_lr0.03_lr_decay0.25_wd0.001/0/predictions.npy
# MCD
python domain_adaptation/mcd.py --dataset=domainnet64 --source=clipart --target=quickdraw\
--uratio=1 --arch=wrn28-2 --augment='CTA(sm,sm)' --train_mimg=8 --logdir experiments/2021/02.12-32 --lr_decay 0.25
Use no
in front of the subdomain to use all domains but the one concerned.
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=no_quickdraw --target=quickdraw\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
# FixMatch
python semi_supervised_domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart\
--target=quickdraw\(10,seed=1\)\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
# AdaMatch
python semi_supervised_domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart\
--target=quickdraw\(10,seed=1\)\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
# FixMatch
python semi_supervised/fixmatch_da.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
# AdaMatch
python semi_supervised/adamatch.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
--logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
tensorboard --logdir experiments