Skip to content

Commit

Permalink
tokenization within datasets (need to re-run json prep stage right
Browse files Browse the repository at this point in the history
    before training stage to be compatible with this change) (#58)
  • Loading branch information
freewym committed Jun 27, 2022
1 parent eb85929 commit 7f2c317
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 74 deletions.
3 changes: 1 addition & 2 deletions espresso/criterions/cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def forward(self, model, sample, reduce=True):
assert pred.size() == target.size()
with data_utils.numpy_seed(model.num_updates):
i = np.random.randint(0, len(sample["id"]))
ref_tokens = sample["target_raw_text"][i]
length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
ref_one = self.dictionary.wordpiece_decode(ref_tokens)
ref_one = sample["text"][i]
pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length]))
logger.info("sample REF: " + ref_one)
logger.info("sample PRD: " + pred_one)
Expand Down
3 changes: 1 addition & 2 deletions espresso/criterions/label_smoothed_cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ def forward(self, model, sample, reduce=True):
assert pred.size() == target.size()
with data_utils.numpy_seed(model.num_updates):
i = np.random.randint(0, len(sample["id"]))
ref_tokens = sample["target_raw_text"][i]
length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
ref_one = self.dictionary.wordpiece_decode(ref_tokens)
ref_one = sample["text"][i]
pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length]))
logger.info("sample REF: " + ref_one)
logger.info("sample PRD: " + pred_one)
Expand Down
2 changes: 1 addition & 1 deletion espresso/data/asr_chain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
text_item = self.text[index][1] if self.text is not None else None
text_item = self.text[index][2] if self.text is not None else None
src_item = self.src[index]
example = {
"id": index,
Expand Down
22 changes: 15 additions & 7 deletions espresso/data/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
else:
ntokens = src_lengths.sum().item()

target_raw_text = None
if samples[0].get("target_raw_text", None) is not None:
target_raw_text = [samples[i]["target_raw_text"] for i in sort_order.numpy()]
token_text = None
if samples[0].get("token_text", None) is not None:
token_text = [samples[i]["token_text"] for i in sort_order.numpy()]

text = None
if samples[0].get("text", None) is not None:
text = [samples[i]["text"] for i in sort_order.numpy()]

batch = {
"id": id,
Expand All @@ -103,7 +107,8 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
"ntokens": ntokens,
"net_input": {"src_tokens": src_frames, "src_lengths": src_lengths},
"target": target,
"target_raw_text": target_raw_text,
"token_text": token_text,
"text": text,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
Expand Down Expand Up @@ -253,14 +258,16 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index][0] if self.tgt is not None else None
raw_text_item = self.tgt[index][1] if self.tgt is not None else None
token_text_item = self.tgt[index][1] if self.tgt is not None else None
text_item = self.tgt[index][2] if self.tgt is not None else None
src_item = self.src[index]
example = {
"id": index,
"utt_id": self.src.utt_ids[index],
"source": src_item,
"target": tgt_item,
"target_raw_text": raw_text_item,
"token_text": token_text_item,
"text": text_item,
}
if self.constraints is not None:
example["constraints"] = self.constraints[index]
Expand Down Expand Up @@ -304,7 +311,8 @@ def collater(self, samples, pad_to_length=None):
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `target_raw_text` (List[str]): list of original text
- `token_text` (List[str]): list of token text
- `text` (List[str]): list of original text
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
Expand Down
8 changes: 4 additions & 4 deletions espresso/data/asr_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
from fairseq.data import Dictionary, encoders
from fairseq.dataclass import FairseqDataclass
from fairseq.file_io import PathManager
from omegaconf import DictConfig

# will automatically load modules defined from there
from espresso.data import encoders as encoders_espresso
Expand Down Expand Up @@ -99,12 +99,12 @@ def dummy_sentence(self, length):
t[-1] = self.eos()
return t

def build_tokenizer(self, cfg: Union[FairseqDataclass, Namespace]):
def build_tokenizer(self, cfg: Union[DictConfig, Namespace]):
self.tokenizer = encoders.build_tokenizer(cfg)

def build_bpe(self, cfg: Union[FairseqDataclass, Namespace]):
def build_bpe(self, cfg: Union[DictConfig, Namespace]):
if (
(isinstance(cfg, FairseqDataclass) and cfg._name == "characters_asr")
(isinstance(cfg, DictConfig) and cfg._name == "characters_asr")
or (isinstance(cfg, Namespace) and getattr(cfg, "bpe", None) == "characters_asr")
):
self.bpe = encoders.build_bpe(
Expand Down
2 changes: 1 addition & 1 deletion espresso/data/asr_xent_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get_batch_shapes(self):

def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
text_item = self.text[index][1] if self.text is not None else None
text_item = self.text[index][2] if self.text is not None else None
src_item = self.src[index]
example = {
"id": index,
Expand Down
51 changes: 29 additions & 22 deletions espresso/data/feat_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,38 +233,39 @@ def __getitem__(self, i):


class AsrTextDataset(torch.utils.data.Dataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory. Each line of the text file is in the
format of 'utt_id tokenized_text'."""
"""Takes a text file as input, tokenizes and tensorizes it in memory at instantiation.
Both original text and tokenized text are kept in memory."""

def __init__(self, utt_ids: List[str], token_text: List[str], dictionary=None, append_eos=True):
def __init__(self, utt_ids: List[str], texts: List[str], dictionary=None, append_eos=True):
super().__init__()
self.dtype = np.float
self.append_eos = append_eos
self.read_text(utt_ids, token_text, dictionary)
self.read_text(utt_ids, texts, dictionary)

def read_text(self, utt_ids: List[str], token_text: List[str], dictionary=None):
assert len(utt_ids) == len(token_text)
def read_text(self, utt_ids: List[str], texts: List[str], dictionary=None):
assert len(utt_ids) == len(texts)
self.utt_ids = utt_ids
self.tokens_list = token_text
self.tensor_list = []
self.texts = texts
self.size = len(self.utt_ids) # number of utterances
self.sizes = []
self.token_texts = None
self.tensor_list = None
if dictionary is not None:
for tokens in self.tokens_list:
tensor = dictionary.encode_line(
tokens, add_if_not_exist=False, append_eos=self.append_eos,
).long()
self.tensor_list.append(tensor)
self.sizes.append(len(self.tensor_list[-1]))
self.token_texts = [dictionary.wordpiece_encode(x) for x in texts]
self.tensor_list = [
dictionary.encode_line(tokens, add_if_not_exist=False, append_eos=self.append_eos).long()
for tokens in self.token_texts
]
self.sizes = [len(tensor) for tensor in self.tensor_list]
else:
self.sizes = [len(tokenize_line(tokens)) for tokens in self.tokens_list]
self.sizes = [len(tokenize_line(text)) for text in texts]

self.sizes = np.array(self.sizes, dtype=np.int32)

assert (
len(self.utt_ids) == len(self.tokens_list)
and (dictionary is None or len(self.utt_ids) == len(self.tensor_list))
(
dictionary is None
or (len(self.utt_ids) == len(self.tensor_list) and len(self.utt_ids) == len(self.token_texts))
)
and len(self.utt_ids) == len(self.sizes)
)

Expand All @@ -280,15 +281,21 @@ def filter_and_reorder(self, indices):
len(np.unique(indices)) == len(indices)
), "Duplicate elements in indices."
self.utt_ids = [self.utt_ids[i] for i in indices]
self.tokens_list = [self.tokens_list[i] for i in indices]
if len(self.tensor_list) > 0:
self.texts = [self.texts[i] for i in indices]
if self.token_texts is not None:
self.token_texts = [self.token_texts[i] for i in indices]
if self.tensor_list is not None:
self.tensor_list = [self.tensor_list[i] for i in indices]
self.sizes = self.sizes[indices]
self.size = len(self.utt_ids)

def __getitem__(self, i):
self.check_index(i)
return self.tensor_list[i] if len(self.tensor_list) > 0 else None, self.tokens_list[i]
return (
self.tensor_list[i] if self.tensor_list is not None else None,
self.token_texts[i] if self.token_texts is not None else None,
self.texts[i]
)

def __len__(self):
return self.size
Expand Down
4 changes: 3 additions & 1 deletion espresso/speech_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def _main(cfg, output_file):
use_cuda = torch.cuda.is_available() and not cfg.common.cpu

task = tasks.setup_task(cfg.task)
task.build_tokenizer(cfg.tokenizer)
task.build_bpe(cfg.bpe)

# Set dictionary
dictionary = task.target_dictionary
Expand Down Expand Up @@ -253,7 +255,7 @@ def decode_fn(x):

# Retrieve the original sentences
if has_target:
target_str = sample["target_raw_text"][i]
target_str = sample["token_text"][i]
if not cfg.common_eval.quiet:
detok_target_str = decode_fn(target_str)
print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file)
Expand Down
18 changes: 9 additions & 9 deletions espresso/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_asr_dataset_from_json(
{
"011c0202": {
"feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819",
"token_text": "T H E <space> H O T E L",
"text": "THE HOTEL",
"utt2num_frames": "693",
},
"011c0203": {
Expand All @@ -133,12 +133,12 @@ def get_asr_dataset_from_json(
with open(data_json_path, "rb") as f:
loaded_json = json.load(f, object_pairs_hook=OrderedDict)

utt_ids, feats, token_text, utt2num_frames = [], [], [], []
utt_ids, feats, texts, utt2num_frames = [], [], [], []
for utt_id, val in loaded_json.items():
utt_ids.append(utt_id)
feats.append(val["feat"])
if "token_text" in val:
token_text.append(val["token_text"])
if "text" in val:
texts.append(val["text"])
if "utt2num_frames" in val:
utt2num_frames.append(int(val["utt2num_frames"]))

Expand All @@ -148,10 +148,10 @@ def get_asr_dataset_from_json(
specaugment_config=specaugment_config if split == "train" else None,
ordered_prefetch=True,
))
if len(token_text) > 0:
assert len(utt_ids) == len(token_text)
if len(texts) > 0:
assert len(utt_ids) == len(texts)
assert tgt_dict is not None
tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict))
tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict))

logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1])))

Expand Down Expand Up @@ -196,7 +196,7 @@ def get_asr_dataset_from_json(
@register_task("speech_recognition_espresso", dataclass=SpeechRecognitionEspressoConfig)
class SpeechRecognitionEspressoTask(FairseqTask):
"""
Transcribe from speech (source) to token text (target).
Transcribe from speech (source) to text (target).
Args:
tgt_dict (~fairseq.data.AsrDictionary): dictionary for the output tokens
Expand Down Expand Up @@ -406,7 +406,7 @@ def _inference_with_wer(self, decoder, sample, model):
scorer.reset()
for i in range(target.size(0)):
utt_id = sample["utt_id"][i]
ref_tokens = sample["target_raw_text"][i]
ref_tokens = sample["token_text"][i]
pred_tokens = self.target_dictionary.string(pred.data[i])
scorer.add_evaluation(utt_id, ref_tokens, pred_tokens)
return (
Expand Down
12 changes: 6 additions & 6 deletions examples/asr_librispeech/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ fi
if [ ${stage} -le 7 ]; then
echo "Stage 7: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $test_set; do
feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp
token_text=data/$dataset/token_text
text=data/$dataset/text
utt2num_frames=data/$dataset/utt2num_frames
asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
asr_prep_json.py --feat-files $feat --text-files $text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi

Expand Down
10 changes: 5 additions & 5 deletions examples/asr_swbd/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,18 @@ fi
if [ $stage -le 6 ]; then
echo "Stage 6: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $test_set; do
feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp
utt2num_frames=data/$dataset/utt2num_frames
# only score train_dev with built-in scorer
text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--token-text-files data/$dataset/token_text"
text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--text-files data/$dataset/text"
asr_prep_json.py --feat-files $feat $text_opt --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi
Expand Down
12 changes: 6 additions & 6 deletions examples/asr_wsj/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,22 @@ fi
if [ ${stage} -le 8 ]; then
echo "Stage 8: Dump Json Files"
train_feat=$train_feat_dir/feats.scp
train_token_text=data/$train_set/token_text
train_text=data/$train_set/text
train_utt2num_frames=data/$train_set/utt2num_frames
valid_feat=$valid_feat_dir/feats.scp
valid_token_text=data/$valid_set/token_text
valid_text=data/$valid_set/text
valid_utt2num_frames=data/$valid_set/utt2num_frames
asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
asr_prep_json.py --feat-files $train_feat --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/train.json
asr_prep_json.py --feat-files $valid_feat --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json
for dataset in $valid_set $test_set; do
if [ "$dataset" == "$valid_set" ]; then
feat=$valid_feat_dir/feats.scp
elif [ "$dataset" == "$test_set" ]; then
feat=$test_feat_dir/feats.scp
fi
token_text=data/$dataset/token_text
text=data/$dataset/text
utt2num_frames=data/$dataset/utt2num_frames
asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
asr_prep_json.py --feat-files $feat --text-files $text --utt2num-frames-files $utt2num_frames --output data/$dataset.json
done
fi

Expand Down
Loading

0 comments on commit 7f2c317

Please sign in to comment.