Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

How sensitive is PAWS to batch size? #5

Closed
frank-xwang opened this issue May 3, 2021 · 12 comments
Closed

How sensitive is PAWS to batch size? #5

frank-xwang opened this issue May 3, 2021 · 12 comments

Comments

@frank-xwang
Copy link

frank-xwang commented May 3, 2021

Hi, thanks for sharing the code! I am curious about PAWS' sensitivity to batch size. Have you tried experimenting with smaller batch sizes (such as 256 or 512) that 8 GPUs can afford on ImageNet? Thanks. @MidoAssran

@MidoAssran
Copy link
Contributor

Hi @frank-xwang,

We didn't explore smaller batches, but I'm happy to help if you're interested in investigating this. In general, the loss should be relatively robust, but the size of the support set does make a difference (as per the ablation in Section 7). Thus, you may need longer training with small batches.

@frank-xwang
Copy link
Author

Hello, thank you for your reply. I think that batch-size ablation study will be very interesting for researchers in many academic groups that do not have so many computing resources. It would be great if you could provide this kind of ablation study.
Also, it seems that we need to install Slurm to run your code on ImageNet, which requires sudo permissions :-(. Could you please release a version of codes or main file that can run directly on a single machine without installing Slurm? Thanks a lot!

@MidoAssran
Copy link
Contributor

I'll get back to you about the batch-size ablation, but it's unlikely I'll be able to get to this soon unfortunately. As for the a version that doesn't require Slurm, you can launch your ImageNet jobs with "main.py" instead of "main_distributed.py" and that should work on a single GPU without Slurm! For example

python main.py
  --sel paws_train
  --fname configs/paws/imgnt_train_1GPU.yaml

@frank-xwang
Copy link
Author

frank-xwang commented May 18, 2021

Awesome! Thank you! For the main file, sorry for the unclearness, I mean one machine with 8 GPUs. Do we have to use main_distributed.py? Is there any main file that is able to work without Slurm on 8 GPUs, with distributed training?

@MidoAssran
Copy link
Contributor

Oh yes I see what you mean. Just pushed a change so that you can now run main.py using several GPUs on a multi-gpu machine, just specify the devices as command line arguments. For example, to run training on 8GPUs, specify the devices as so:

python main.py
  --sel paws_train
  --fname configs/paws/imgnt_train_8GPU.yaml
  --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7

@frank-xwang
Copy link
Author

Great! Thanks!

@Ir1d
Copy link

Ir1d commented May 27, 2021

@frank-xwang Hi, did you run successfully on 8gpus ? Could you share your training time

@frank-xwang
Copy link
Author

Hi, after reducing "unsupervised_batch_size" and "supervised_imgs_per_class", I can run it on 4 V100 GPUs. The training time for each epoch is approximately 0.8 hours. But I think reducing batch size may reduce performance, which may need to be verified after completing the experiment.

@CloudRR
Copy link

CloudRR commented Jun 18, 2021

Hi, after reducing "unsupervised_batch_size" and "supervised_imgs_per_class", I can run it on 4 V100 GPUs. The training time for each epoch is approximately 0.8 hours. But I think reducing batch size may reduce performance, which may need to be verified after completing the experiment.
@frank-xwang Hi, have you finished your experiment on 4 V100 GPUs? I also want to run the experiment on 8 V100 GPUs but I am a littile worried about the performance and the speed. Thanks a lot!

@frank-xwang
Copy link
Author

Hi @CloudRR, I tried some hyperparameters, but failed to reproduce the reported results with 4 V100 GPUs. Although the speed is not bad, training 1 epoch takes about 1 hour. It seems that PAWS is also sensitive to batch size, as has been observed in many self-supervised learning methods.

@Ir1d
Copy link

Ir1d commented Jun 29, 2021

Same here. Couldnt reproduce with 4gpus, and also 1h/epoch

@MidoAssran
Copy link
Contributor

MidoAssran commented Jun 29, 2021

Hi,

I've had a lot on my plate, but I did manage to try out a PAWS run on ImageNet with a small batch-size, and it essentially reproduces the large-batch numbers.

Using 8 V100 GPUs for 100 epochs with 10% of ImageNet labels, I get

  • 1h/epoch
  • 70.2% top-1

This top-1 accuracy is consistent with the ablation in the bottom row of table 4 in the paper (similar support set, but much larger batch-size).

Here is the config I used to produce this result when running on 8 GPUs. To explain some of the choices:

  • Roughly square-root scaling of the learning rate (though I haven't explored this much, and other learning rates might yield better performance).
  • I also set me_max: false. With a small batch-size, it's not clear to me that using me-max regularization makes sense, so I turned it off.
  • With 8 GPUs, the unsupervised batch-size is 256 (8 x 32) and the support batch size is 1680 (560 classes (8 x 70) and 3 images per class). I tried to use as large a support set as possible, since the ablation in Table 4 shows larger supports lead to better performance.

All other hyper-parameters are identical to the large-batch setup.

criterion:
  classes_per_batch: 70
  me_max: false
  sharpen: 0.25
  supervised_imgs_per_class: 3
  supervised_views: 1
  temperature: 0.1
  unsupervised_batch_size: 32
data:
  color_jitter_strength: 1.0
  data_seed: null
  dataset: imagenet
  image_folder: imagenet_full_size/061417/
  label_smoothing: 0.1
  multicrop: 6
  normalize: true
  root_path: datasets/
  subset_path: imagenet_subsets
  unique_classes_per_rank: true
  unlabeled_frac: 0.90
logging:
  folder: /path_to_save_models_and_logs/
  write_tag: paws
meta:
  copy_data: true
  device: cuda:0
  load_checkpoint: false
  model_name: resnet50
  output_dim: 2048
  read_checkpoint: null
  use_fp16: true
  use_pred_head: true
optimization:
  epochs: 100
  final_lr: 0.0012
  lr: 1.2
  momentum: 0.9
  nesterov: false
  start_lr: 0.3
  warmup: 10
  weight_decay: 1.0e-06

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants