Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Logging token level losses at inference time #4169

Conversation

c-flaherty
Copy link
Contributor

@c-flaherty c-flaherty commented Nov 12, 2021

Intro

In this PR, I update TorchGeneratorAgent to support logging token level conditional probabilities and ranks at inference time when in verbose mode. This logging works whether the agent is using greedy, beam, or any other kind of token generation supported by the agent currently.

Here are two examples:

parlai dm --model-file zoo:unittest/transformer_generator2/model --truncate 1024 -v --task integration_tests:multiturn_nocandidate -ne 1 --inference beam --beam-size 3

Screen Shot 2021-11-12 at 1 45 35 PM

parlai dm --model-file zoo:unittest/transformer_generator2/model --truncate 1024 -v --task integration_tests:multiturn_nocandidate -ne 1 --inference greedy

Screen Shot 2021-11-12 at 1 46 40 PM

A brief explanation with code pointers:

The scores are here:

score = F.log_softmax(score, dim=-1, dtype=torch.float32) # type: ignore

To my understanding score is a tensor of shape (batch size, num of beams, vocab size) and score[b, i, :] for example contains the conditional probabilities of each token in vocab being the next token in the i-th beam of batch b. However, generation candidates are added to beam objects whenever an EOS token is found. Therefore, accumulating score does not really get us all the way there. We need to have each beam accumulate token probabilities of beams and store them whenever finished hypotheses are found. This requires

(1) adding an additional parameter to the TreeSearch object to store them:

# keeps tuples (score, time_step, hyp_id)
self.finished = []

(2) updating TreeSearch:select_paths method to output token probabilities of next paths in beam:

def select_paths(self, logprobs, prior_scores, current_length):

(3) updating store of beam token probabilities each time TreeSearch:advance method is called:

hyp_ids, tok_ids, self.scores = self.select_paths(
logprobs, self.scores, current_length
)
# use clone() here to ensure that self.all_scores will not be changed
# later due to any penalties to self.scores
self.all_scores.append(self.scores.clone())
self.outputs.append(tok_ids)
self.bookkeep.append(hyp_ids)
tok_id_list = tok_ids.tolist()
self.partial_hyps = [
self.partial_hyps[hyp_ids[i]] + [tok_id_list[i]]
for i in range(self.beam_size)
]

and finally (4) storing token probabilities along with candidate utterances in TreeSearch:finished parameter:

# check new hypos for eos label, if we have some, add to finished
for hypid in range(self.beam_size):
if self.outputs[-1][hypid] == self.eos:
if self.scores[hypid] <= neginf(self.scores.dtype):
continue
# this is finished hypo, adding to finished
eostail = _HypothesisTail(
timestep=len(self.outputs) - 1,
hypid=hypid,
score=self.all_scores[-1][hypid],
tokenid=self.eos,
)
self.finished.append(eostail)

Once, we do this, we will be able to easily able to pass through token probabilities through get_rescored_finished (the method introducing length penalty to utterance level probability) while keeping them stored alongside candidates, output them from TorchGeneratorAgent:_generate in both beam_preds_scores and in beams for free, and finally assign them to token_losses variable in TorchGeneratorAgent:eval_step, so we can output them in the same way they are outputted for examples with labels

Tests:

pytest tests/test_tga.py

parlai/core/torch_generator_agent.py Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
@emilydinan emilydinan self-requested a review November 22, 2021 18:48
Copy link
Contributor

@emilydinan emilydinan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the well-documented PR description @c-flaherty and nice job so far!! this is going to be a super useful feature.

i didn't finish reviewing but found some things that need fixing so i wanted to unblock you before EOD!

parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
Copy link
Contributor

@emilydinan emilydinan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! this is looking so much simpler. added a few more questions/comments!

parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Show resolved Hide resolved
parlai/core/torch_generator_agent.py Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
tests/test_tga.py Show resolved Hide resolved
Copy link
Contributor

@emilydinan emilydinan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks nearly good to go, but need one last fix in beam search.

i'll also add one last request that you add a test that would capture some of the bugs you ran into here. For example, if you initialize the BeamSearch object and pass a set of (fake) logprobs to select_paths, can you ensure that it returns you the correct token scores? (& similarly for the other tree search objects)

parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_generator_agent.py Outdated Show resolved Hide resolved
@c-flaherty
Copy link
Contributor Author

I couldn't get the code to go any faster after playing around with some suggestions from Stephen, so instead I moved all token level logging related code behind guard statements. This means that none of the tensor operations related to token level logging run unless verbose mode is on. As expected, this change means that running this code without verbose mode on is not any slower than running this code on main.

On main, running CUDA_LAUNCH_BLOCKING=1 perf stat -r 10 -d parlai eval_model --task "babi:Task1k:2" --model-file "zoo:tutorial_transformer_generator/model" --optimizer adam:

Performance counter stats for 'parlai eval_model --task babi:Task1k:2 --model-file zoo:tutorial_transformer_generator/model --optimizer adam' (10 runs):

          45868.78 msec task-clock                #    1.159 CPUs utilized            ( +-  1.66% )
           1986695      context-switches          #    0.043 M/sec                    ( +-  3.08% )
               146      cpu-migrations            #    0.003 K/sec                    ( +-  2.83% )
           1326937      page-faults               #    0.029 M/sec                    ( +-  0.29% )
      164034479904      cycles                    #    3.576 GHz                      ( +-  1.65% )  (49.98%)
      197481840141      instructions              #    1.20  insn per cycle           ( +-  0.41% )  (62.58%)
       41177732097      branches                  #  897.729 M/sec                    ( +-  0.48% )  (62.69%)
         414964167      branch-misses             #    1.01% of all branches          ( +-  0.57% )  (62.78%)
       61351690838      L1-dcache-loads           # 1337.548 M/sec                    ( +-  0.60% )  (62.66%)
        3185721966      L1-dcache-load-misses     #    5.19% of all L1-dcache hits    ( +-  0.70% )  (62.42%)
         204441820      LLC-loads                 #    4.457 M/sec                    ( +-  0.52% )  (49.72%)
          32770298      LLC-load-misses           #   16.03% of all LL-cache hits     ( +-  1.03% )  (49.74%)

            39.568 +- 0.758 seconds time elapsed  ( +-  1.92% )

