Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]" during reproducing mlm pre-training #9

Closed
Duemoo opened this issue Mar 12, 2023 · 5 comments

Comments

@Duemoo
Copy link

Duemoo commented Mar 12, 2023

Hello,
I was trying to pre-train the ATLAS model (base & large size), by running the provided example script in atlas/example_scripts/mlm/train.sh with 4 40GB A100 GPUs, but then I got this error:

Traceback (most recent call last):                                                            
  File "/home/work/atlas/atlas/train.py", line 223, in <module>                                                                                                                             
    train(                                     
  File "/home/work/atlas/atlas/train.py", line 77, in train                                                                                                                                 
    reader_loss, retriever_loss = model(                                                      
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl                                                                     
    return forward_call(*input, **kwargs)                                                     
  File "/home/work/atlas/atlas/src/atlas.py", line 432, in forward                                                                                                                          
    passages, _ = self.retrieve(                                                              
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                                
    return func(*args, **kwargs)                                                              
  File "/home/work/atlas/atlas/src/atlas.py", line 181, in retrieve                                                                                                                         
    passages, scores = retrieve_func(*args, **kwargs)[:2]                                                                                                                                   
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                                
    return func(*args, **kwargs)                                                              
  File "/home/work/atlas/atlas/src/atlas.py", line 170, in retrieve_with_rerank                                                                                                             
    retriever_scores = torch.einsum("id, ijd->ij", [query_emb, passage_emb])                                                                                                                
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 328, in einsum                                                                                 
    return einsum(equation, *_operands)                                                       
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 330, in einsum                                                                                 
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]                                                                                                                     
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536]                                           
srun: error: localhost: task 0: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 3: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 1: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 2: Exited with exit code 1 

I used the provided passages (Wikipedia Dec2018 dump), and ran the script without any changes in training arguments.
So, the batch size per device was 2 and 100 documents were retrieved by the retriever, regarding [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536] in the error message above.
In addition, I found that the script and the overall pre-training process worked well after removing this line from the script, i.e., doing re-indexing of the whole passages instead of doing re-ranking, although this resulted in lower few-shot performance compared to the scores reported in Table.19 from the paper. (However, I think the performance issue might be irrelevant to the removal of this line)
--retrieve_with_rerank --n_to_rerank_with_retrieve_with_rerank 100 \

Could you provide any hints to solve this issue? Thank you in advance!

@jeffhj
Copy link
Contributor

jeffhj commented Mar 12, 2023

Hi @Duemoo, you may refer to this pull request to fix the issue for reranking.

@jeffhj
Copy link
Contributor

jeffhj commented Mar 12, 2023

I also cannot reproduce the 64-shot results reported in Table 2 with the pretraining/finetuning script and the settings described in the paper (without reranking in pretraining).

@Duemoo
Copy link
Author

Duemoo commented Mar 13, 2023

Thank you! I applied your commits and it fixed the problem :)

@Duemoo
Copy link
Author

Duemoo commented Mar 13, 2023

@jeffhj Thank you for the information. I guess that using CCNet indices together could be an important factor for the performance, and so I'm trying to reproduce the results using CCNet texts as well.

@mlomeli1
Copy link
Contributor

mlomeli1 commented May 5, 2023

Thank you @jeffhj for the fix, I've merged it into master, so I am closing this issue.

@mlomeli1 mlomeli1 closed this as completed May 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants