Skip to content

Commit

Permalink
document code and add readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kenzheng99 committed Oct 11, 2022
1 parent 793b95f commit e0f9d83
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 16 deletions.
36 changes: 36 additions & 0 deletions egs2/iam/ocr1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
This is a recipe for the IAM handwriting recognition dataset, and is an experiment with using end-to-end
ASR models to solve an OCR task.

To run, first make an account on https://fki.tic.heia-fr.ch/databases/iam-handwriting-database and fill
in the username and password in `local/data.sh`. Then, run `./run.sh`


<!-- Generated by scripts/utils/show_asr_result.sh -->
# RESULTS
## Environments
- date: `Fri Oct 7 05:52:11 EDT 2022`
- python version: `3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]`
- espnet version: `espnet 202207`
- pytorch version: `pytorch 1.10.0`
- Git hash: `5a6319300231b8193f1b6e8465d572be63150119`
- Commit date: `Sat Sep 24 12:14:08 2022 -0400`

## asr_conformer_full_vocab
### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|inference_asr_model_valid.acc.ave/test|2915|25932|80.3|17.4|2.3|0.9|20.6|73.4|

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|inference_asr_model_valid.acc.ave/test|2915|125616|93.9|4.4|1.8|0.7|6.8|73.4|

### TER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|inference_asr_model_valid.acc.ave/test|2915|128531|94.0|4.3|1.7|0.7|6.7|73.4|

8 changes: 6 additions & 2 deletions egs2/iam/ocr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ SECONDS=0

stage=1
stop_stage=2
# Set username/password from account on https://fki.tic.heia-fr.ch/register

# Fill in username/password from account on https://fki.tic.heia-fr.ch/register
iam_username=""
iam_password=""

data_dir=data/
# Set parameters for the feature dimensions used during image extraction,
# see data_prep.py for details
feature_dim=100
downsampling_factor=0.5

data_dir=data/

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "Stage 1.1: Downloading the IAM Handwriting dataset with username ${iam_username} and password ${iam_password}"
mkdir -p ${IAM}
Expand Down
79 changes: 67 additions & 12 deletions egs2/iam/ocr1/local/data_prep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
import string
"""Prepare the IAM handwriting dataset for ESPnet ASR training
Usage:
python local/data_prep.py [--feature_dim 100] [--downsampling_factor 0.5]
Expects data to be in:
downloads/
lines.txt # labels from IAM Handwriting dataset (from ascii.tgz)
lines/
*/*/*.png # "lines" images from IAM Handwriting dataset (from lines.tgz)
train.txt # id's for train/valid/test splits
valid.txt
test.txt
Required packages:
Pillow
"""
import os
import argparse
import numpy as np
Expand All @@ -7,7 +23,22 @@
from espnet.utils.cli_writers import file_writer_helper

def prepare_text(lines_file_path, output_dir, split_ids):
"""Create text file (map of ids to transcriptions) in Kaldi format"""
"""Create 'text' file (map of ids to transcriptions) in Kaldi format
Parameters
----------
lines_file_path : str
The file path of the full "lines.txt" label file of the IAM dataset
output_dir : str
The folder path for output, e.g. data/
split_ids : list of str
a list of example ids to process, used to define train/valid/test splits
Returns
-------
skipped_ids : list of str
a list of ids that were skipped due to having an empty transcription
"""
output_lines = []
skipped_ids = []
with open(lines_file_path) as lines_file:
Expand All @@ -27,9 +58,6 @@ def prepare_text(lines_file_path, output_dir, split_ids):
# extract and format transcription into Kaldi style
transcription = " ".join(line_split[8:])
transcription = transcription.replace("|", " ")
# transcription = transcription.translate(str.maketrans('', '', string.punctuation))
# transcription = transcription.replace(" ", " ")
# transcription = transcription.strip().upper()
transcription = transcription.strip()

if transcription == "":
Expand All @@ -47,7 +75,18 @@ def prepare_text(lines_file_path, output_dir, split_ids):
return skipped_ids

def prepare_utt2spk_spk2utt(output_dir, split_ids, ids_to_skip=[]):
"""Create (dummy) utt2spk and spk2utt files to satisfy Kaldi format"""
"""Create (dummy) utt2spk and spk2utt files to satisfy Kaldi format
Parameters
----------
output_dir : str
The folder path for output, e.g. data/
split_ids : list of str
a list of example ids to process, used to define train/valid/test splits
ids_to_skip : list of str
A list of ids to exclude from the output, e.g. used to ignore the ones
that were skipped due to an empty transcription
"""
output_lines = [f'{line_id} {line_id}\n' for line_id in split_ids if line_id not in ids_to_skip]
output_lines.sort()

Expand All @@ -57,9 +96,25 @@ def prepare_utt2spk_spk2utt(output_dir, split_ids, ids_to_skip=[]):
with open(os.path.join(output_dir, 'spk2utt'), 'w') as out_file:
out_file.writelines(output_lines)

def prepare_feats(img_dir, output_dir, split_ids, ids_to_skip=[], feature_dim=100, downsampling_factor=0.25):
"""Create feats.scp file from OCR images"""

def prepare_feats(img_dir, output_dir, split_ids, ids_to_skip=[], feature_dim=100, downsampling_factor=0.5):
"""Create feats.scp file from OCR images
Parameters
----------
output_dir : str
The folder path for output, e.g. data/
split_ids : list of str
a list of example ids to process, used to define train/valid/test splits
ids_to_skip : list of str
A list of ids to exclude from the output, e.g. used to ignore the ones
that were skipped due to an empty transcription
feature_dim : int (default=100)
The hidden dimension for each feature matrix, all images are resized to this height
downsampling_factor : float (default=0.5)
A multiplier for the width dimension (analogous to the 'time' dimension in ASR) of each image,
used to reduce the length for faster training. Empirically, a value of 0.5 is found to achieve the
best performance
"""
writer = file_writer_helper(
wspecifier=f'ark,scp:{os.path.join(output_dir, "feats.ark")},{os.path.join(output_dir, "feats.scp")}',
filetype='mat',
Expand Down Expand Up @@ -89,7 +144,7 @@ def prepare_feats(img_dir, output_dir, split_ids, ids_to_skip=[], feature_dim=10
# update counters for logging
num_processed += 1
total_length += img_arr.shape[0]

print(f'Extracted features for {num_processed} examples to {os.path.join(output_dir, "feats.scp")}, average length is {total_length / num_processed:.02f}')

if __name__ == '__main__':
Expand All @@ -98,9 +153,9 @@ def prepare_feats(img_dir, output_dir, split_ids, ids_to_skip=[], feature_dim=10
data_dir = "data/"

parser = argparse.ArgumentParser()
parser.add_argument('--feature_dim', type=int, default=100,
parser.add_argument('--feature_dim', type=int, default=100,
help='Feature dimension to resize each image feature to')
parser.add_argument('--downsampling_factor', type=float, default=0.25,
parser.add_argument('--downsampling_factor', type=float, default=0.5,
help='Factor to downsample the length of each image feature to, the average length will be about 1500 * downsampling_factor')

args = parser.parse_args()
Expand Down
5 changes: 3 additions & 2 deletions egs2/iam/ocr1/local/download_and_untar.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/bin/bash
# A script to download and extract the "lines" split of the IAM Handwriting Dataset

if [ $# -lt 3 ]; then
echo "Usage: $0 <data-base-path> <lrs3-username> <lrs3-password>"
echo "--args <data-base-path> : The path where to download the dataset"
echo "--args <data-base-path> : The path to download the dataset to"
echo "--args <iam-username> : The username required to download the dataset"
echo "--args <iam-password> : The password required to download the dataset"
echo "If you do not have a username/password, please request from: https://fki.tic.heia-fr.ch/register"
Expand Down Expand Up @@ -75,7 +76,7 @@ echo "Downloading train/dev/test splits"
# Using the IAM-B splits from https://github.com/shonenkov/IAM-Splitting for this recipe
for split in train valid test; do
if [ -e ${download_dir}/${split}.txt ]; then
rm ${download_dir}/${split}.txt
rm ${download_dir}/${split}.txt
fi
if ! wget -nv ${iam_splits_base_url}${split}.txt -P ${download_dir}; then
echo "Error downloading ${iam_splits_base_url}${split}.txt"
Expand Down

0 comments on commit e0f9d83

Please sign in to comment.