Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

TorchScript BART models #3459

Merged
merged 127 commits into from Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 101 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
8e2e391
Checkpoint jit
stephenroller Feb 21, 2021
b22ffc7
maryam checkpoint
Feb 23, 2021
79ad543
Changes so far; currently debugging incr decoding
EricMichaelSmith Feb 24, 2021
60c7364
Trying to wrangle types
EricMichaelSmith Feb 25, 2021
54bc0b8
Fully traced model, no incr decoding, generates past EOS
EricMichaelSmith Feb 26, 2021
b349ea8
WIP test to mix scripting and tracing
EricMichaelSmith Feb 26, 2021
f513488
Get scripting and tracing working, no incr decoding
EricMichaelSmith Feb 26, 2021
4823b3c
Bit of cleanup
EricMichaelSmith Feb 26, 2021
f8936a7
Minor
EricMichaelSmith Feb 26, 2021
afe9656
Merge branch 'master' into jittest
EricMichaelSmith Mar 1, 2021
b9039d0
Prototype for incr decoding
EricMichaelSmith Mar 2, 2021
0c97c1a
Fixes
EricMichaelSmith Mar 2, 2021
e5d9608
Minor
EricMichaelSmith Mar 2, 2021
e58f07e
Fixes
EricMichaelSmith Mar 2, 2021
70599c5
Fixes
EricMichaelSmith Mar 2, 2021
2162a63
Fixes
EricMichaelSmith Mar 2, 2021
ed8a253
Save path
EricMichaelSmith Mar 2, 2021
2b8c560
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 2, 2021
dd33305
Path manager
EricMichaelSmith Mar 2, 2021
c29c2b8
Load exported model
EricMichaelSmith Mar 2, 2021
c4404e9
Compare outputs
EricMichaelSmith Mar 4, 2021
d9e87d8
Try original greedy search
EricMichaelSmith Mar 4, 2021
033092e
Minor
EricMichaelSmith Mar 4, 2021
1ca2428
Merge branch 'master' into jittest
EricMichaelSmith Mar 5, 2021
2c377a7
Refactor a bit
EricMichaelSmith Mar 5, 2021
ed0a917
Partial work on overhauling script code
EricMichaelSmith Mar 5, 2021
a3b4d1b
Minor
EricMichaelSmith Mar 5, 2021
de0dfda
Lots of reversions
EricMichaelSmith Mar 5, 2021
3997002
More reversion
EricMichaelSmith Mar 5, 2021
c96195c
Work on initial decoder input
EricMichaelSmith Mar 5, 2021
f642c57
Finish applying fixes from before scripting
EricMichaelSmith Mar 5, 2021
bbfafa4
Finish conversion
EricMichaelSmith Mar 5, 2021
aded9b5
Fixes
EricMichaelSmith Mar 5, 2021
b309164
Refactor for loading back in
EricMichaelSmith Mar 5, 2021
6d05381
Move scripts
EricMichaelSmith Mar 5, 2021
76f263f
Merge branch 'master' into jittest
EricMichaelSmith Mar 9, 2021
2b806c6
Set up scripting tokenization
EricMichaelSmith Mar 9, 2021
a7d8dcf
Initial fixes
EricMichaelSmith Mar 9, 2021
c5164c3
Don't rely on History
EricMichaelSmith Mar 9, 2021
23249b9
Minor
EricMichaelSmith Mar 9, 2021
8aacdc1
Start work in ScriptableDictionaryAgent
EricMichaelSmith Mar 9, 2021
a93d581
Work on dict
EricMichaelSmith Mar 9, 2021
da8eaba
Enable more stuff
EricMichaelSmith Mar 9, 2021
d163fc9
Enable more code
EricMichaelSmith Mar 9, 2021
817fca8
Set up more of the code
EricMichaelSmith Mar 9, 2021
ff85524
More
EricMichaelSmith Mar 9, 2021
2981f25
Simplify
EricMichaelSmith Mar 9, 2021
a80c637
Work on cleaning up BPE helper
EricMichaelSmith Mar 9, 2021
dc23433
Minor
EricMichaelSmith Mar 9, 2021
8e6d108
Pass in args
EricMichaelSmith Mar 9, 2021
894746d
Minor
EricMichaelSmith Mar 9, 2021
15b5600
Fuse keys
EricMichaelSmith Mar 9, 2021
59f9f7e
Rearrange
EricMichaelSmith Mar 9, 2021
ba16e1f
Various fixes
EricMichaelSmith Mar 9, 2021
b6e25f0
Various fixes
EricMichaelSmith Mar 9, 2021
b5e1dd4
Start work on replacement function
EricMichaelSmith Mar 11, 2021
4cce680
Minor
EricMichaelSmith Mar 11, 2021
f1877b4
Prototype
EricMichaelSmith Mar 11, 2021
e813de9
Work on scripting
EricMichaelSmith Mar 11, 2021
805b32d
Work on scripting
EricMichaelSmith Mar 11, 2021
e123fe6
Detokenizing
EricMichaelSmith Mar 11, 2021
205d52f
Runtime fixes
EricMichaelSmith Mar 12, 2021
f5e3ba1
Make splitting test more powerful
EricMichaelSmith Mar 12, 2021
04412c7
Testing
EricMichaelSmith Mar 12, 2021
9ca4945
Cleanup
EricMichaelSmith Mar 12, 2021
0c75b9f
Merge branch 'master' into jittest
EricMichaelSmith Mar 12, 2021
b08e690
No for/else
EricMichaelSmith Mar 12, 2021
9c7696e
A bit of cleanup
EricMichaelSmith Mar 12, 2021
7d81a91
More cleanups
EricMichaelSmith Mar 12, 2021
c73925f
Format
EricMichaelSmith Mar 12, 2021
e7106d4
Start CI testing
EricMichaelSmith Mar 12, 2021
de6a71b
Try bumping to 1.7
EricMichaelSmith Mar 12, 2021
96b70ea
Properly fix tests?
EricMichaelSmith Mar 12, 2021
9d07efa
Fix
EricMichaelSmith Mar 12, 2021
79a4137
Merge branch 'jittest' into jittest-ci-testing
EricMichaelSmith Mar 12, 2021
df43c02
Partial fix
EricMichaelSmith Mar 15, 2021
cb16dcd
TODOs
EricMichaelSmith Mar 15, 2021
97cd0df
Remove breakpoints
EricMichaelSmith Mar 15, 2021
e3f0c79
Join incr states in the last dim instead
EricMichaelSmith Mar 15, 2021
4ca5826
Model parallel fixes
EricMichaelSmith Mar 15, 2021
b5fb575
Minor
EricMichaelSmith Mar 15, 2021
df8b0eb
Cleanup
EricMichaelSmith Mar 15, 2021
9747589
Don't track history
EricMichaelSmith Mar 15, 2021
e9a64a8
Minor
EricMichaelSmith Mar 15, 2021
1723af6
Update parlai/agents/transformer/modules.py
EricMichaelSmith Mar 15, 2021
84c816a
Comments so far
EricMichaelSmith Mar 15, 2021
789a190
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 15, 2021
79a4974
Reversions
EricMichaelSmith Mar 15, 2021
0e208b1
ParlaiScript
EricMichaelSmith Mar 15, 2021
a79a6d1
Unit test overhaul
EricMichaelSmith Mar 15, 2021
ec6bae7
Minor
EricMichaelSmith Mar 15, 2021
c97f3a2
Minor
EricMichaelSmith Mar 15, 2021
849c8ac
Merge branch 'master' into jittest
EricMichaelSmith Mar 15, 2021
e394cd9
Test agent
EricMichaelSmith Mar 15, 2021
ba0cb1c
Add second test
EricMichaelSmith Mar 15, 2021
45312f9
Fix tests
EricMichaelSmith Mar 15, 2021
cd994b3
README
EricMichaelSmith Mar 15, 2021
8860766
Test fix
EricMichaelSmith Mar 15, 2021
0e7b38f
Conditional import
EricMichaelSmith Mar 15, 2021
1611563
Fix import
EricMichaelSmith Mar 15, 2021
6b6c084
Minor
EricMichaelSmith Mar 15, 2021
d7784ea
Rename
EricMichaelSmith Mar 15, 2021
4bee499
Rename
EricMichaelSmith Mar 15, 2021
5174dc0
Reverting stuff
EricMichaelSmith Mar 16, 2021
4d5515c
Reverting stuff
EricMichaelSmith Mar 16, 2021
e7d601c
Reverting stuff
EricMichaelSmith Mar 16, 2021
2fcd6ad
Update parlai/scripts/torchscript.py
EricMichaelSmith Mar 16, 2021
00a1d26
Revert
EricMichaelSmith Mar 16, 2021
14ec40b
Split ordering by layer
EricMichaelSmith Mar 16, 2021
a7fb474
Merge branch 'jittest' of github.com:facebookresearch/ParlAI into jit…
EricMichaelSmith Mar 16, 2021
5572674
Minor
EricMichaelSmith Mar 16, 2021
d8ad611
Re-add reorder_incremental_state method
EricMichaelSmith Mar 16, 2021
1d272a8
Fixes
EricMichaelSmith Mar 16, 2021
2919fa1
Split stuff to delay import
EricMichaelSmith Mar 16, 2021
0d7ae91
Remove list comprehension
EricMichaelSmith Mar 18, 2021
b66d477
Revert some code
EricMichaelSmith Mar 18, 2021
d898c24
Build out wrapping
EricMichaelSmith Mar 18, 2021
2d6ca34
Fixes
EricMichaelSmith Mar 18, 2021
18e68af
Revert
EricMichaelSmith Mar 18, 2021
1d7d342
Merge branch 'master' into jittest
EricMichaelSmith Mar 18, 2021
3807be9
Skip unless fairseq
EricMichaelSmith Mar 18, 2021
6f16934
Move checks to nightly
EricMichaelSmith Mar 18, 2021
9862d25
Merge branch 'master' into jittest
EricMichaelSmith Mar 19, 2021
feaf319
Merge branch 'master' into jittest
EricMichaelSmith Mar 19, 2021
a4377df
Try to speed up style_gen test
EricMichaelSmith Mar 19, 2021
ea18a61
Revert
EricMichaelSmith Mar 19, 2021
8e15176
Reduce bs for wizard tasks
EricMichaelSmith Mar 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 9 additions & 13 deletions parlai/agents/bart/modules.py
Expand Up @@ -7,7 +7,7 @@
"""
import torch
import torch.nn.functional as F
from typing import Any, Dict, Union, List, Optional
from typing import Dict
from parlai.agents.transformer.modules import TransformerGeneratorModel


Expand Down Expand Up @@ -50,10 +50,8 @@ def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
return torch.cat([tens, inputs], 1)

def reorder_decoder_incremental_state(
self,
incremental_state: Dict[str, Any],
inds: Union[List[int], torch.LongTensor],
) -> Optional[Dict[str, Any]]:
self, incremental_state: Dict[str, torch.Tensor], inds: torch.LongTensor
) -> Dict[str, torch.Tensor]:
"""
Incremental state is weird to handle when we seed decoder with two inputs
initially.
Expand All @@ -62,14 +60,12 @@ def reorder_decoder_incremental_state(
assert incremental_state is not None
assert len(incremental_state) > 0

for incr_state_l in incremental_state.values():
assert 'self_attn' in incr_state_l
assert 'prev_mask' in incr_state_l['self_attn']
self_attn_mask = incr_state_l['self_attn']['prev_mask']
# check this is on the very first run with incremental state
if self_attn_mask.ndim == 3 and tuple(self_attn_mask.shape[1:]) == (2, 2):
# cut off the inappropriate incremental state
incr_state_l['self_attn']['prev_mask'] = self_attn_mask[:, -1:, :]
assert 'self_attn_prev_mask' in incremental_state
self_attn_mask = incremental_state['self_attn_prev_mask']
# check this is on the very first run with incremental state
if self_attn_mask.ndim == 4 and tuple(self_attn_mask.shape[1:3]) == (2, 2):
# cut off the inappropriate incremental state
incremental_state['self_attn_prev_mask'] = self_attn_mask[:, -1:, :, :]

return super().reorder_decoder_incremental_state(incremental_state, inds)

Expand Down
25 changes: 25 additions & 0 deletions parlai/agents/jit/README.md
@@ -0,0 +1,25 @@
# Agent exported to TorchScript (JIT compilation)

This agent will read in a ParlAI agent that has been exported to TorchScript with JIT compilation, for use in greedy-search inference on CPU. This allows inference to be run on models without using any ParlAI overhead, either for tokenization or for the forward passes through the model. Currently, only BART models are supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there plans to support all transformer/generators? if not, should we name this bart_jit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I definitely would like to see us support more than just BART, especially once we see how well TorchScripting works for the BART models


Sample call for exporting a BART model to TorchScript:
```
parlai jit_export \
--model-file ${MODEL_FILE} \
--model bart \
--no-cuda \
--scripted-model-file ~/_test_scripted_model__bart.pt \
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches'
```

Interacting with an exported model using `parlai interactive`:
```
parlai interactive --model-file ~/_test_scripted_model__bart.pt --model jit
```

Loading in and running inference on an exported model, without any ParlAI overhead:
```
python parlai/agents/jit/scripts/test_exported_model.py \
--scripted-model-file ~/_test_scripted_model__bart.pt \
--input 'I am looking for a restaurant in the west part of town.|APIRESP: Restaurant 14 matches'
```
5 changes: 5 additions & 0 deletions parlai/agents/jit/__init__.py
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
70 changes: 70 additions & 0 deletions parlai/agents/jit/jit.py
@@ -0,0 +1,70 @@
#!/usr/bin/env python3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't love exposing this in the main agents folder. Can we make this in the torchscript folder, as it's very narrow and only for testing innit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'd be happier about moving it into the main folder if we were in a spot where jit supported a LOT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, moving


# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch

from parlai.core.agents import Agent
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.utils.io import PathManager


class JitAgent(Agent):
"""
ParlAI agent exported to TorchScript with JIT compilation and then loaded from disk.

Metrics and batch act are currently unsupported, and CUDA is unsupported because
TorchScripting is currently CPU-only.
"""

def __init__(self, opt: Opt, shared=None):

super().__init__(opt=opt, shared=shared)
with PathManager.open(self.opt['model_file'], "rb") as f:
self.module = torch.jit.load(f)

# Track incoming history strings
self.history: List[str] = []

def share(self):
"""
Share the scripted module object.
"""
shared = super().share()
shared['module'] = self.module
return shared

def observe(self, observation: Message) -> Message:
# TODO: support self._validate_observe_invariants() method of TorchAgent

self.history.append(observation['text'])

return super().observe(observation)

def self_observe(self, self_message: Message) -> None:
# TODO: support self._validate_self_observe_invariants() method of TorchAgent

assert self.observation is not None
if self.observation['episode_done']:
# oh this was the last example in the episode. reset the history
self.history = []
# additionally mark the last observation as invalid
self.observation = None
else:
self.history.append(self_message['text'])

def reset(self):
super().reset()
self.history = []

def act(self) -> Message:
response_text = self.module('\n'.join(self.history))
response = Message({'text': response_text, 'episode_done': False})
# self.observation will determine if we're going onto a new episode
self.self_observe(response)
return response
5 changes: 5 additions & 0 deletions parlai/agents/jit/scripts/__init__.py
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
48 changes: 48 additions & 0 deletions parlai/agents/jit/scripts/test_exported_model.py
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from typing import List

import torch.jit

from parlai.utils.io import PathManager


def test_exported_model(scripted_model_file: str, inputs: List[str]):

with PathManager.open(scripted_model_file, "rb") as f:
scripted_module = torch.jit.load(f)

print('\nGenerating given the scripted module:')
context = []
for input_ in inputs:
print(' TEXT: ' + input_)
context.append(input_)
label = scripted_module('\n'.join(context))
print("LABEL: " + label)
context.append(label)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-smf',
'--scripted-model-file',
type=str,
help='Where to load the scripted model checkpoint from',
)
parser.add_argument(
"-i",
"--input",
type=str,
default="hello world",
help="Test input string to pass into the encoder of the scripted model. Separate lines with a pipe",
)
args = parser.parse_args()
test_exported_model(
scripted_model_file=args.scripted_model_file, inputs=args.input.split('|')
)