Skip to content

Model for recasing and repunctuating ASR transcripts

License

Notifications You must be signed in to change notification settings

openaudiosearch/recasepunc

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Recasing and punctuation model based on Bert

Benoit Favre 2021

This system converts a sequence of lowercase tokens without punctuation to a sequence of cased tokens with punctuation.

It is trained to predict both aspects at the token level in a multitask fashion, from fine-tuned BERT representations.

The model predicts the following recasing labels:

  • lower: keep lowercase
  • upper: convert to upper case
  • capitalize: set first letter as upper case
  • other: left as is, but could be processed with a list

And the following punctuation labels:

  • o: no punctuation
  • period: .
  • comma: ,
  • question: ?
  • exclamation: !

Input tokens are batched as sequences of length 256 that are processed independently without overlap.

In training, batches containing less that 256 tokens are simulated by drawing uniformly a length and replacing all tokens and labels after that point with padding (called Cut-drop).

Changelog:

  • Add support for Zh and En models
  • Fix generation when input is smaller than max length

Installation

Use your favourite method for installing Python requirements. For example:

python -mvenv env
. env/bin/activate
pip3 install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html

Prediction

Predict from raw text:

python recasepunc.py predict checkpoint/path < input.txt > output.txt

Models

All models are trained from the 1st 100M tokens from Common Crawl

checkpoints/zh.24000

{
  "iteration": "24000",
  "train_loss": "0.006788245493080467",
  "valid_loss": "0.007345725328494341",
  "valid_accuracy_case": "0.9963942307692307",
  "valid_accuracy_punc": "0.9692508012820513",
  "valid_fscore": "{0: 0.7727023363113403, 1: 0.7901785373687744, 2: 0.7293065190315247, 3: 0.7692307829856873, 4: 0.4615384638309479}",
  "config": "{'seed': 871253, 'lang': 'zh', 'flavor': 'ckiplab/bert-base-chinese', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/zh-100M.train.x', 'data/zh-100M.train.y', 'data/zh-100M.valid.x', 'data/zh-100M.valid.y', 'checkpoints/zh'], 'pad_token_id': 0, 'cls_token_id': 101, 'cls_token': '[CLS]', 'sep_token_id': 102, 'sep_token': '[SEP]'}"
}

checkpoints/en.23000

{
  "iteration": "23000",
  "train_loss": "0.014598741472698748",
  "valid_loss": "0.025432642453756087",
  "valid_accuracy_case": "0.9407051282051282",
  "valid_accuracy_punc": "0.9401041666666666",
  "valid_fscore": "{0: 0.6455026268959045, 1: 0.5925925970077515, 2: 0.7243649959564209, 3: 0.7027027010917664, 4: 0.03921568766236305}",                                                    
  "config": "{'seed': 871253, 'lang': 'en', 'flavor': 'bert-base-uncased', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/en-100M.train.x', 'data/en-100M.train.y', 'data/en-100M.valid.x', 'data/en-100M.valid.y', 'checkpoints/en'], 'pad_token_id': 0, 'cls_token_id': 101, 'cls_token': '[CLS]', 'sep_token_id': 102, 'sep_token': '[SEP]'}"                                                                                           
}

checkpoints/fr.22000

{
  "iteration": "22000",
  "train_loss": "0.02052250287961215",
  "valid_loss": "0.009240646392871171",
  "valid_accuracy_case": "0.9881810897435898",
  "valid_accuracy_punc": "0.9683493589743589",
  "valid_fscore": "{0: 0.802524745464325, 1: 0.7892595529556274, 2: 0.8360477685928345, 3: 0.8717948198318481, 4: 0.2068965584039688}",
  "config": "{'seed': 871253, 'lang': 'fr', 'flavor': 'flaubert/flaubert_base_uncased', 'max_length': 256, 'batch_size': 16, 'updates': 24000, 'period': 1000, 'lr': 1e-05, 'dab_rate': 0.1, 'device': device(type='cuda'), 'debug': False, 'action': 'train', 'action_args': ['data/fr-100M.train.x', 'data/fr-100M.train.y', 'data/fr-100M.valid.x', 'data/fr-100M.valid.y', 'checkpoints/fr'], 'pad_token_id': 2, 'cls_token_id': 0, 'cls_token': '<s>', 'sep_token_id': 1, 'sep_token': '</s>'}"
}

Training

Notes: You need to modify file names adequately. Training tensors are precomputed and loaded to CPU memory, models and batches are moved to CUDA memory.

Stage 0: download text data

Stage 1: tokenize and normalize text with Moses tokenizer, and extract recasing and repunctuation labels

python recasepunc.py preprocess --lang $LANG < input.txt > input.case+punc

Stage 2: sub-tokenize with Flaubert tokenizer, and generate pytorch tensors

python recasepunc.py tensorize input.case+punc input.case+punc.x input.case+punc.y --lang $LANG

Stage 3: train model

python recasepunc.py train train.x train.y valid.x valid.y checkpoint/path --lang $LANG

Stage 4: evaluate performance on a test set

python recasepunc.py eval checkpoint/path.iteration test.x test.y

Notes

This work was not published, but a similar model is described in "FullStop: Multilingual Deep Models for Punctuation Prediction", Frank et al, 2021.

About

Model for recasing and repunctuating ASR transcripts

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%