This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
TorchScript BART models #3459
Merged
Merged
TorchScript BART models #3459
Changes from 101 commits
Commits
Show all changes
127 commits
Select commit
Hold shift + click to select a range
8e2e391
Checkpoint jit
stephenroller b22ffc7
maryam checkpoint
79ad543
Changes so far; currently debugging incr decoding
EricMichaelSmith 60c7364
Trying to wrangle types
EricMichaelSmith 54bc0b8
Fully traced model, no incr decoding, generates past EOS
EricMichaelSmith b349ea8
WIP test to mix scripting and tracing
EricMichaelSmith f513488
Get scripting and tracing working, no incr decoding
EricMichaelSmith 4823b3c
Bit of cleanup
EricMichaelSmith f8936a7
Minor
EricMichaelSmith afe9656
Merge branch 'master' into jittest
EricMichaelSmith b9039d0
Prototype for incr decoding
EricMichaelSmith 0c97c1a
Fixes
EricMichaelSmith e5d9608
Minor
EricMichaelSmith e58f07e
Fixes
EricMichaelSmith 70599c5
Fixes
EricMichaelSmith 2162a63
Fixes
EricMichaelSmith ed8a253
Save path
EricMichaelSmith 2b8c560
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith dd33305
Path manager
EricMichaelSmith c29c2b8
Load exported model
EricMichaelSmith c4404e9
Compare outputs
EricMichaelSmith d9e87d8
Try original greedy search
EricMichaelSmith 033092e
Minor
EricMichaelSmith 1ca2428
Merge branch 'master' into jittest
EricMichaelSmith 2c377a7
Refactor a bit
EricMichaelSmith ed0a917
Partial work on overhauling script code
EricMichaelSmith a3b4d1b
Minor
EricMichaelSmith de0dfda
Lots of reversions
EricMichaelSmith 3997002
More reversion
EricMichaelSmith c96195c
Work on initial decoder input
EricMichaelSmith f642c57
Finish applying fixes from before scripting
EricMichaelSmith bbfafa4
Finish conversion
EricMichaelSmith aded9b5
Fixes
EricMichaelSmith b309164
Refactor for loading back in
EricMichaelSmith 6d05381
Move scripts
EricMichaelSmith 76f263f
Merge branch 'master' into jittest
EricMichaelSmith 2b806c6
Set up scripting tokenization
EricMichaelSmith a7d8dcf
Initial fixes
EricMichaelSmith c5164c3
Don't rely on History
EricMichaelSmith 23249b9
Minor
EricMichaelSmith 8aacdc1
Start work in ScriptableDictionaryAgent
EricMichaelSmith a93d581
Work on dict
EricMichaelSmith da8eaba
Enable more stuff
EricMichaelSmith d163fc9
Enable more code
EricMichaelSmith 817fca8
Set up more of the code
EricMichaelSmith ff85524
More
EricMichaelSmith 2981f25
Simplify
EricMichaelSmith a80c637
Work on cleaning up BPE helper
EricMichaelSmith dc23433
Minor
EricMichaelSmith 8e6d108
Pass in args
EricMichaelSmith 894746d
Minor
EricMichaelSmith 15b5600
Fuse keys
EricMichaelSmith 59f9f7e
Rearrange
EricMichaelSmith ba16e1f
Various fixes
EricMichaelSmith b6e25f0
Various fixes
EricMichaelSmith b5e1dd4
Start work on replacement function
EricMichaelSmith 4cce680
Minor
EricMichaelSmith f1877b4
Prototype
EricMichaelSmith e813de9
Work on scripting
EricMichaelSmith 805b32d
Work on scripting
EricMichaelSmith e123fe6
Detokenizing
EricMichaelSmith 205d52f
Runtime fixes
EricMichaelSmith f5e3ba1
Make splitting test more powerful
EricMichaelSmith 04412c7
Testing
EricMichaelSmith 9ca4945
Cleanup
EricMichaelSmith 0c75b9f
Merge branch 'master' into jittest
EricMichaelSmith b08e690
No for/else
EricMichaelSmith 9c7696e
A bit of cleanup
EricMichaelSmith 7d81a91
More cleanups
EricMichaelSmith c73925f
Format
EricMichaelSmith e7106d4
Start CI testing
EricMichaelSmith de6a71b
Try bumping to 1.7
EricMichaelSmith 96b70ea
Properly fix tests?
EricMichaelSmith 9d07efa
Fix
EricMichaelSmith 79a4137
Merge branch 'jittest' into jittest-ci-testing
EricMichaelSmith df43c02
Partial fix
EricMichaelSmith cb16dcd
TODOs
EricMichaelSmith 97cd0df
Remove breakpoints
EricMichaelSmith e3f0c79
Join incr states in the last dim instead
EricMichaelSmith 4ca5826
Model parallel fixes
EricMichaelSmith b5fb575
Minor
EricMichaelSmith df8b0eb
Cleanup
EricMichaelSmith 9747589
Don't track history
EricMichaelSmith e9a64a8
Minor
EricMichaelSmith 1723af6
Update parlai/agents/transformer/modules.py
EricMichaelSmith 84c816a
Comments so far
EricMichaelSmith 789a190
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith 79a4974
Reversions
EricMichaelSmith 0e208b1
ParlaiScript
EricMichaelSmith a79a6d1
Unit test overhaul
EricMichaelSmith ec6bae7
Minor
EricMichaelSmith c97f3a2
Minor
EricMichaelSmith 849c8ac
Merge branch 'master' into jittest
EricMichaelSmith e394cd9
Test agent
EricMichaelSmith ba0cb1c
Add second test
EricMichaelSmith 45312f9
Fix tests
EricMichaelSmith cd994b3
README
EricMichaelSmith 8860766
Test fix
EricMichaelSmith 0e7b38f
Conditional import
EricMichaelSmith 1611563
Fix import
EricMichaelSmith 6b6c084
Minor
EricMichaelSmith d7784ea
Rename
EricMichaelSmith 4bee499
Rename
EricMichaelSmith 5174dc0
Reverting stuff
EricMichaelSmith 4d5515c
Reverting stuff
EricMichaelSmith e7d601c
Reverting stuff
EricMichaelSmith 2fcd6ad
Update parlai/scripts/torchscript.py
EricMichaelSmith 00a1d26
Revert
EricMichaelSmith 14ec40b
Split ordering by layer
EricMichaelSmith a7fb474
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith 5572674
Minor
EricMichaelSmith d8ad611
Re-add reorder_incremental_state method
EricMichaelSmith 1d272a8
Fixes
EricMichaelSmith 2919fa1
Split stuff to delay import
EricMichaelSmith 0d7ae91
Remove list comprehension
EricMichaelSmith b66d477
Revert some code
EricMichaelSmith d898c24
Build out wrapping
EricMichaelSmith 2d6ca34
Fixes
EricMichaelSmith 18e68af
Revert
EricMichaelSmith 1d7d342
Merge branch 'master' into jittest
EricMichaelSmith 3807be9
Skip unless fairseq
EricMichaelSmith 6f16934
Move checks to nightly
EricMichaelSmith 9862d25
Merge branch 'master' into jittest
EricMichaelSmith feaf319
Merge branch 'master' into jittest
EricMichaelSmith a4377df
Try to speed up style_gen test
EricMichaelSmith ea18a61
Revert
EricMichaelSmith 8e15176
Reduce bs for wizard tasks
EricMichaelSmith File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Agent exported to TorchScript (JIT compilation) | ||
|
||
This agent will read in a ParlAI agent that has been exported to TorchScript with JIT compilation, for use in greedy-search inference on CPU. This allows inference to be run on models without using any ParlAI overhead, either for tokenization or for the forward passes through the model. Currently, only BART models are supported. | ||
|
||
Sample call for exporting a BART model to TorchScript: | ||
``` | ||
parlai jit_export \ | ||
--model-file ${MODEL_FILE} \ | ||
--model bart \ | ||
--no-cuda \ | ||
--scripted-model-file ~/_test_scripted_model__bart.pt \ | ||
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches' | ||
``` | ||
|
||
Interacting with an exported model using `parlai interactive`: | ||
``` | ||
parlai interactive --model-file ~/_test_scripted_model__bart.pt --model jit | ||
``` | ||
|
||
Loading in and running inference on an exported model, without any ParlAI overhead: | ||
``` | ||
python parlai/agents/jit/scripts/test_exported_model.py \ | ||
--scripted-model-file ~/_test_scripted_model__bart.pt \ | ||
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env python3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't love exposing this in the main agents folder. Can we make this in the torchscript folder, as it's very narrow and only for testing innit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'd be happier about moving it into the main folder if we were in a spot where jit supported a LOT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, moving |
||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import torch | ||
|
||
from parlai.core.agents import Agent | ||
from parlai.core.message import Message | ||
from parlai.core.opt import Opt | ||
from parlai.utils.io import PathManager | ||
|
||
|
||
class JitAgent(Agent): | ||
""" | ||
ParlAI agent exported to TorchScript with JIT compilation and then loaded from disk. | ||
|
||
Metrics and batch act are currently unsupported, and CUDA is unsupported because | ||
TorchScripting is currently CPU-only. | ||
""" | ||
|
||
def __init__(self, opt: Opt, shared=None): | ||
|
||
super().__init__(opt=opt, shared=shared) | ||
with PathManager.open(self.opt['model_file'], "rb") as f: | ||
self.module = torch.jit.load(f) | ||
|
||
# Track incoming history strings | ||
self.history: List[str] = [] | ||
|
||
def share(self): | ||
""" | ||
Share the scripted module object. | ||
""" | ||
shared = super().share() | ||
shared['module'] = self.module | ||
return shared | ||
|
||
def observe(self, observation: Message) -> Message: | ||
# TODO: support self._validate_observe_invariants() method of TorchAgent | ||
|
||
self.history.append(observation['text']) | ||
|
||
return super().observe(observation) | ||
|
||
def self_observe(self, self_message: Message) -> None: | ||
# TODO: support self._validate_self_observe_invariants() method of TorchAgent | ||
|
||
assert self.observation is not None | ||
if self.observation['episode_done']: | ||
# oh this was the last example in the episode. reset the history | ||
self.history = [] | ||
# additionally mark the last observation as invalid | ||
self.observation = None | ||
else: | ||
self.history.append(self_message['text']) | ||
|
||
def reset(self): | ||
super().reset() | ||
self.history = [] | ||
|
||
def act(self) -> Message: | ||
response_text = self.module('\n'.join(self.history)) | ||
response = Message({'text': response_text, 'episode_done': False}) | ||
# self.observation will determine if we're going onto a new episode | ||
self.self_observe(response) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
from typing import List | ||
|
||
import torch.jit | ||
|
||
from parlai.utils.io import PathManager | ||
|
||
|
||
def test_exported_model(scripted_model_file: str, inputs: List[str]): | ||
|
||
with PathManager.open(scripted_model_file, "rb") as f: | ||
scripted_module = torch.jit.load(f) | ||
|
||
print('\nGenerating given the scripted module:') | ||
context = [] | ||
for input_ in inputs: | ||
print(' TEXT: ' + input_) | ||
context.append(input_) | ||
label = scripted_module('\n'.join(context)) | ||
print("LABEL: " + label) | ||
context.append(label) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'-smf', | ||
'--scripted-model-file', | ||
type=str, | ||
help='Where to load the scripted model checkpoint from', | ||
) | ||
parser.add_argument( | ||
"-i", | ||
"--input", | ||
type=str, | ||
default="hello world", | ||
help="Test input string to pass into the encoder of the scripted model. Separate lines with a pipe", | ||
) | ||
args = parser.parse_args() | ||
test_exported_model( | ||
scripted_model_file=args.scripted_model_file, inputs=args.input.split('|') | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there plans to support all
transformer/generator
s? if not, should we name thisbart_jit
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I definitely would like to see us support more than just BART, especially once we see how well TorchScripting works for the BART models