Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of pretrained weigths #5

Closed
marcelomatheusgauy opened this issue May 21, 2024 · 14 comments
Closed

Use of pretrained weigths #5

marcelomatheusgauy opened this issue May 21, 2024 · 14 comments

Comments

@marcelomatheusgauy
Copy link

marcelomatheusgauy commented May 21, 2024

Hi,

Thanks for making your code open-source. I plan on using your models for tasks involving audio (voice and speech) of hospital patients with respiratory issues. We have had success with pre-trained models before and yours seems to be fairly suited for the type of tasks we consider. As we deal with Brazilian Portuguese audios, we are willing to test whether performing an additional pre-training on Brazilian Portuguese unlabeled audio data on top of the already pre-trained weights on AudioSet could lead to improved results on the types of health related downstream tasks we will consider later.

I have been able to start pretraining from scratch on a given set of audios following your instructions, using the script train_audio.py. After inspecting the code, I believe loading the pretrained weights (to perform further pretraining) involves passing a --resume argument on the command line. So I do python train_audio.py --csv_main=my_audios.csv --resume=path_to_pretrained_weights. Is this understanding correct? It is returning the following error when loading the weights:

---------------------------------------------------------------------------
Traceback (most recent call last):
  File "train_audio.py", line 393, in <module>
    main(args)
  File "train_audio.py", line 283, in main
    load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, delta_epoch=0, strict=False)
  File "train_audio.py", line 171, in load_model
    checkpoint = torch.load(args.resume, map_location='cpu')
  File "/home/gauy/anaconda3/envs/ar/lib/python3.8/site-packages/torch/serialization.py", line 705, in load
    with _open_zipfile_reader(opened_file) as opened_zipfile:
  File "/home/gauy/anaconda3/envs/ar/lib/python3.8/site-packages/torch/serialization.py", line 242, in __init__
    super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: Expected hasRecord("version") to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------

I am using the m2d_vit_base-80x608p16x16-221006-mr7 weights set you provided. I have also setup the conda environment as suggested in another issue here. Do you have an idea on why the error message is displayed? I have not made any changes to the script train_audio.py, have setup my_audios.csv following your instructions on a set of Brazilian Portuguese audios we have available and load the correct path to the pretrained model.

Thanks!

@daisukelab
Copy link
Collaborator

Hi, thanks for your interest. I'm glad to hear that the pre-trained weight is fairly suitable for your tasks so far.

It looks like an environmental issue, such as the torch version.
I found a similar report: ultralytics/yolov5#581
And I confirmed this similar command line worked:

python train_audio.py --csv_main data/files_icbhi2017.csv --resume m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth

The python version is:

>>> import torch; torch.__version__
'2.1.2'

Then, could you renew your PyTorch?
You can follow https://pytorch.org/get-started/locally/ for the exact command line you need.

As you have setup the conda environment, yours could be the following if your cuda version is <= 11.8:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

In addition, I am willing to help with your attempt at the additional (further) pre-training. I know that it would need to adjust the pre-training settings (e.g., training and warmup epochs) in the next step.
You might also want to use the background noise which we confirmed effective.

python train_audio.py --csv_main data/files_icbhi2017.csv --csv_bg data/files_f_s_d_5_0_k.csv --resume m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth 

Anyway, let's make the torch.load() issue clear.

@marcelomatheusgauy
Copy link
Author

marcelomatheusgauy commented May 22, 2024

Thank you for your answer. I have inspected my commands and found out that I did not realize the pretrained weights file was zipped. In order to get it to work, all I had to do was unzip the file and give the correct path. That fixes the torch.load issue.

Now, if you are willing to help with further pre-training, I have a couple questions:

  1. I believe the use of background noise was mostly used as a data augmentation step in case the dataset for further pre-training is too small. In our case, we are using fairly large public databases of Brazilian Portuguese speech, which combined account for 1200h approximately. As such do you think including background noise still helps? Or is it unnecessary?
  2. I would suspect that further pre-training demands less epochs and overall different parameters (warmup epochs and perhaps others). Do you have an intuition on how to choose such parameters?
  3. Further expanding question 2): I only have available an easily accessible server with 2 8GB GPUs and another not so easily accessible server (that has queues and is used by many people) where nodes have typically 8 8GB GPUs. I could potentially ask for more than one node but I am unclear whether that is actually beneficial. As such the best case scenario would be 8 8GB GPUs and the preliminary tests should be in 2 8GB GPUs. As pretraining is quite demanding typically, this may not be enough, in which case I have to figure out what to do. Do you think the requirements for further pre-training will be as big as the original audioset pretraining (ie. 300 epochs), or do you think there is already hope with much less epochs? Using a single 8GB GPU I have determined that I can run about 5 epochs in a day with a batch size of 32 (which I think is the max possible).

@marcelomatheusgauy
Copy link
Author

To further explain how your model will be used: first we will do additional pretraining in Brazilian Portuguese speech data and test the new and perhaps improved model on standard Brazilian Portuguese speech tasks we will prepare. If we are successful, the model will later be fine-tuned on hospital patients suffering from respiratory problems. For the part with healthcare patient audios, the dataset sizes will be very small (ie. minutes and at best case reach 1-2 hours), so it is very important that transfer learning is performed effectively. The hospital data collection will start in a few months most likely, but the first part can already be done now and will be done to prepare models to be fine-tuned quickly once we actually get the data.

Thanks for your help!

@daisukelab
Copy link
Collaborator

daisukelab commented May 23, 2024

Hi, I summarized a guideline based on what I have experienced.

https://github.com/nttcslab/m2d/blob/master/Guide_app.md

Based on it, quick comments for your use case are:

  • Pre-training from scratch using 2 8GB GPUs would be virtually impossible because it would take months.
  • If your data is closer or is a speech, LibriSpeech pre-trained weights may be better to start. (Please note that an AudioSet weight was better for respiratory sounds in our experiments.)

Recommendation for your #1 "we will do additional pre-training in Brazilian Portuguese speech data and test the new and perhaps improved model on standard Brazilian Portuguese speech tasks we will prepare."

Recommendation for your #2 "the model will later be fine-tuned on hospital patients suffering from respiratory problems."

  • If you succeeded in your pre-training, fine-tune it.
  • Fine-tune a downloaded AudioSet pre-trained weight as it is: m2d_vit_base-80x608p16x16-221006-mr7
  • Fine-tune a downloaded LibriSpeech pre-trained weight as it is: m2d_s_vit_base-80x400p80x2-230201
  • Further pre-train and fine-tune a pre-trained model.

Further pre-training (Fur-PT) guide

A possible command line for you is:

CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train_audio.py --epochs 600 --warmup_epochs 24 --resume m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth --model m2d_x_vit_base --batch_size 32 --accum_iter 4 --csv_main __your__.csv --csv_bg_noise data/files_f_s_d_5_0_k.csv --noise_ratio 0.01 --save_freq 100 --eval_after 600 --seed 3 --teacher m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth --blr 3e-4 --loss_off 1. --min_ds_size 10000

The following options are for using an existing weight to initialize (resume) and for using a teacher model in the M2D-X regularization setting.

--resume m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth
--teacher m2d_vit_base-80x608p16x16-221006-mr7/checkpoint-300.pth

These parameters set an effective batch size of 128, which we used in our Fur-PT.

--batch_size 32 --accum_iter 4

The number of epochs of 600 (and warm-up epochs of 24) may be good (but could be decreased).
You might adjust the save frequency (save_freq) and the starting epoch of evaluation during the pre-training (eval_after).

--epochs 600 --warmup_epochs 24 --save_freq 100 --eval_after 600

This option virtually increases the dataset size by repeating the list of samples. Use this if you have less than 5000 samples.

 --min_ds_size 10000

You can change the learning rate by --blr 3e-4, but this would be fine with the batch size and epochs above.

@marcelomatheusgauy
Copy link
Author

Thanks for your comments and help. From what I understand, I have four options:

  1. Pre-training from scratch with M2D on Brazilian Portuguese Speech and later fine-tuning on the specific tasks. This may require superior GPUs as it is very demanding. It would provide a useful baseline for the performance on the tasks.

  2. Using the pre-trained weights on AudioSet and just fine-tune on the specific tasks. This does not require GPUs to be much better than what I have already and should perform well and provide a baseline.

  3. Use pre-trained weights on Librispeech and just fine-tune on the tasks. This will probably perform better than AudioSet on the first batch of speech related tasks we will give the model, but I suspect it will be inferior on the actual target: tasks that are related to respiratory issues in hospital patients as in those speech is probably not the only important factor;

  4. Select one of the pretrained weights (either on AudioSet or Librispeech) and perform further pretraining on Brazilian Portuguese speech audios as proposed in your paper Masked Modeling Duo: Towards a Universal Audio
    Pre-Training Framework. For that I shall use a variant of the command line you provided on Fur-PT section. One doubt I still have regards the number of epochs. Namely, in your paper you used 600 epochs but you had a dataset with 10,000 samples after applying a data augmentation of sorts (--min_ds_size parameter). Our collection of Brazilian Portuguese datasets have almost 0.5 million samples and more than a 1000h of audio, in which case I would think much less than 600 epochs should suffice right? In that case, it should be enough for us to run the model on the 2 8GB GPUs* or if necessary on 8 8GB GPUs and it hopefully wouldn't take too long. I am thinking 20 epochs or less should be enough on a large dataset for further pre-training. Lastly, is the background noise needed as well? I will probably do ablations to check that as I am unsure it helps too much when using an already large dataset

*A small correction: after updating Pytorch to the latest version, I have found it is possible to set batch size to 64 and not just 32. So we can set accum_iter to 2 and not 4, which should help a little.

Thank you very much for your help.

@marcelomatheusgauy
Copy link
Author

A small addendum for others: to setup distributed mode I had to adapt the command line to be CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 train_audio.py ...

Without adding torch.distributed.launch, the model does not run on distributed mode

@daisukelab
Copy link
Collaborator

A small addendum for others: to setup distributed mode I had to adapt the command line to be CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 train_audio.py ...

Without adding torch.distributed.launch, the model does not run on distributed mode

Yes, you're correct. My example was wrong and corrected above (for somebody else's future reference).
FYI -- usually I use the torchrun: https://pytorch.org/docs/stable/elastic/run.html

I will answer the questions above...

@daisukelab
Copy link
Collaborator

Regarding your questions about the four options, yes, they are your options. It was nice to figure out the 4th option. And I also recommend the 5th option.

First, the combination of the number of epochs and batch size matters because we use the EMA updated target encoder and the annealing learning rate schedule.

  • While the target encoder creates the training signal, it is updated every time after consuming the samples of the effective batch size (batch size by accum_iter). -> The effective batch size matters for gaining an usuful training signal.
  • The learning rate schedule of 20 epochs is very different from 300 epochs. Finding effective combinations of epochs & batch size takes time, so I recommend using a similar combination to ours: bs=2048 & epochs=300 for 2M sample data, bs=2048 & epochs=1000 for 281k sample data, bs=128 & epochs=600 for 10k sample data.

Answers to your four options, and one more from me follow.

  1. Yes, pre-training from scratch to get an effective model would take time. I guess it would take about 200 epochs minimum. My experience with LibriSpeech was 960h (281k samples) for 1000 epochs.

  2. Same as below.

  3. Yes, you have the same understanding as mine.

  4. It's a nice idea.

    • As you have 0.5M samples, no need for min_ds_size.
    • One concern is about the number of epochs. The 20 epochs could be good enough, but the learning rate scheduling could be different from what I have tried. But it would be worth a try. In addition, use the fewer warmup epochs, such as --warmup_epochs 3. (I also updated for --warmup_epochs in my comment above)
    • For the background noise, you can try it later. I recommend it because it was useful for M2D-S & medical applications; it would likely be useful for you too.
  5. Similar to 4, but perform further pre-training on the final application data because it was effective in my TASLP paper.

@marcelomatheusgauy
Copy link
Author

Thanks for the answer. I will be analyzing/testing how to schedule the learning rate for option 4 then. That was a very helpful comment that would have taken me time to figure out.

Option 5 is nice and I will do it once I actually get the data. As I mentioned, the hospital patient data will be collected over the next months and might take a year or so to be enough in quantity. Until then I can perform the other 4 options.

I am closing the issue now as I believe it has been solved. You were a great help.

@marcelomatheusgauy
Copy link
Author

By the way, one last question: do you have an intuition on what values to expect for the loss during pre-training? While the primary measure of performance will be the performance on downstream tasks, I am curious whether I can rule out certain tests based on initial performance during pre-training, as they just do not reach low enough loss for M2D to have learned effective audio representations.

@daisukelab
Copy link
Collaborator

daisukelab commented May 24, 2024

Regarding the loss values, I have included a log of ICBHI 2017 fur-PT here.
examples/logs/log_m2d_x_vit_base-80x200p16x4-230814-Ddffsd50ks5blr0003bs128a2nr.3-e600.out -> Fixed to example_logs.zip
As you can see, the loss would be around 0.4 in a fur-PT when using a noise ratio of 0.2. If using a noise ratio of 0.0, it would be around 0.25.

Regarding the performance check, we use the linear evaluation using our evaluator EVAR. Please check the "2. Evaluating M2D" in README.md for the details.
I recommend tasks spcv2 and cremad for your purpose.

Regarding the evaluation of speech tasks, the 16x16 model does not perform well in speech tasks such as ASR/phoneme recognition.
Speech tasks that require an understanding of phoneme-level information, such as ASR, would need finer time resolution, while frequency resolution is not needed. Then, 80x2 speech models are suitable for this purpose. (But it would be another story for you)

Lastly, in our experiments on ICBHI 2017, the 16x4 (40ms time frame) model outperformed the 16x16 model.
Then, I also recommend the weight m2d_vit_base-80x200p16x4-230529. Note that this model consumes more memory, so, you might try 16x16 first. The 16x4 would be a future option for you.

Please let me know if you have trouble with the EVAR setup.

@marcelomatheusgauy
Copy link
Author

Thank you for your answer. So the losses I see seem to be more or less in line with what you observed, though it might be possible to do better as I am using a larger dataset and the learning rate schedule could be improved.
*FYI: the link to the ICBHI 2017 fur-PT log is broken for me. I was also unable to find it by entering the examples folder as there was no logs folder.

I will setup EVAR as well as additional Brazilian Portuguese speech tasks (they are classification tasks, such as emotion, speaker or gender recognition). I intend to avoid ASR tasks as my understanding was the same one you presented: finer time resolution is needed for those. I will ask again if I have problems with EVAR.

I also found out I can access a server with 8 A100 GPUs which is probably enough to do pretraining from scratch, as well as most configurations I can imagine using your model.

An additional question not directly related to your work: In our experiments with multiple pretrained models we have found regression tasks to be typically hard, with the task often becoming easier by adapting it to some sort of multiclass classification. This seems to be in line with what other researchers reported us: namely, it is better to change a regression task into a multiclass classification. I am curious whether you have encountered similar issues and your opinion on the matter. My impression is that the losses we use (the ones I tried were using masked reconstruction losses, but yours is a little different) give the model information on the data distribution but not on how the space of data points evolves from one state to another which is quite possibly necessary information for a regression task as there is a metric on the labels. I am curious whether there is some form of loss term that may be added which helps models also perform well via transfer learning on regression tasks.

@daisukelab
Copy link
Collaborator

Please find logs here: example_logs.zip I added M2D, M2D-S logs in addition to the M2D-X for ICBHI2017 log.
(And, I also updated the guide document.)

And it's a great question. However, I have no experience with regression tasks.
If possible, could you share the paper's title about "what other researchers reported us: namely, it is better to change a regression task into a multiclass classification." (But maybe it is not paper; if so, it's OK.)

Actually, we need further investigation to understand how the masked reconstruction or M2D models the input signal in the output features.
Does it learn the underlying mechanism for how sounds are generated from the enormous use cases found in the data samples? Or does it learn the enormous but simple combination of the patterns found in the spectrograms from the data samples?
While we expect the former to happen, the reality would likely be closer to the latter.
I do not know the loss to encourage the model to learn how the space of data points evolves from one state to another; that might be a good new research topic.
Good luck to you in that direction, and thanks for sharing your interest. That is indeed an important and unexplored topic.

@marcelomatheusgauy
Copy link
Author

Thank you for the logs. They will be helpful.

Unfortunately, I have only exchanged words with other researchers who reported this issue. I have not found a paper documenting this problem. It would be an interesting research problem, in particular for us, as regression tasks will likely be part of the set of tasks we will have to consider for respiratory problems in hospital patients.

Indeed I agree that M2D or Masked Reconstruction probably induce the models to learn simple combinations of the patterns found in the spectrograms and does not lead them to understand the underlying mechanism for sound generation. That is an interesting research question as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants