Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add CUDA Kernel for TreeSearch Ngram Blocking #4633

Merged
merged 39 commits into from
Jul 5, 2022

Conversation

pearlli98
Copy link
Contributor

@pearlli98 pearlli98 commented Jun 27, 2022

Major changes

  1. add availability of cuda kernel for running ngram blocking (both self-blocking and context-blocking), can be activated by the flag --gpu-beam-blocking True, functionalities are wrapped in the class NGramRepeatBlockFunction
  2. all subclasses of TreeSearch has new boolean attribute gpu_beam_blocking that we use in _block_ngrams()
  3. use tensors for the attributes self.partial_hyps and self.context of TorchGeneratorAgent instead of lists
  4. add unit tests for gpu ngram blocking
  5. add ninja and protobuf to dependency list to enable CUDA extensions

Testing steps

gpu tests: with cuda enabled

  1. run pytest tests/test_transformers.py -k test_beamsearch_blocking_gpu
  2. pytest tests/test_transformers.py -k test_beamsearch_contextblocking_gpu

cpu tests:

  1. run pytest tests/test_transformers.py -k test_beamsearch_blocking_cpu
  2. pytest tests/test_transformers.py -k test_beamsearch_contextblocking_cpu

Other information

  1. To run an interactive model with gpu beam blocking, do
    parlai interactive --model-file "zoo:tutorial_transformer_generator/model" --gpu-beam-blocking True

Evaluation
Have done evaluation on the convai2 teacher task on 3 settings: (1) code on main (2) new code with cpu (3) new code with gpu kernel, results shown below

  • Correctness: green check, utterances from all 3 settings are identical.
  • Runtime: We are seeing ~10% improvement with the new gpu kernel.
  main gpu kernel cpu
average of 10 runs 697.817s 620.084s 689.101s
change \ -11.14% -1.25%

@@ -1466,7 +1503,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor:
logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype)
return logprobs

def advance(self, logprobs):
def advance(self, logprobs, step):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a easy walk around here to avoid breaking current tests, is to set a default parameter to the step here, step=0 or sth when it is not given, since it is used in non GPU beam blocking settings.

Copy link
Contributor

@dexterju27 dexterju27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test failure seems to be cased by
NameError: name 'NGramRepeatBlock' is not defined indicates the CUDA cpp binding function is not handled properly by the CI.

@@ -956,6 +965,7 @@ def _treesearch_factory(self, device, verbose=False):
)
elif method == 'beam':
return BeamSearch(
self.opt['gpu_beam_blocking'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have set self.gpu_beam_blocking, why not use it?

@@ -1443,15 +1459,36 @@ def _block_ngrams(
Source text to grab ngrams from. If None, it uses the current
hypothesis (i.e. self-blocking).
"""
context = None
if self.gpu_beam_blocking:
if if_context_blocking:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it if self.gpu_beam_blocking and if_context_blocking?

@pearlli98 pearlli98 changed the title [WIP] Add CUDA Kernel for Beam Blocking [WIP] Add CUDA Kernel for TreeSearch Ngram Blocking Jun 28, 2022
@klshuster
Copy link
Contributor

is this ready for review? or still WIP?

@pearlli98 pearlli98 changed the title [WIP] Add CUDA Kernel for TreeSearch Ngram Blocking Add CUDA Kernel for TreeSearch Ngram Blocking Jun 30, 2022
@pearlli98
Copy link
Contributor Author

is this ready for review? or still WIP?

this is ready for review.

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks great! I'm approving but it looks like we still have cleaninstall failures; any chance you could try fixing those?

@@ -1367,11 +1391,14 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType:
a LongTensor representing the input context; used for context
ngram blocking, if supplied
"""
self.context = context.tolist()
self.context = torch.Tensor(context.tolist()).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is default dtype here? curious why .long() cast is needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outdated change, reverted to main code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaninstall failures seemed to have disappeared without me changing anything.

step,
beam_size,
no_repeat_ngram_size,
if_context_blocking=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this parameter necessary? I.e., can't we assume that if context is passed, we block on it? or does this make the downstream logic easier to handle in the kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this helps downstream kernel logic easier. Kernel params can't be none so I initialize a placeholder tensor as the empty context. Will need a bool to tell whether we are doing self-blocking or context-blocking.

@@ -1725,7 +1792,8 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection
voc_size = logprobs.size(-1)

# get the backtracking hypothesis id as a multiple of full voc_sizes
hyp_ids = best_idxs // voc_size
# hyp_ids = best_idxs // voc_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: feel free to delete. and also, thank you for changing this (I've seen this warning several times)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

setup.py Outdated
@@ -9,6 +9,8 @@

from setuptools import setup, find_packages

# from torch.utils.cpp_extension import BuildExtension, CUDAExtension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you please delete these commented lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

@@ -0,0 +1,50 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is our code freshly rewritten, or did we get it from fastseq. We need to maintain the original copyright headers (in addition to our own) if that's the case. Both are MIT so it's no problem, but we need to include Copyright (c) Microsoft etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! We definitely referenced their code but i've also made some significant changes. I think we should add the copyright here? can you let me know what's the right move? not super familiar with this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would sth like this work?

/*
Copyright (c) Facebook, Inc. and its affiliates.
Copyright (c) Microsoft Corporation.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed in our mtg, you can cite the original file and say we adapt from it. And maintain the Microsoft copyright header.

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow great progress!

@@ -0,0 +1,112 @@
/*
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
Copy link
Contributor

@dexterju27 dexterju27 Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what @stephenroller mentioned about licensing should apply to this file as well, I suppose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, edited this as well.

int vocab_size,
int no_repeat_ngram_size,
bool if_context_blocking) {
auto row = blockIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments here on what row and col means here?


// final thread writes the end of previous ngram array to tokens_shm
if (col == blockDim.x - 1) {
for (int i=1; i<no_repeat_ngram_size; i++){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spacing, you could apply a CPP formatter on this file

@@ -422,7 +422,59 @@ def test_beamsearch_blocking(self):
assert '34 34' not in text

@pytest.mark.nofbcode
def test_beamsearch_contextblocking(self):
@testing_utils.skipUnlessGPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding all these tests

@pearlli98 pearlli98 closed this Jul 5, 2022
@pearlli98 pearlli98 reopened this Jul 5, 2022
@pearlli98 pearlli98 marked this pull request as ready for review July 5, 2022 19:19
@pearlli98 pearlli98 merged commit dff9aab into main Jul 5, 2022
@pearlli98 pearlli98 deleted the pearlli-ngram-blocking-kernel branch July 5, 2022 19:53
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants