Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokenization within datasets (need to re-run json prep stage right before training stage to be compatible with this change) #58

Merged
merged 1 commit into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions espresso/criterions/cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def forward(self, model, sample, reduce=True):
assert pred.size() == target.size()
with data_utils.numpy_seed(model.num_updates):
i = np.random.randint(0, len(sample["id"]))
ref_tokens = sample["target_raw_text"][i]
length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
ref_one = self.dictionary.wordpiece_decode(ref_tokens)
ref_one = sample["text"][i]
pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length]))
logger.info("sample REF: " + ref_one)
logger.info("sample PRD: " + pred_one)
Expand Down
3 changes: 1 addition & 2 deletions espresso/criterions/label_smoothed_cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ def forward(self, model, sample, reduce=True):
assert pred.size() == target.size()
with data_utils.numpy_seed(model.num_updates):
i = np.random.randint(0, len(sample["id"]))
ref_tokens = sample["target_raw_text"][i]
length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
ref_one = self.dictionary.wordpiece_decode(ref_tokens)
ref_one = sample["text"][i]
pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length]))
logger.info("sample REF: " + ref_one)
logger.info("sample PRD: " + pred_one)
Expand Down
2 changes: 1 addition & 1 deletion espresso/data/asr_chain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
text_item = self.text[index][1] if self.text is not None else None
text_item = self.text[index][2] if self.text is not None else None
src_item = self.src[index]
example = {
"id": index,
Expand Down
22 changes: 15 additions & 7 deletions espresso/data/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
else:
ntokens = src_lengths.sum().item()

target_raw_text = None
if samples[0].get("target_raw_text", None) is not None:
target_raw_text = [samples[i]["target_raw_text"] for i in sort_order.numpy()]
token_text = None
if samples[0].get("token_text", None) is not None:
token_text = [samples[i]["token_text"] for i in sort_order.numpy()]

text = None
if samples[0].get("text", None) is not None:
text = [samples[i]["text"] for i in sort_order.numpy()]

batch = {
"id": id,
Expand All @@ -103,7 +107,8 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
"ntokens": ntokens,
"net_input": {"src_tokens": src_frames, "src_lengths": src_lengths},
"target": target,
"target_raw_text": target_raw_text,
"token_text": token_text,
"text": text,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
Expand Down Expand Up @@ -253,14 +258,16 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index][0] if self.tgt is not None else None
raw_text_item = self.tgt[index][1] if self.tgt is not None else None
token_text_item = self.tgt[index][1] if self.tgt is not None else None
text_item = self.tgt[index][2] if self.tgt is not None else None
src_item = self.src[index]
example = {
"id": index,
"utt_id": self.src.utt_ids[index],
"source": src_item,
"target": tgt_item,
"target_raw_text": raw_text_item,
"token_text": token_text_item,
"text": text_item,
}
if self.constraints is not None:
example["constraints"] = self.constraints[index]
Expand Down Expand Up @@ -304,7 +311,8 @@ def collater(self, samples, pad_to_length=None):
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `target_raw_text` (List[str]): list of original text
- `token_text` (List[str]): list of token text
- `text` (List[str]): list of original text
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
Expand Down
8 changes: 4 additions & 4 deletions espresso/data/asr_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
from fairseq.data import Dictionary, encoders
from fairseq.dataclass import FairseqDataclass
from fairseq.file_io import PathManager
from omegaconf import DictConfig

# will automatically load modules defined from there
from espresso.data import encoders as encoders_espresso
Expand Down Expand Up @@ -99,12 +99,12 @@ def dummy_sentence(self, length):
t[-1] = self.eos()
return t

def build_tokenizer(self, cfg: Union[FairseqDataclass, Namespace]):
def build_tokenizer(self, cfg: Union[DictConfig, Namespace]):
self.tokenizer = encoders.build_tokenizer(cfg)

def build_bpe(self, cfg: Union[FairseqDataclass, Namespace]):
def build_bpe(self, cfg: Union[DictConfig, Namespace]):
if (
(isinstance(cfg, FairseqDataclass) and cfg._name == "characters_asr")
(isinstance(cfg, DictConfig) and cfg._name == "characters_asr")
or (isinstance(cfg, Namespace) and getattr(cfg, "bpe", None) == "characters_asr")
):
self.bpe = encoders.build_bpe(
Expand Down
2 changes: 1 addition & 1 deletion espresso/data/asr_xent_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
text_item = self.text[index][1] if self.text is not None else None
text_item = self.text[index][2] if self.text is not None else None
src_item = self.src[index]
example = {
"id": index,
Expand Down
51 changes: 29 additions & 22 deletions espresso/data/feat_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,38 +233,39 @@ def __getitem__(self, i):


class AsrTextDataset(torch.utils.data.Dataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory. Each line of the text file is in the
format of 'utt_id tokenized_text'."""
"""Takes a text file as input, tokenizes and tensorizes it in memory at instantiation.
Both original text and tokenized text are kept in memory."""

def __init__(self, utt_ids: List[str], token_text: List[str], dictionary=None, append_eos=True):
def __init__(self, utt_ids: List[str], texts: List[str], dictionary=None, append_eos=True):
super().__init__()
self.dtype = np.float
self.append_eos = append_eos
self.read_text(utt_ids, token_text, dictionary)
self.read_text(utt_ids, texts, dictionary)

def read_text(self, utt_ids: List[str], token_text: List[str], dictionary=None):
assert len(utt_ids) == len(token_text)
def read_text(self, utt_ids: List[str], texts: List[str], dictionary=None):
assert len(utt_ids) == len(texts)
self.utt_ids = utt_ids
self.tokens_list = token_text
self.tensor_list = []
self.texts = texts
self.size = len(self.utt_ids) # number of utterances
self.sizes = []
self.token_texts = None
self.tensor_list = None
if dictionary is not None:
for tokens in self.tokens_list:
tensor = dictionary.encode_line(
tokens, add_if_not_exist=False, append_eos=self.append_eos,
).long()
self.tensor_list.append(tensor)
self.sizes.append(len(self.tensor_list[-1]))
self.token_texts = [dictionary.wordpiece_encode(x) for x in texts]
self.tensor_list = [
dictionary.encode_line(tokens, add_if_not_exist=False, append_eos=self.append_eos).long()
for tokens in self.token_texts
]
self.sizes = [len(tensor) for tensor in self.tensor_list]
else:
self.sizes = [len(tokenize_line(tokens)) for tokens in self.tokens_list]
self.sizes = [len(tokenize_line(text)) for text in texts]

self.sizes = np.array(self.sizes, dtype=np.int32)

assert (
len(self.utt_ids) == len(self.tokens_list)
and (dictionary is None or len(self.utt_ids) == len(self.tensor_list))
(
dictionary is None
or (len(self.utt_ids) == len(self.tensor_list) and len(self.utt_ids) == len(self.token_texts))
)
and len(self.utt_ids) == len(self.sizes)
)

Expand All @@ -280,15 +281,21 @@ def filter_and_reorder(self, indices):
len(np.unique(indices)) == len(indices)
), "Duplicate elements in indices."
self.utt_ids = [self.utt_ids[i] for i in indices]
self.tokens_list = [self.tokens_list[i] for i in indices]
if len(self.tensor_list) > 0:
self.texts = [self.texts[i] for i in indices]
if self.token_texts is not None:
self.token_texts = [self.token_texts[i] for i in indices]
if self.tensor_list is not None:
self.tensor_list = [self.tensor_list[i] for i in indices]
self.sizes = self.sizes[indices]
self.size = len(self.utt_ids)

def __getitem__(self, i):
self.check_index(i)
return self.tensor_list[i] if len(self.tensor_list) > 0 else None, self.tokens_list[i]
return (
self.tensor_list[i] if self.tensor_list is not None else None,
self.token_texts[i] if self.token_texts is not None else None,
self.texts[i]
)

def __len__(self):
return self.size
Expand Down
2 changes: 1 addition & 1 deletion espresso/speech_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def decode_fn(x):

# Retrieve the original sentences
if has_target:
target_str = sample["target_raw_text"][i]
target_str = sample["token_text"][i]
if not cfg.common_eval.quiet:
detok_target_str = decode_fn(target_str)
print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file)
Expand Down
18 changes: 9 additions & 9 deletions espresso/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_asr_dataset_from_json(
{
"011c0202": {
"feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819",
"token_text": "T H E <space> H O T E L",
"text": "THE HOTEL",
"utt2num_frames": "693",
},
"011c0203": {
Expand All @@ -133,12 +133,12 @@ def get_asr_dataset_from_json(
with open(data_json_path, "rb") as f:
loaded_json = json.load(f, object_pairs_hook=OrderedDict)

utt_ids, feats, token_text, utt2num_frames = [], [], [], []
utt_ids, feats, texts, utt2num_frames = [], [], [], []
for utt_id, val in loaded_json.items():
utt_ids.append(utt_id)
feats.append(val["feat"])
if "token_text" in val:
token_text.append(val["token_text"])
if "text" in val:
texts.append(val["text"])
if "utt2num_frames" in val:
utt2num_frames.append(int(val["utt2num_frames"]))

Expand All @@ -148,10 +148,10 @@ def get_asr_dataset_from_json(
specaugment_config=specaugment_config if split == "train" else None,
ordered_prefetch=True,
))
if len(token_text) > 0:
assert len(utt_ids) == len(token_text)
if len(texts) > 0:
assert len(utt_ids) == len(texts)
assert tgt_dict is not None
tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict))
tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict))

logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1])))

Expand Down Expand Up @@ -196,7 +196,7 @@ def get_asr_dataset_from_json(
@register_task("speech_recognition_espresso", dataclass=SpeechRecognitionEspressoConfig)
class SpeechRecognitionEspressoTask(FairseqTask):
"""
Transcribe from speech (source) to token text (target).
Transcribe from speech (source) to text (target).

Args:
tgt_dict (~fairseq.data.AsrDictionary): dictionary for the output tokens
Expand Down Expand Up @@ -406,7 +406,7 @@ def _inference_with_wer(self, decoder, sample, model):
scorer.reset()
for i in range(target.size(0)):
utt_id = sample["utt_id"][i]
ref_tokens = sample["target_raw_text"][i]
ref_tokens = sample["token_text"][i]
pred_tokens = self.target_dictionary.string(pred.data[i])
scorer.add_evaluation(utt_id, ref_tokens, pred_tokens)
return (
Expand Down
12 changes: 6 additions & 6 deletions examples/asr_librispeech/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ fi
if [ ${stage} -le 7 ]; then
echo "Stage 7: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $test_set; do
feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp
token_text=data/$dataset/token_text
text=data/$dataset/text
utt2num_frames=data/$dataset/utt2num_frames
asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
asr_prep_json.py --feat-files $feat --text-files $text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi

Expand Down
10 changes: 5 additions & 5 deletions examples/asr_swbd/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,18 @@ fi
if [ $stage -le 6 ]; then
echo "Stage 6: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $test_set; do
feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp
utt2num_frames=data/$dataset/utt2num_frames
# only score train_dev with built-in scorer
text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--token-text-files data/$dataset/token_text"
text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--text-files data/$dataset/text"
asr_prep_json.py --feat-files $feat $text_opt --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi
Expand Down
12 changes: 6 additions & 6 deletions examples/asr_wsj/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,22 @@ fi
if [ ${stage} -le 8 ]; then
echo "Stage 8: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $valid_set $test_set; do
if [ "$dataset" == "$valid_set" ]; then
feat=$valid_feat_dir/feats.scp
elif [ "$dataset" == "$test_set" ]; then
feat=$test_feat_dir/feats.scp
fi
token_text=data/$dataset/token_text
text=data/$dataset/text
utt2num_frames=data/$dataset/utt2num_frames
asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
asr_prep_json.py --feat-files $feat --text-files $text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi

Expand Down
Loading