Skip to content

Commit

Permalink
Update Jupyter Notebook APIs (#25)
Browse files Browse the repository at this point in the history
Update Jupyter Notebook APIs
  • Loading branch information
kahne committed Dec 8, 2019
1 parent a6f3da3 commit a7ec574
Show file tree
Hide file tree
Showing 13 changed files with 242 additions and 83 deletions.
12 changes: 12 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ jobs:
- run:
no_output_timeout: 60m
command: ~/miniconda3/bin/conda run -n vizseq36 python -m tests.test_embedding_based_scorers
- run:
no_output_timeout: 60m
command: ~/miniconda3/bin/conda run -n vizseq36 python -m tests.test_ipynb

py37_mac:
macos:
Expand All @@ -48,6 +51,9 @@ jobs:
- run:
no_output_timeout: 60m
command: ~/miniconda3/bin/conda run -n vizseq37 python -m tests.test_embedding_based_scorers
- run:
no_output_timeout: 60m
command: ~/miniconda3/bin/conda run -n vizseq37 python -m tests.test_ipynb

py36_linux:
docker:
Expand All @@ -64,6 +70,9 @@ jobs:
- run:
no_output_timeout: 60m
command: sudo python -m tests.test_embedding_based_scorers
- run:
no_output_timeout: 60m
command: sudo python -m tests.test_ipynb

py37_linux:
docker:
Expand All @@ -80,6 +89,9 @@ jobs:
- run:
no_output_timeout: 60m
command: sudo python -m tests.test_embedding_based_scorers
- run:
no_output_timeout: 60m
command: sudo python -m tests.test_ipynb

workflows:
version: 2
Expand Down
34 changes: 34 additions & 0 deletions tests/ipynb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import unittest
import os


class VizSeqIpynbTestCase(unittest.TestCase):
def setUp(self) -> None:
dataset_root = 'examples/data/translation_wmt14_en_de_test'
if not os.path.isdir(dataset_root):
raise NotADirectoryError(f'{dataset_root} does not exist.')
with open(f'{dataset_root}/src_0.txt') as f:
self.source = {'src': [l.strip() for l in f]}
with open(f'{dataset_root}/ref_0.txt') as f:
self.references = {'ref': [l.strip() for l in f]}
with open(f'{dataset_root}/pred_onlineA.0.txt') as f:
self.hypothesis = {'hypo': [l.strip() for l in f]}

def test_view_stats(self):
pass

def test_view_examples(self):
pass

def test_view_scores(self):
pass

def test_view_n_grams(self):
pass
24 changes: 24 additions & 0 deletions tests/ipynb/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from . import VizSeqIpynbTestCase
from vizseq.ipynb.core import (view_stats, view_examples, view_n_grams,
view_scores)


class VizSeqIpynbCoreTestCase(VizSeqIpynbTestCase):
def test_view_stats(self):
_ = view_stats(self.source, self.references)

def test_view_examples(self):
_ = view_examples(self.source, self.references, self.hypothesis)

def test_view_scores(self):
_ = view_scores(self.references, self.hypothesis, ['bleu'])

def test_view_n_grams(self):
_ = view_n_grams(self.references)
32 changes: 32 additions & 0 deletions tests/ipynb/test_fairseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from . import VizSeqIpynbTestCase
from vizseq.ipynb.fairseq_viz import (view_stats, view_examples, view_n_grams,
view_scores)


class VizSeqIpynbFairSeqTestCase(VizSeqIpynbTestCase):
def setUp(self) -> None:
self.log_paths = [
'examples/data/wmt14_fr_en_test.fairseq_generate.log',
'examples/data/wmt14_fr_en_test.fairseq_generate.log'
]

def test_view_stats(self):
_ = view_stats(self.log_paths[0])

def test_view_examples(self):
_ = view_examples(self.log_paths[0])
_ = view_examples(self.log_paths)

def test_view_scores(self):
_ = view_scores(self.log_paths[0], ['bleu'])

def test_view_n_grams(self):
_ = view_n_grams(self.log_paths[0])
23 changes: 23 additions & 0 deletions tests/test_ipynb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


if __name__ == '__main__':
import unittest
from tests.ipynb import test_core, test_fairseq

# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for m in [test_core, test_fairseq]:
suite.addTests(loader.loadTestsFromModule(m))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
2 changes: 1 addition & 1 deletion vizseq/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.10
0.1.11
2 changes: 1 addition & 1 deletion vizseq/ipynb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def view_stats(

n_src_plots = len(_src.text_indices)
n_plots = n_src_plots + _ref.n_sources
plt.ion()
fig, ax = plt.subplots(nrows=1, ncols=n_plots, figsize=(7 * n_plots, 5))
for i, idx in enumerate(_src.text_indices):
cur_ax = ax if n_plots == 1 else ax[i]
Expand All @@ -131,7 +132,6 @@ def view_stats(
_ = cur_ax.hist(cur_sent_lens, density=True, bins=25)
_ = cur_ax.axvline(x=np.mean(cur_sent_lens), color='red', linewidth=3)
cur_ax.set_title(f'Reference {name} Length')
plt.show()


# TODO: display one by one instead of all together
Expand Down
78 changes: 46 additions & 32 deletions vizseq/ipynb/fairseq_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.
#

import os
from typing import List, Optional
import os.path as op
from collections import Counter
from typing import List, Optional, Union

from vizseq.ipynb.core import (view_examples as _view_examples,
view_stats as _view_stats,
Expand All @@ -15,33 +16,46 @@
from vizseq._view import DEFAULT_PAGE_SIZE, DEFAULT_PAGE_NO, VizSeqSortingType


def _get_data(generation_log_path: str):
assert os.path.isfile(generation_log_path)
sources, references, hypothesis = {}, {}, {}
with open(generation_log_path) as f:
for l in f:
line = l.strip()
if line.startswith('H-'):
_id, _, sent = line.split('\t', 2)
hypothesis[_id[2:]] = sent
elif line.startswith('T-'):
_id, sent = line.split('\t', 1)
references[_id[2:]] = sent
elif line.startswith('S-'):
_id, sent = line.split('\t', 1)
sources[_id[2:]] = sent
assert set(sources.keys()) == set(references.keys()) == set(
hypothesis.keys())
ids = sorted(sources.keys())
sources = [sources[i] for i in ids]
references = [references[i] for i in ids]
hypothesis = [hypothesis[i] for i in ids]
return {'0': sources}, {'0': references}, {'fairseq': hypothesis}
def _get_data(log_path_or_paths: Union[str, List[str]]):
if isinstance(log_path_or_paths, str):
log_path_or_paths = [log_path_or_paths]
ids, src, ref, hypo = None, None, None, {}
names = Counter()
for k, log_path in enumerate(log_path_or_paths):
assert op.isfile(log_path)
cur_src, cur_ref, cur_hypo = {}, {}, {}
with open(log_path) as f:
for l in f:
line = l.strip()
if line.startswith('H-'):
_id, _, sent = line.split('\t', 2)
cur_hypo[_id[2:]] = sent
elif line.startswith('T-'):
_id, sent = line.split('\t', 1)
cur_ref[_id[2:]] = sent
elif line.startswith('S-'):
_id, sent = line.split('\t', 1)
cur_src[_id[2:]] = sent
cur_ids = sorted(cur_src.keys())
assert set(cur_ids) == set(cur_ref.keys()) == set(cur_hypo.keys())
cur_src = [cur_src[i] for i in cur_ids]
cur_ref = [cur_ref[i] for i in cur_ids]
if k == 0:
ids, src, ref = cur_ids, cur_src, cur_ref
else:
assert set(ids) == set(cur_ids) and set(src) == set(cur_src)
assert set(ref) == set(cur_ref)
name = op.splitext(op.basename(log_path))[0]
names.update([name])
if names[name] > 1:
name += f'.{names[name]}'
hypo[name] = [cur_hypo[i] for i in cur_ids]
return {'0': src}, {'0': ref}, hypo


# TODO: visualize alignment by attention
def view_examples(
generation_log_path: str,
log_path_or_paths: Union[str, List[str]],
metrics: Optional[List[str]] = None,
query: str = '',
page_sz: int = DEFAULT_PAGE_SIZE,
Expand All @@ -50,24 +64,24 @@ def view_examples(
need_g_translate: bool = False,
disable_alignment: bool = False
):
sources, references, hypothesis = _get_data(generation_log_path)
sources, references, hypothesis = _get_data(log_path_or_paths)
return _view_examples(
sources, references, hypothesis, metrics, query, page_sz=page_sz,
page_no=page_no, sorting=sorting, need_g_translate=need_g_translate,
disable_alignment=disable_alignment
)


def view_stats(generation_log_path: str):
sources, references, hypothesis = _get_data(generation_log_path)
def view_stats(log_path: str):
sources, references, hypothesis = _get_data(log_path)
_view_stats(sources, references, hypothesis)


def view_n_grams(generation_log_path: str, k: int = 64):
sources, references, hypothesis = _get_data(generation_log_path)
def view_n_grams(log_path: str, k: int = 64):
sources, _, _ = _get_data(log_path)
return _view_n_grams(sources, k=k)


def view_scores(generation_log_path: str, metrics: List[str]):
sources, references, hypothesis = _get_data(generation_log_path)
def view_scores(log_path: str, metrics: List[str]):
sources, references, hypothesis = _get_data(log_path)
return _view_scores(references, hypothesis, metrics)
36 changes: 36 additions & 0 deletions website/docs/features/fairseq_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
id: fairseq_api
title: Fairseq Integration APIs
sidebar_label: Fairseq Integration APIs
---

## `vizseq.ipynb.fairseq_viz.*`
VizSeq can directly import and analyze model predictions generated by
[`fairseq-generate`](https://github.com/pytorch/fairseq/blob/master/generate.py) or
[`fairseq-interactive`](https://github.com/pytorch/fairseq/blob/master/interactive.py) in Jupyter Notebook. The APIs
are almost the same as [the normal Jupyter Notebook APIs](ipynb_api).
### `view_stats()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
### `view_scores()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`metrics`: List[str]**: List of scorer IDs. Use [`available_scorers()`](#available_scorers) to check all the available ones.
### `view_examples()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`metrics`: Optional[List[str]] = None**: List of scorer IDs. Default to `None`. Use [`available_scorers()`](#available_scorers) to check all the available ones.
- **`query`: str = ''**: The keyword(s) for example filtering. Default to `''`.
- **`page_sz`: int = 10**: Page size. Default to `10`.
- **`page_no`: int = 1**: Page number. Default to `1`.
- **`sorting`: VizSeqSortingType = VizSeqSortingType.original**
- **`need_g_translate`: bool = False**:
To show Google Translate results or not. Default to `False`.
- **`disable_alignment`: bool = False**:
Not to show source-reference and reference-hypothesis alignments for rendering speedup. Default to `False`.

### `view_n_grams()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`k`: int = 64**:
Number of n-grams to be shown. Default to `64`.
32 changes: 3 additions & 29 deletions website/docs/features/ipynb_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ title: Jupyter Notebook APIs
sidebar_label: Jupyter Notebook APIs
---

## Core (``vizseq.*`` or `vizseq.ipynb.*`)
## ``vizseq.*`` or `vizseq.ipynb.*`
### `view_stats()`
Showing the dataset statistics, including examples count, tokens count, sentence length distribution, etc.
It contains Matplotlib figures and you need to add `%matplotlib inline` before use.
#### Arguments
- **`sources`: Union[str, List[str], Dict[str, List[str]]]**:
Source-side data source. Can be a path, paths or lists of sentences. Refer to the [data](data) section for more details.
Expand Down Expand Up @@ -54,31 +55,4 @@ The data source. Can be a path, paths or lists of sentences. Refer to the [data]
- **`k`: int = 64**:
Number of n-grams to be shown. Default to `64`.

## Fairseq (`vizseq.ipynb.fairseq_viz.*`)
VizSeq can directly import and analyze model predictions generated by [`fairseq-generate`](https://github.com/pytorch/fairseq/blob/master/generate.py) and [`fairseq-interactive`](https://github.com/pytorch/fairseq/blob/master/interactive.py).
The APIs are almost the same.
### `view_stats()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
### `view_scores()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`metrics`: List[str]**: List of scorer IDs. Use [`available_scorers()`](#available_scorers) to check all the available ones.
### `view_examples()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`metrics`: Optional[List[str]] = None**: List of scorer IDs. Default to `None`. Use [`available_scorers()`](#available_scorers) to check all the available ones.
- **`query`: str = ''**: The keyword(s) for example filtering. Default to `''`.
- **`page_sz`: int = 10**: Page size. Default to `10`.
- **`page_no`: int = 1**: Page number. Default to `1`.
- **`sorting`: VizSeqSortingType = VizSeqSortingType.original**
- **`need_g_translate`: bool = False**:
To show Google Translate results or not. Default to `False`.
- **`disable_alignment`: bool = False**:
Not to show source-reference and reference-hypothesis alignments for rendering speedup. Default to `False`.

### `view_n_grams()`
#### Arguments
- **`log_path`: str**: The path to `fairseq-generate` or `fairseq-interactive` log file.
- **`k`: int = 64**:
Number of n-grams to be shown. Default to `64`.
## [Fairseq Integration APIs](fairseq_api)
Loading

0 comments on commit a7ec574

Please sign in to comment.