Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alibi Tensor Parallel Fix #244

Merged
merged 6 commits into from
Feb 1, 2022
Merged

Alibi Tensor Parallel Fix #244

merged 6 commits into from
Feb 1, 2022

Conversation

DanielHesslow
Copy link
Collaborator

Addressing issue #227. Added test for loading and comparing output of the same model using different TP degree. Validated that there was an issue and that is now fixed.

@stas00
Copy link
Member

stas00 commented Jan 31, 2022

For some reason the new test keeps on failing on CI - though runs fine on my machine.

And while you're debugging please s/test_tensor_paralell.py/test_tensor_parallel.py/ - didn't want to step on your toes, Daniel.

If you want to speed up the debug on CI, temporarily change

run: pytest --timeout=600 tests
to just run this test

to:

run: pytest --timeout=600 tests/test_tensor_paralell.py     

but let's not forget to undo the change before merging.

@stas00
Copy link
Member

stas00 commented Jan 31, 2022

Aha, you're running into the really hard to debug problem of when you don't spawn an external process to run gpu - the memory never gets freed and neither the port, since pytest didn't exit and you didn't explicitly release the port - which most programs never program for.

So in isolation it works just fine but when combined with other similar tests things break in very subtle ways.

It's safer to use an external program, get it to save the data and then read the data in the main process - far less convenient and it's slower but it's far more resilient.

A lot of these complex programs use globals, so it's not enough to del some object to get them to release ports/gpus, as they weren't written with that kind of use in mind.

@DanielHesslow
Copy link
Collaborator Author

DanielHesslow commented Feb 1, 2022

Seems like third time was the charm, I think they should actually be spawning separate processes with mp.set_start_method('spawn', force=True) otherwise pytorch won't be happy. The port already in use thing I guess would be an issue as long as we're on the same machine. Probably the previous test had just not properly cleaned up the port before this one started.

In either case, I think all should be good now, certainly the critical alibi issue is solved, if the tests starts causing trouble I'll come back to it :)

Copy link
Member

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome fix and the test - thank you, Daniel.

please s/test_tensor_paralell.py/test_tensor_parallel.py/ and it's good to go.

Comment on lines 2 to 4
import sys, os
dir = os.path.abspath(os.path.join(os.path.dirname(__file__),os.path.pardir))
sys.path.append(dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a more modern readable version would be:

import sys
from pathlib import Path
git_repo_path = Path(__file__).resolve().parents[1]
sys.path.insert(1, str(git_repo_path))

I wasn't sure if it's one parent up though or 2 you wanted.

but it's fine as it is as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I'll just remove it all together. Was just for quicker iteration without going through pytest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in such cases then it's easier to use:

PYTHONPATH=`pwd` tests/test.py

or something like that. :)

@stas00
Copy link
Member

stas00 commented Feb 1, 2022

The port already in use thing I guess would be an issue as long as we're on the same machine. Probably the previous test had just not properly cleaned up the port before this one started.

That's what I was trying to say - when running deepspeed inside pytest one either has to continue using the same port and hope that nobody parallelizes the test running, or better to switch to a launcher, except it makes things much slower to start. I tried to ask for deepspeed to become more friendly to notebook envs (which pytest is like one), but they have a bigger fish to fry, as the product was designed for large jobs that exit on completion.

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First review comments, haven't read everything, will read more tomorow morning! Either way thanks for the clean fix and for the test!

#Select the part of the tensor that corresponds to our tensor parallel index.
tp_world_size = mpu.get_tensor_model_parallel_world_size()
tp_index = mpu.get_tensor_model_parallel_rank()
alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
num_attention_head_per_partition = mpu.divide(num_attention_heads, tp_world_size)
alibi = alibi[tp_index * num_attention_head_per_partition: (tp_index + 1) * num_attention_head_per_partition]

Personally I always find reshape to be weird magic.
We can probably do something more efficient by only computing what we need, but let's do that for now since this is just done at init and should be short.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disagree, I think the reshape is clearer, I'll keep it as is.

Comment on lines 3 to 4
dir = os.path.abspath(os.path.join(os.path.dirname(__file__),os.path.pardir))
sys.path.append(dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm so confused, our tests don't have that. I'm guessing you haven't installed the repo via pip. Please remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should be installed via pip? Megatron-LM and its derivatives aren't installable.

Copy link
Member

@stas00 stas00 Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we are already adding the root dir automatically here for all tests to enjoy.

git_repo_path = abspath(join(dirname(dirname(__file__))))
sys.path.insert(1, git_repo_path)

so it's probably just redundant. and that's why Thomas suggested to remove it.

@@ -0,0 +1,219 @@
from gc import get_referents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some unused imports

Comment on lines 27 to 35
def flatten_arguments(args):
"""
Converts dictionary argument to a list.

Note: we add "IGNORED" at the beginning as this value is ignored by the argparser

Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's duplicated code. Ideally if you think it's helpful you can add it to testing_utils and import it directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]


def equal_vectors(tensor1, tensor2, dim=-1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 73 to 74
# ALIBI:
"--position-embedding-type":"alibi",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not put this as default? Essentially the way I see it is default is basically common config people would use. If you strongly disagree let me know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, fixed

Comment on lines 83 to 84
# paralell args
"--tensor-model-parallel-size":str(tp_size),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, you can

args = get_default_args()
args["--tensor-model-parallel-size"] = str(tp_size)

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the fix correspond to what we want. All comments are on the testing, so if needed we can merge for the training. Mostly me not understanding some of the subtleties that DS has.

Comment on lines 194 to 198
command_args = self.get_default_args(tp_size = 1)
pool = Pool(1)
result = pool.map(MyTestCase.infer_model, [((0, 1, command_args, None, cp_dir, None))])
pool.close()
pool.join()
Copy link
Member

@thomasw21 thomasw21 Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blows me, that you could do this like that! Awesome!

dist_env = dict(
MASTER_ADDR="localhost", MASTER_PORT="9991", RANK=str(tp_index), LOCAL_RANK=str(tp_index), WORLD_SIZE=str(tp_size)
)
logging.getLogger().critical("Process: starting")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create a logger in this file.

Comment on lines +110 to +112
#Hack
import megatron.initialize as init
init.git_ds_info = lambda: None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more comments for future? Listing this as a hack is hard to understand just by reading the code.

initialize_megatron()
args = get_args()

args.vocab_size = args.padded_vocab_size = 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be defaulted to the tokenizer size. If gpt2 is too big, you can create a smaller one no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah vocab_size does default to the tokenizers vocab size but we're not using the tokenizer so it doesn't really matter. However the padded_vocab_size is different if you change TP since it needs to be padded up to 128 on each tp-rank. And if the vocab size changes we can't load the model, (getting a mismatch in the shapes).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I get that, but I think you should change most values from padded_vocab_size to vocab_size in this file (typically when you change input). You mean that you can't load the same checkpoint with TP = 2? That's a bit unfortunate. I'm guessing #239 fixes your issue? If so, can you add a comment linking to that PR, and we'll remove it once that other once is merged? Otherwise this seems like a real issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

239 sorta solves it, while this is an issue with deepspeeds reshaping utils, we could at that point restructure the code to first convert the checkpoint using that util and then it should work.

if load is not None:
# Hack (same as in eval_harness/evaluate.py)
# Loading pipelined models in deepspeed with different TP than it was originally trained on fails
# due to a sanity check, that makes sure that all state_dicts that we merge contains attention layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird, so it doesn't support sharding the embedding layer? I thought it wasa common practice ....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is a bit weird. Sharding the embedding layer is fine, but currently the code for merging state dicts in the ds loader contains some asserts that all state_dicts contain attention layers. So this will trigger when we load cross tp size in pipelined models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum maybe @stas00 has a more insight on this. Should #239 fix this as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, 239 won't fix this, it's an issue in deepspeed. It's a bit weird that they check names of parameters considering it should be a general purpose lib but it is what it is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean the reshaping code should encounter this issue as well no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're doing their own merging and not using the deepspeed code for this, so they can fixup the mismatched shapes correctly.

Copy link
Member

@stas00 stas00 Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: #239

It's primarily because of TP. You need to know how to merge those params, so w/o names you can't do it. So it is not generic because of that. unless the source checkpoint comes up with a way to declare all these params, which is a good idea btw!

Please feel free to report additional issue with checkpoint merging while Tunji is working on it.

The idea is that the merging work will also be integrated with elastic checkpoint feature so that it should be able to reshape on the fly at load-time.

#output = model(*input_batch)
output = model.eval_batch(iter([token_ids]), compute_loss = False, reduce_output = None)[0]

output = gather_from_tensor_model_parallel_region(output)[..., :tokenizer.vocab_size]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this linked to vocab padding in particular :tokenizer.vocab_size? I'm guessing you mean that output can have difference sizes depending on TP. If so shouldn't you strip the padding from each output before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean I should strip the padding before the gather? I guess that would be marginally faster but I don't think it matters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand the :tokenizer.vocab_size. But after re-reading it's probably due to padded_vocab?

output = gather_from_tensor_model_parallel_region(output)[..., :tokenizer.vocab_size]

if save != None:
args.save = save
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no shouldn't be

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change my mind, not sure which line you talk about but:
The gather is neccessary cause the output is split across tp (normally the loss_fn deals with this)

The save check is necessary cause without it ds will complain about the save path being None.

args.save = save
save_checkpoint(0, [model], None, None)

return (output[0].detach().cpu().numpy(), token_ids.detach().cpu().numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to use numpy?


logging.getLogger().critical(output-output2)
import numpy as np
self.assertTrue(np.allclose(output,output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fail before? Just to make sure that the fix works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jupp :)


tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, this is GPTModelPipe right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

@DanielHesslow DanielHesslow merged commit 40e8b2a into main Feb 1, 2022
@DanielHesslow DanielHesslow deleted the alibi_tp branch February 1, 2022 15:01
@stas00
Copy link
Member

stas00 commented Feb 14, 2022

Hi @DanielHesslow,

Was there a special reason to make this test use a largish model and taking forever to finish? This makes it very difficult to move forward with the test suite.

Could we cut down the size and the length of iterations to a bare minimum like all the other training tests w/o undermining the purpose of the test? we normally run for about 20 iteration and this test is 5000 long!

Also currently it currently times out. Probably your testing environment is much stronger than the one CI uses.

Thanks.

@DanielHesslow
Copy link
Collaborator Author

Hmm, not quite sure what you mean. The model size is the same as in test_model.py which is rather small. 2 layer with hidden_size = 128 seq_len=256. Not chosen for any particular reason so feel free to change it if there's a better model size.

Re number of iterations it should just run one forward for each configuration. It doesn't do any training just comparing that we get the same output with the same random model with different TP size. Don't know why it would time out but if it does that clearly needs to be addressed, unfortunately I don't think I quite have the bandwidth to do it atm.

@stas00
Copy link
Member

stas00 commented Feb 15, 2022

oh, may be test_model.py chose a large size. I will check.

But I was just checking to see if there were any constraints on the size or the number of interations, I will work on making it finish faster.

Thanks a lot, Daniel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants