Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

TorchScript BART models #3459

Merged
merged 127 commits into from Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
8e2e391
Checkpoint jit
stephenroller Feb 21, 2021
b22ffc7
maryam checkpoint
Feb 23, 2021
79ad543
Changes so far; currently debugging incr decoding
EricMichaelSmith Feb 24, 2021
60c7364
Trying to wrangle types
EricMichaelSmith Feb 25, 2021
54bc0b8
Fully traced model, no incr decoding, generates past EOS
EricMichaelSmith Feb 26, 2021
b349ea8
WIP test to mix scripting and tracing
EricMichaelSmith Feb 26, 2021
f513488
Get scripting and tracing working, no incr decoding
EricMichaelSmith Feb 26, 2021
4823b3c
Bit of cleanup
EricMichaelSmith Feb 26, 2021
f8936a7
Minor
EricMichaelSmith Feb 26, 2021
afe9656
Merge branch 'master' into jittest
EricMichaelSmith Mar 1, 2021
b9039d0
Prototype for incr decoding
EricMichaelSmith Mar 2, 2021
0c97c1a
Fixes
EricMichaelSmith Mar 2, 2021
e5d9608
Minor
EricMichaelSmith Mar 2, 2021
e58f07e
Fixes
EricMichaelSmith Mar 2, 2021
70599c5
Fixes
EricMichaelSmith Mar 2, 2021
2162a63
Fixes
EricMichaelSmith Mar 2, 2021
ed8a253
Save path
EricMichaelSmith Mar 2, 2021
2b8c560
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 2, 2021
dd33305
Path manager
EricMichaelSmith Mar 2, 2021
c29c2b8
Load exported model
EricMichaelSmith Mar 2, 2021
c4404e9
Compare outputs
EricMichaelSmith Mar 4, 2021
d9e87d8
Try original greedy search
EricMichaelSmith Mar 4, 2021
033092e
Minor
EricMichaelSmith Mar 4, 2021
1ca2428
Merge branch 'master' into jittest
EricMichaelSmith Mar 5, 2021
2c377a7
Refactor a bit
EricMichaelSmith Mar 5, 2021
ed0a917
Partial work on overhauling script code
EricMichaelSmith Mar 5, 2021
a3b4d1b
Minor
EricMichaelSmith Mar 5, 2021
de0dfda
Lots of reversions
EricMichaelSmith Mar 5, 2021
3997002
More reversion
EricMichaelSmith Mar 5, 2021
c96195c
Work on initial decoder input
EricMichaelSmith Mar 5, 2021
f642c57
Finish applying fixes from before scripting
EricMichaelSmith Mar 5, 2021
bbfafa4
Finish conversion
EricMichaelSmith Mar 5, 2021
aded9b5
Fixes
EricMichaelSmith Mar 5, 2021
b309164
Refactor for loading back in
EricMichaelSmith Mar 5, 2021
6d05381
Move scripts
EricMichaelSmith Mar 5, 2021
76f263f
Merge branch 'master' into jittest
EricMichaelSmith Mar 9, 2021
2b806c6
Set up scripting tokenization
EricMichaelSmith Mar 9, 2021
a7d8dcf
Initial fixes
EricMichaelSmith Mar 9, 2021
c5164c3
Don't rely on History
EricMichaelSmith Mar 9, 2021
23249b9
Minor
EricMichaelSmith Mar 9, 2021
8aacdc1
Start work in ScriptableDictionaryAgent
EricMichaelSmith Mar 9, 2021
a93d581
Work on dict
EricMichaelSmith Mar 9, 2021
da8eaba
Enable more stuff
EricMichaelSmith Mar 9, 2021
d163fc9
Enable more code
EricMichaelSmith Mar 9, 2021
817fca8
Set up more of the code
EricMichaelSmith Mar 9, 2021
ff85524
More
EricMichaelSmith Mar 9, 2021
2981f25
Simplify
EricMichaelSmith Mar 9, 2021
a80c637
Work on cleaning up BPE helper
EricMichaelSmith Mar 9, 2021
dc23433
Minor
EricMichaelSmith Mar 9, 2021
8e6d108
Pass in args
EricMichaelSmith Mar 9, 2021
894746d
Minor
EricMichaelSmith Mar 9, 2021
15b5600
Fuse keys
EricMichaelSmith Mar 9, 2021
59f9f7e
Rearrange
EricMichaelSmith Mar 9, 2021
ba16e1f
Various fixes
EricMichaelSmith Mar 9, 2021
b6e25f0
Various fixes
EricMichaelSmith Mar 9, 2021
b5e1dd4
Start work on replacement function
EricMichaelSmith Mar 11, 2021
4cce680
Minor
EricMichaelSmith Mar 11, 2021
f1877b4
Prototype
EricMichaelSmith Mar 11, 2021
e813de9
Work on scripting
EricMichaelSmith Mar 11, 2021
805b32d
Work on scripting
EricMichaelSmith Mar 11, 2021
e123fe6
Detokenizing
EricMichaelSmith Mar 11, 2021
205d52f
Runtime fixes
EricMichaelSmith Mar 12, 2021
f5e3ba1
Make splitting test more powerful
EricMichaelSmith Mar 12, 2021
04412c7
Testing
EricMichaelSmith Mar 12, 2021
9ca4945
Cleanup
EricMichaelSmith Mar 12, 2021
0c75b9f
Merge branch 'master' into jittest
EricMichaelSmith Mar 12, 2021
b08e690
No for/else
EricMichaelSmith Mar 12, 2021
9c7696e
A bit of cleanup
EricMichaelSmith Mar 12, 2021
7d81a91
More cleanups
EricMichaelSmith Mar 12, 2021
c73925f
Format
EricMichaelSmith Mar 12, 2021
e7106d4
Start CI testing
EricMichaelSmith Mar 12, 2021
de6a71b
Try bumping to 1.7
EricMichaelSmith Mar 12, 2021
96b70ea
Properly fix tests?
EricMichaelSmith Mar 12, 2021
9d07efa
Fix
EricMichaelSmith Mar 12, 2021
79a4137
Merge branch 'jittest' into jittest-ci-testing
EricMichaelSmith Mar 12, 2021
df43c02
Partial fix
EricMichaelSmith Mar 15, 2021
cb16dcd
TODOs
EricMichaelSmith Mar 15, 2021
97cd0df
Remove breakpoints
EricMichaelSmith Mar 15, 2021
e3f0c79
Join incr states in the last dim instead
EricMichaelSmith Mar 15, 2021
4ca5826
Model parallel fixes
EricMichaelSmith Mar 15, 2021
b5fb575
Minor
EricMichaelSmith Mar 15, 2021
df8b0eb
Cleanup
EricMichaelSmith Mar 15, 2021
9747589
Don't track history
EricMichaelSmith Mar 15, 2021
e9a64a8
Minor
EricMichaelSmith Mar 15, 2021
1723af6
Update parlai/agents/transformer/modules.py
EricMichaelSmith Mar 15, 2021
84c816a
Comments so far
EricMichaelSmith Mar 15, 2021
789a190
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 15, 2021
79a4974
Reversions
EricMichaelSmith Mar 15, 2021
0e208b1
ParlaiScript
EricMichaelSmith Mar 15, 2021
a79a6d1
Unit test overhaul
EricMichaelSmith Mar 15, 2021
ec6bae7
Minor
EricMichaelSmith Mar 15, 2021
c97f3a2
Minor
EricMichaelSmith Mar 15, 2021
849c8ac
Merge branch 'master' into jittest
EricMichaelSmith Mar 15, 2021
e394cd9
Test agent
EricMichaelSmith Mar 15, 2021
ba0cb1c
Add second test
EricMichaelSmith Mar 15, 2021
45312f9
Fix tests
EricMichaelSmith Mar 15, 2021
cd994b3
README
EricMichaelSmith Mar 15, 2021
8860766
Test fix
EricMichaelSmith Mar 15, 2021
0e7b38f
Conditional import
EricMichaelSmith Mar 15, 2021
1611563
Fix import
EricMichaelSmith Mar 15, 2021
6b6c084
Minor
EricMichaelSmith Mar 15, 2021
d7784ea
Rename
EricMichaelSmith Mar 15, 2021
4bee499
Rename
EricMichaelSmith Mar 15, 2021
5174dc0
Reverting stuff
EricMichaelSmith Mar 16, 2021
4d5515c
Reverting stuff
EricMichaelSmith Mar 16, 2021
e7d601c
Reverting stuff
EricMichaelSmith Mar 16, 2021
2fcd6ad
Update parlai/scripts/torchscript.py
EricMichaelSmith Mar 16, 2021
00a1d26
Revert
EricMichaelSmith Mar 16, 2021
14ec40b
Split ordering by layer
EricMichaelSmith Mar 16, 2021
a7fb474
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 16, 2021
5572674
Minor
EricMichaelSmith Mar 16, 2021
d8ad611
Re-add reorder_incremental_state method
EricMichaelSmith Mar 16, 2021
1d272a8
Fixes
EricMichaelSmith Mar 16, 2021
2919fa1
Split stuff to delay import
EricMichaelSmith Mar 16, 2021
0d7ae91
Remove list comprehension
EricMichaelSmith Mar 18, 2021
b66d477
Revert some code
EricMichaelSmith Mar 18, 2021
d898c24
Build out wrapping
EricMichaelSmith Mar 18, 2021
2d6ca34
Fixes
EricMichaelSmith Mar 18, 2021
18e68af
Revert
EricMichaelSmith Mar 18, 2021
1d7d342
Merge branch 'master' into jittest
EricMichaelSmith Mar 18, 2021
3807be9
Skip unless fairseq
EricMichaelSmith Mar 18, 2021
6f16934
Move checks to nightly
EricMichaelSmith Mar 18, 2021
9862d25
Merge branch 'master' into jittest
EricMichaelSmith Mar 19, 2021
feaf319
Merge branch 'master' into jittest
EricMichaelSmith Mar 19, 2021
a4377df
Try to speed up style_gen test
EricMichaelSmith Mar 19, 2021
ea18a61
Revert
EricMichaelSmith Mar 19, 2021
8e15176
Reduce bs for wizard tasks
EricMichaelSmith Mar 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions parlai/agents/transformer/modules.py
Expand Up @@ -748,7 +748,7 @@ def forward_layers(
tensor: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Dict[int, torch.Tensor],
incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass of decoder layers.
Expand Down Expand Up @@ -797,8 +797,9 @@ def forward(self, input, encoder_state, incr_state=None):
encoder_output, encoder_mask = encoder_state

seq_len = input.size(1)
positions = input.new(seq_len).long()
positions = torch.arange(seq_len, out=positions).unsqueeze(0)
positions = torch.arange(
seq_len, dtype=torch.long, device=input.device
).unsqueeze(0)

if incr_state is not None:
# We're doing incremental decoding, so select only the most recent position
Expand Down
107 changes: 107 additions & 0 deletions parlai/scripts/torchscript.py
@@ -0,0 +1,107 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch.jit
import torch.nn as nn
from packaging import version

from parlai.core.agents import create_agent
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.script import ParlaiScript, register_script
from parlai.utils.io import PathManager


def export_model(opt: Opt):
"""
Export a model to TorchScript so that inference can be run outside of ParlAI.

Currently, only CPU greedy-search inference on BART models is supported.
"""

if version.parse(torch.__version__) < version.parse('1.7.0'):
raise NotImplementedError(
'TorchScript export is only supported for Torch 1.7 and higher!'
)
else:
# Only load TorchScriptGreedySearch now, because this will trigger scripting of
# associated modules
from parlai.torchscript.modules import TorchScriptGreedySearch

overrides = {
'no_cuda': True, # TorchScripting is CPU only
'model_parallel': False, # model_parallel is not currently supported when TorchScripting
}
if 'override' not in opt:
opt['override'] = {}
for k, v in overrides.items():
opt[k] = v
opt['override'][k] = v

# Create the unscripted greedy-search module
agent = create_agent(opt, requireModelExists=True)
original_module = TorchScriptGreedySearch(agent)

# Script the module and save
scripted_module = torch.jit.script(TorchScriptGreedySearch(agent))
with PathManager.open(opt['scripted_model_file'], 'wb') as f:
torch.jit.save(scripted_module, f)

# Compare the original module to the scripted module against the test inputs
if len(opt['input']) > 0:
inputs = opt['input'].split('|')
print('\nGenerating given the original unscripted module:')
_run_conversation(module=original_module, inputs=inputs)
print('\nGenerating given the scripted module:')
_run_conversation(module=scripted_module, inputs=inputs)


def setup_args() -> ParlaiParser:
parser = ParlaiParser(add_parlai_args=True, add_model_args=True)
parser.add_argument(
'-smf',
'--scripted-model-file',
type=str,
default='_scripted.pt',
help='Where the scripted model checkpoint will be saved',
)
parser.add_argument(
"-i",
"--input",
type=str,
default='',
help="Input string to pass into the encoder of the scripted model, to test it against the unscripted version. Separate lines with a pipe",
)
return parser


def _run_conversation(module: nn.Module, inputs: List[str]):
"""
Run a conversation with the given module given the input strings.
"""
context = []
for input_ in inputs:
print(' TEXT: ' + input_)
context.append(input_)
label = module('\n'.join(context))
print("LABEL: " + label)
context.append(label)


@register_script('torchscript', hidden=True)
class TorchScript(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()

def run(self):
return export_model(self.opt)


if __name__ == '__main__':
TorchScript.main()
27 changes: 27 additions & 0 deletions parlai/torchscript/README.md
@@ -0,0 +1,27 @@
# Agent exported to TorchScript (JIT compilation)

This agent will read in a ParlAI agent that has been exported to TorchScript with JIT compilation, for use in greedy-search inference on CPU. This allows inference to be run on models without using any ParlAI overhead, either for tokenization or for the forward passes through the model. Currently, only BART models are supported.

Sample call for exporting a BART model to TorchScript:
```
parlai torchscript \
--model-file ${MODEL_FILE} \
--model bart \
--no-cuda \
--scripted-model-file ~/_test_scripted_model__bart.pt \
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches'
```

Interacting with an exported model using `parlai interactive`:
```
parlai interactive \
--model-file ~/_test_scripted_model__bart.pt \
--model parlai.torchscript.agents:TorchScriptAgent
```

Loading in and running inference on an exported model, without any ParlAI overhead:
```
python parlai/torchscript/scripts/test_exported_model.py \
--scripted-model-file ~/_test_scripted_model__bart.pt \
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches'
```
5 changes: 5 additions & 0 deletions parlai/torchscript/__init__.py
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
70 changes: 70 additions & 0 deletions parlai/torchscript/agents.py
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch

from parlai.core.agents import Agent
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.utils.io import PathManager


class TorchScriptAgent(Agent):
"""
ParlAI agent exported to TorchScript with JIT compilation and then loaded from disk.

Metrics and batch act are currently unsupported, and CUDA is unsupported because
TorchScripting is currently CPU-only.
"""

def __init__(self, opt: Opt, shared=None):

super().__init__(opt=opt, shared=shared)
with PathManager.open(self.opt['model_file'], "rb") as f:
self.module = torch.jit.load(f)

# Track incoming history strings
self.history: List[str] = []

def share(self):
"""
Share the scripted module object.
"""
shared = super().share()
shared['module'] = self.module
return shared

def observe(self, observation: Message) -> Message:
# TODO: support self._validate_observe_invariants() method of TorchAgent

self.history.append(observation['text'])

return super().observe(observation)

def self_observe(self, self_message: Message) -> None:
# TODO: support self._validate_self_observe_invariants() method of TorchAgent

assert self.observation is not None
if self.observation['episode_done']:
# oh this was the last example in the episode. reset the history
self.history = []
# additionally mark the last observation as invalid
self.observation = None
else:
self.history.append(self_message['text'])

def reset(self):
super().reset()
self.history = []

def act(self) -> Message:
response_text = self.module('\n'.join(self.history))
response = Message({'text': response_text, 'episode_done': False})
# self.observation will determine if we're going onto a new episode
self.self_observe(response)
return response