-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Jupyter Notebook APIs
- Loading branch information
Showing
13 changed files
with
242 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
import unittest | ||
import os | ||
|
||
|
||
class VizSeqIpynbTestCase(unittest.TestCase): | ||
def setUp(self) -> None: | ||
dataset_root = 'examples/data/translation_wmt14_en_de_test' | ||
if not os.path.isdir(dataset_root): | ||
raise NotADirectoryError(f'{dataset_root} does not exist.') | ||
with open(f'{dataset_root}/src_0.txt') as f: | ||
self.source = {'src': [l.strip() for l in f]} | ||
with open(f'{dataset_root}/ref_0.txt') as f: | ||
self.references = {'ref': [l.strip() for l in f]} | ||
with open(f'{dataset_root}/pred_onlineA.0.txt') as f: | ||
self.hypothesis = {'hypo': [l.strip() for l in f]} | ||
|
||
def test_view_stats(self): | ||
pass | ||
|
||
def test_view_examples(self): | ||
pass | ||
|
||
def test_view_scores(self): | ||
pass | ||
|
||
def test_view_n_grams(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
from . import VizSeqIpynbTestCase | ||
from vizseq.ipynb.core import (view_stats, view_examples, view_n_grams, | ||
view_scores) | ||
|
||
|
||
class VizSeqIpynbCoreTestCase(VizSeqIpynbTestCase): | ||
def test_view_stats(self): | ||
_ = view_stats(self.source, self.references) | ||
|
||
def test_view_examples(self): | ||
_ = view_examples(self.source, self.references, self.hypothesis) | ||
|
||
def test_view_scores(self): | ||
_ = view_scores(self.references, self.hypothesis, ['bleu']) | ||
|
||
def test_view_n_grams(self): | ||
_ = view_n_grams(self.references) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
|
||
from . import VizSeqIpynbTestCase | ||
from vizseq.ipynb.fairseq_viz import (view_stats, view_examples, view_n_grams, | ||
view_scores) | ||
|
||
|
||
class VizSeqIpynbFairSeqTestCase(VizSeqIpynbTestCase): | ||
def setUp(self) -> None: | ||
self.log_paths = [ | ||
'examples/data/wmt14_fr_en_test.fairseq_generate.log', | ||
'examples/data/wmt14_fr_en_test.fairseq_generate.log' | ||
] | ||
|
||
def test_view_stats(self): | ||
_ = view_stats(self.log_paths[0]) | ||
|
||
def test_view_examples(self): | ||
_ = view_examples(self.log_paths[0]) | ||
_ = view_examples(self.log_paths) | ||
|
||
def test_view_scores(self): | ||
_ = view_scores(self.log_paths[0], ['bleu']) | ||
|
||
def test_view_n_grams(self): | ||
_ = view_n_grams(self.log_paths[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
|
||
if __name__ == '__main__': | ||
import unittest | ||
from tests.ipynb import test_core, test_fairseq | ||
|
||
# initialize the test suite | ||
loader = unittest.TestLoader() | ||
suite = unittest.TestSuite() | ||
|
||
# add tests to the test suite | ||
for m in [test_core, test_fairseq]: | ||
suite.addTests(loader.loadTestsFromModule(m)) | ||
|
||
# initialize a runner, pass it your suite and run it | ||
runner = unittest.TextTestRunner(verbosity=3) | ||
result = runner.run(suite) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.1.10 | ||
0.1.11 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
--- | ||
id: fairseq_api | ||
title: Fairseq Integration APIs | ||
sidebar_label: Fairseq Integration APIs | ||
--- | ||
|
||
## `vizseq.ipynb.fairseq_viz.*` | ||
VizSeq can directly import and analyze model predictions generated by | ||
[`fairseq-generate`](https://github.com/pytorch/fairseq/blob/master/generate.py) or | ||
[`fairseq-interactive`](https://github.com/pytorch/fairseq/blob/master/interactive.py) in Jupyter Notebook. The APIs | ||
are almost the same as [the normal Jupyter Notebook APIs](ipynb_api). | ||
### `view_stats()` | ||
#### Arguments | ||
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file. | ||
### `view_scores()` | ||
#### Arguments | ||
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file. | ||
- **`metrics`: List[str]**: List of scorer IDs. Use [`available_scorers()`](#available_scorers) to check all the available ones. | ||
### `view_examples()` | ||
#### Arguments | ||
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file. | ||
- **`metrics`: Optional[List[str]] = None**: List of scorer IDs. Default to `None`. Use [`available_scorers()`](#available_scorers) to check all the available ones. | ||
- **`query`: str = ''**: The keyword(s) for example filtering. Default to `''`. | ||
- **`page_sz`: int = 10**: Page size. Default to `10`. | ||
- **`page_no`: int = 1**: Page number. Default to `1`. | ||
- **`sorting`: VizSeqSortingType = VizSeqSortingType.original** | ||
- **`need_g_translate`: bool = False**: | ||
To show Google Translate results or not. Default to `False`. | ||
- **`disable_alignment`: bool = False**: | ||
Not to show source-reference and reference-hypothesis alignments for rendering speedup. Default to `False`. | ||
|
||
### `view_n_grams()` | ||
#### Arguments | ||
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file. | ||
- **`k`: int = 64**: | ||
Number of n-grams to be shown. Default to `64`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.