Skip to content

Commit

Permalink
Demo Notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
burchim committed Sep 17, 2021
1 parent 19482bf commit 2f59ed2
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 3 deletions.
3 changes: 1 addition & 2 deletions .gitignore
@@ -1,8 +1,7 @@
__pycache__
*.pyc
.DS_Store
.vscode
.DS_Store
env/*
datasets/*
callbacks/*
callbacks/*
1 change: 1 addition & 0 deletions EfficientConformer.ipynb
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"EfficientConformer.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"yA9TPERPtBUL"},"source":["#Efficient Conformer Demo\n","A quick intro to using pretrained models and how to train/evaluate models.<br>\n","repo: [https://github.com/burchim/EfficientConformer](https://github.com/burchim/EfficientConformer)"]},{"cell_type":"markdown","metadata":{"id":"-I6v5ThmlRVp"},"source":["# Install"]},{"cell_type":"code","metadata":{"id":"ugmpSZEa3g13"},"source":["!git clone https://github.com/burchim/EfficientConformer.git "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iIvKGdgElYih"},"source":["import os\n","os.chdir('EfficientConformer/')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ClzKWI_TFKZK"},"source":["!pip install -r requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xr-rBVhM7mBP"},"source":["!git clone --recursive https://github.com/parlance/ctcdecode.git\n","!cd ctcdecode && pip install ."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"61bc416smy2E"},"source":["# Download pretrained models and tokenizer"]},{"cell_type":"code","metadata":{"id":"l3KrXieZqSDm"},"source":["!pip install gdown"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9Wg0BtnPTIpW"},"source":["pretrained_models = {\n"," \"EfficientConformerCTCSmall\": \"1MU49nbRONkOOGzvXHFDNfvWsyFmrrBam\",\n"," \"EfficientConformerCTCMedium\": \"1h5hRG9T_nErslm5eGgVzqx7dWDcOcGDB\",\n"," \"EfficientConformerCTCLarge\": \"1U4iBTKQogX4btE-S4rqCeeFZpj3gcweA\"\n","}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"okSOe0wTT9zp"},"source":["# Select one of the official pretrained models\n","pretrained_model = \"EfficientConformerCTCSmall\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ihk81osmsHAz"},"source":["import gdown\n","\n","# Create model callback directory\n","if not os.path.exists(os.path.join(\"callbacks\", pretrained_model)):\n"," os.mkdir(os.path.join(\"callbacks\", pretrained_model))\n","\n","# Download pretrained model checkpoint\n","gdown.download(\"https://drive.google.com/uc?id=\" + pretrained_models[pretrained_model], os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"), quiet=False)\n","\n","# Create tokenizer directory\n","if not os.path.exists(\"datasets/LibriSpeech\"):\n"," os.mkdir(\"datasets/LibriSpeech\")\n","\n","# Download pretrained model tokenizer\n","gdown.download(\"https://drive.google.com/uc?id=1hx2s4ZTDsnOFtx5_h5R_qZ3R6gEFafRx\", \"datasets/LibriSpeech/LibriSpeech_bpe_256.model\", quiet=False)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PdXrPEoaslUq"},"source":["# Test model on LibriSpeech samples"]},{"cell_type":"code","metadata":{"id":"G9TRAOYhGKHH"},"source":["# Download LibriSPeech dev-clean subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz\n","\n","# Download LibriSPeech dev-other subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NaaXsV62ux3X"},"source":["import json\n","import glob\n","import torch\n","import torchaudio\n","import IPython.display as ipd\n","from functions import create_model\n","import matplotlib.pyplot as plt\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"V0dY7IBquiRC"},"source":["config_file = \"configs/\" + pretrained_model + \".json\"\n","\n","# Load model Config\n","with open(config_file) as json_config:\n"," config = json.load(json_config)\n","\n","# PyTorch Device\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(\"Device:\", device)\n","\n","# Create and Load pretrained model\n","model = create_model(config).to(device)\n","model.summary()\n","model.eval()\n","model.load(os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_EidY8gF2Z0_"},"source":["# Get audio files paths\n","audio_files = glob.glob(\"datasets/LibriSpeech/*/*/*/*.flac\")\n","print(len(audio_files), \"audio files\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9qMnfHxZzqvX"},"source":["# Random indices\n","indices = torch.randint(0, len(audio_files), size=(10,))\n","\n","# Test model\n","for i in indices:\n","\n"," # Load audio file\n"," audio, sr = torchaudio.load(audio_files[i])\n","\n"," # Plot audio\n"," plt.title(audio_files[i].split(\"/\")[-1])\n"," plt.plot(audio[0])\n"," plt.show()\n"," print()\n","\n"," # Display\n"," ipd.display(ipd.Audio(audio, rate=sr))\n"," print()\n","\n"," # Predict sentence\n"," prediction = model.gready_search_decoding(audio.to(device), x_len=torch.tensor([len(audio[0])], device=device))[0]\n"," print(\"model prediction:\", prediction, '\\n')\n","\n"," for i in range(100):\n"," print('*', end='')\n"," print('\\n')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zCc9S6BgQ1M5"},"source":["# Training\n","Download the LibriSpeech dataset using:\n","\n","- `cd datasets && bash ./download_LibriSpeech.sh`\n","\n","Or download LibriSpeech train-clean 100h subset with:\n","\n","- `cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf datasets/train-clean-100.tar.gz`"]},{"cell_type":"code","metadata":{"id":"y7A4c0x0RnJq"},"source":["# Download LibriSPeech train-clean-100 subset\n","!cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HdYQ6GDeca6M"},"source":["Train an Efficient Conformer CTC Small model.<br>\n","The `--prepare_dataset` flag will tokenize text sequences and save samples length before training/evaluation.<br>\n","Use the `--create_tokenizer` flag if you need to create a new sentencepiece tokenizer.<br>\n","Training mode is selected by default."]},{"cell_type":"code","metadata":{"id":"0HTj66OxQ4in"},"source":["# Prepare dataset and train model\n","!python main.py --config_file configs/EfficientConformerCTCSmall.json --prepare_dataset"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Ea0W2NwVbOCe"},"source":["# Evaluation\n","Proceed to a gready search evaluation.\n","Use the `--mode` flag to select an evaluation mode:\n","\n","- `validation-clean` for evaluation on the LibriSpeech dev-clean validation set.\n","- `validation-other` for evaluation on the LibriSpeech dev-other validation set.\n","- `test-clean` for evaluation on the LibriSpeech test-clean test set.\n","- `test-other` for evaluation on the LibriSpeech test-other test set.\n","- `eval_time` to evaluate model inference time on the LibriSpeech dev-clean validation set.\n","\n","Select a model checkpoint to load for evaluation using the `--initial_epoch` flag.<br>\n","For example, `--initial_epoch swa-equal-401-450` will load the pretrained checkpoints_swa-equal-401-450.ckpt file."]},{"cell_type":"code","metadata":{"id":"SXwRILONbbkG"},"source":["!python main.py --config_file configs/EfficientConformerCTCSmall.json --mode validation-clean --initial_epoch swa-equal-401-450 --gready"],"execution_count":null,"outputs":[]}]}
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -2,6 +2,8 @@

Official implementation of the Efficient Conformer, progressively downsampled Conformer with grouped attention for Automatic Speech Recognition.

**Efficient Conformer [Paper](https://arxiv.org/abs/2109.01163) | [Demo Notebook](https://colab.research.google.com/github/burchim/EfficientConformer/blob/master/EfficientConformer.ipynb)**

## Efficient Conformer Encoder
Inspired from previous works done in Automatic Speech Recognition and Computer Vision, the Efficient Conformer encoder is composed of three encoder stages where each stage comprises a number of Conformer blocks using grouped attention. The encoded sequence is progressively downsampled and projected to wider feature dimensions, lowering the amount of computation while achieving better performance. Grouped multi-head attention reduce attention complexity by grouping neighbouring time elements along the feature dimension before applying scaled dot-product attention.

Expand Down
2 changes: 1 addition & 1 deletion models/modules.py
Expand Up @@ -423,7 +423,7 @@ def __init__(self, dim_model, num_heads, Pdrop, max_pos_encoding, relative_pos_e
# Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6)

# Efficient Multi-Head Attention
# Multi-Head Linear Attention
if linear_att:
self.mhsa = MultiHeadLinearAttention(dim_model, num_heads)

Expand Down

0 comments on commit 2f59ed2

Please sign in to comment.