On c-flaherty:add_token_level_probability_logging_TGA, running CUDA_LAUNCH_BLOCKING=1 perf stat -r 10 -d parlai eval_model --task "babi:Task1k:2" --model-file "zoo:tutorial_transformer_generator/model" --optimizer adam:

Performance counter stats for 'parlai eval_model --task babi:Task1k:2 --model-file zoo:tutorial_transformer_generator/model --optimizer adam' (10 runs):

          45668.64 msec task-clock                #    1.165 CPUs utilized            ( +-  1.09% )
           1804133      context-switches          #    0.040 M/sec                    ( +-  4.24% )
               139      cpu-migrations            #    0.003 K/sec                    ( +-  3.24% )
           1326329      page-faults               #    0.029 M/sec                    ( +-  0.13% )
      163236060522      cycles                    #    3.574 GHz                      ( +-  1.18% )  (49.82%)
      197468618586      instructions              #    1.21  insn per cycle           ( +-  0.55% )  (62.27%)
       41286514451      branches                  #  904.045 M/sec                    ( +-  0.59% )  (62.28%)
         423899974      branch-misses             #    1.03% of all branches          ( +-  1.77% )  (62.37%)
       61301661616      L1-dcache-loads           # 1342.314 M/sec                    ( +-  0.61% )  (62.54%)
        3163267796      L1-dcache-load-misses     #    5.16% of all L1-dcache hits    ( +-  1.22% )  (62.76%)
         204609470      LLC-loads                 #    4.480 M/sec                    ( +-  0.53% )  (50.20%)
          32725509      LLC-load-misses           #   15.99% of all LL-cache hits     ( +-  0.53% )  (50.04%)

            39.206 +- 0.522 seconds time elapsed  ( +-  1.33% )

On c-flaherty:add_token_level_probability_logging_TGA, running CUDA_LAUNCH_BLOCKING=1 perf stat -r 10 -d parlai eval_model --task "babi:Task1k:2" --model-file "zoo:tutorial_transformer_generator/model" --optimizer adam -v:

Performance counter stats for 'parlai eval_model --task babi:Task1k:2 --model-file zoo:tutorial_transformer_generator/model --optimizer adam -v' (10 runs):

          48766.12 msec task-clock                #    1.156 CPUs utilized            ( +-  1.02% )
           1852279      context-switches          #    0.038 M/sec                    ( +-  3.12% )
               145      cpu-migrations            #    0.003 K/sec                    ( +-  2.74% )
           1325857      page-faults               #    0.027 M/sec                    ( +-  0.22% )
      174661626177      cycles                    #    3.582 GHz                      ( +-  1.11% )  (50.21%)
      214397405556      instructions              #    1.23  insn per cycle           ( +-  0.42% )  (62.70%)
       45077919773      branches                  #  924.370 M/sec                    ( +-  0.43% )  (62.63%)
         432371445      branch-misses             #    0.96% of all branches          ( +-  0.99% )  (62.49%)
       67462839403      L1-dcache-loads           # 1383.396 M/sec                    ( +-  0.44% )  (62.33%)
        3388401715      L1-dcache-load-misses     #    5.02% of all L1-dcache hits    ( +-  0.62% )  (62.30%)
         210949249      LLC-loads                 #    4.326 M/sec                    ( +-  0.44% )  (49.97%)
          33144635      LLC-load-misses           #   15.71% of all LL-cache hits     ( +-  0.72% )  (50.08%)

            42.195 +- 0.429 seconds time elapsed  ( +-  1.02% )

Copy link
Contributor

@emilydinan emilydinan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for persisting through several rounds of edits @c-flaherty!!! This will be a super helpful change 😄

This looks good to go from my perspective, pending tests passing.

@stephenroller -- did you want to take a look before merge?

Copy link
Contributor

@emilydinan emilydinan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm -- do the rag tests you skip pass locally?

@@ -88,8 +88,9 @@ def test_retrieval_none(self):
_test_bb2_rag(KnowledgeAccessMethod.NONE, n_docs=1)


@testing_utils.skipUnlessGPU
@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive")
# @testing_utils.skipUnlessGPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may have left these comments in by mistake

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just testing to right now to see if skipping these tests fixes the cache issue. If it does, then yea will remove this comment, but if it doesn't, then I'll revert these changes related to skipping tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And address the cache issue in a separate pr)

@@ -110,7 +110,7 @@
}


@testing_utils.skipUnlessGPU
@unittest.skip("Cache too large")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to skip these all of the time? Or just when they run on Circle CI? If so, you can use the decorator skipIfCircleCI (in utils/testing.py).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, yeah I'll make that change if this indeed does fix the cache issue. I'll probably have to remove all these changes to the test decorators anyways though, as they probably won't fix the cache issues. Will keep this in mind!

@c-flaherty c-flaherty merged commit daa85bf into facebookresearch:main Jan 5, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants