Skip to content

Context managers#13900

Merged
lvwerra merged 3 commits intomasterfrom
context_managers
Oct 20, 2021
Merged

Context managers#13900
lvwerra merged 3 commits intomasterfrom
context_managers

Conversation

@lvwerra
Copy link
Copy Markdown
Member

@lvwerra lvwerra commented Oct 6, 2021

What does this PR do?

This PR adds a ContextManagers class that wraps one, multiple or no context managers and acts as a single context manager as discussed with @sgugger and @stas00. This should reduce code duplications where the context is conditional e.g. can only be applied if a framework is used or flag is set.

Example

The following example should illustrate how the ContextManagers class works:

import contextlib

@contextlib.contextmanager
def context_en():
    print('Welcome!')
    yield
    print('Bye!')
    
@contextlib.contextmanager
def context_fr():
    print('Bonjour!')
    yield
    print('Au revoir!')

We can then either specify no, one, or more contexts:

with ContextManagers([]):
    print('Transformers are awesome!')
print()

with ContextManagers([context_en()]):
    print('Transformers are awesome!')
print()
    
with ContextManagers([context_en(), context_fr()]):
    print('Transformers are awesome!')

The context managers will be nested in the ContextManagers

>>>Transformers are awesome!
>>>
>>>Welcome!
>>>Transformers are awesome!
>>>Bye!
>>>
>>>Welcome!
>>>Bonjour!
>>>Transformers are awesome!
>>>Au revoir!
>>>Bye!

Use cases

With this class the following code snippets can be rewritten:

if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()

contexts = []
if is_deepspeed_zero3_enabled():
    import deepspeed

    contexts.append(deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None))

with ContextManagers(contexts):
    old_num_tokens, old_embedding_dim = old_embeddings.weight.size()

Contexts with additional conditional statements are probably best dealt with like that:

if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings

contexts = []
if is_deepspeed_zero3_enabled():
    import deepspeed
    
    contexts.append(deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0))

with ContextManagers(contexts):
    if not is_deepspeed_zero3_enabled() or torch.distributed.get_rank() == 0:
        new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]

Another place where code duplication is even more severe is in upcasting added in #13573:

if is_amp_available:
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

What do you think about this implementation?

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot for taking care of it! Could you add the (great) demos you put in the presentation as tests in test_file_utils? That would be awesome :-)

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Oct 6, 2021

Hmm, the PR/implementation itself is awesome. thank you, @lvwerra

Looking at before and after examples I'm concerned with whether this doesn't make readability/easy-of-understanding worse than before.

Perhaps this appears to be that way since I have been staring at the old way for a long time and the new code looks unfamiliar? I.e. is it the case of habituation bias?

Thoughts?

with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
Copy link
Copy Markdown
Contributor

@stas00 stas00 Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw we also have a context manager for catching std streams.
https://huggingface.co/transformers/master/testing.html#testing-the-stdout-stderr-output
scroll down to CaptureStd

no need to change - as either way works, just sharing we already have a built-in for that functionality.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this method hides the output from user being able to debug what's going on - as I just fixed CaptureStd not to hide output by default and only make a copy of it. #13803

again nothing needs to be done in this case as the output is trivial.

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Oct 6, 2021

I think the better example is the code mentioned last, where there is a duplication of three lines. For a duplication of just one line, the benefit is not super visible indeed.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Oct 6, 2021

I'm not at all sure the number of duplicated lines is the issue.

With the way things are now the reader can instantly choose the relevant branch using the in-head processor and flow through as they read the code.

With the addition of the context manager it is far more difficult to figure out what's going on.

And if additional conditions are added as in the 2nd example in OP that makes the code even more difficult to follow.

IMHO, this current line of work is mainly to make the maintainability better - e.g. to prevent cases where we fixed one copy of a branch and forgot the other - especially since they aren't even aligned the same.

Would the following approach improve maintainability while keeping the readability easy?

 def run():
     q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 
     attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 
     attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 
     
 if is_amp_available: 
     with autocast(enabled=False): 
        run()
 else: 
    run()

before:

 if is_amp_available: 
     with autocast(enabled=False): 
         q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 
         attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 
         attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 
 else: 
     q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 
     attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 
     attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Oct 20, 2021

Not sure what's blocking this PR from getting merged @lvwerra ? It's just a tool that people are free to use or not and I personally would really like to use this in the Trainer to simplify all the code using autocast block in an if statement.

The particular

if is_amp_available: 
     with autocast(enabled=False):
        some_code
else:
    same_code

is present so many that I will create a autocast_if_amp_is_available contextmanager using your tools to simply do:

with autocast_if_amp_is_available()
    some_code

@lvwerra lvwerra merged commit 0270d44 into master Oct 20, 2021
@lvwerra lvwerra deleted the context_managers branch October 20, 2021 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants