Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
5 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,7 @@ | ||
__pycache__ | ||
*.pyc | ||
.DS_Store | ||
.vscode | ||
.DS_Store | ||
env/* | ||
datasets/* | ||
callbacks/* | ||
callbacks/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"EfficientConformer.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"yA9TPERPtBUL"},"source":["#Efficient Conformer Demo\n","A quick intro to using pretrained models and how to train/evaluate models.<br>\n","repo: [https://github.com/burchim/EfficientConformer](https://github.com/burchim/EfficientConformer)"]},{"cell_type":"markdown","metadata":{"id":"-I6v5ThmlRVp"},"source":["# Install"]},{"cell_type":"code","metadata":{"id":"ugmpSZEa3g13"},"source":["!git clone https://github.com/burchim/EfficientConformer.git "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iIvKGdgElYih"},"source":["import os\n","os.chdir('EfficientConformer/')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ClzKWI_TFKZK"},"source":["!pip install -r requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xr-rBVhM7mBP"},"source":["!git clone --recursive https://github.com/parlance/ctcdecode.git\n","!cd ctcdecode && pip install ."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"61bc416smy2E"},"source":["# Download pretrained models and tokenizer"]},{"cell_type":"code","metadata":{"id":"l3KrXieZqSDm"},"source":["!pip install gdown"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9Wg0BtnPTIpW"},"source":["pretrained_models = {\n"," \"EfficientConformerCTCSmall\": \"1MU49nbRONkOOGzvXHFDNfvWsyFmrrBam\",\n"," \"EfficientConformerCTCMedium\": \"1h5hRG9T_nErslm5eGgVzqx7dWDcOcGDB\",\n"," \"EfficientConformerCTCLarge\": \"1U4iBTKQogX4btE-S4rqCeeFZpj3gcweA\"\n","}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"okSOe0wTT9zp"},"source":["# Select one of the official pretrained models\n","pretrained_model = \"EfficientConformerCTCSmall\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ihk81osmsHAz"},"source":["import gdown\n","\n","# Create model callback directory\n","if not os.path.exists(os.path.join(\"callbacks\", pretrained_model)):\n"," os.mkdir(os.path.join(\"callbacks\", pretrained_model))\n","\n","# Download pretrained model checkpoint\n","gdown.download(\"https://drive.google.com/uc?id=\" + pretrained_models[pretrained_model], os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"), quiet=False)\n","\n","# Create tokenizer directory\n","if not os.path.exists(\"datasets/LibriSpeech\"):\n"," os.mkdir(\"datasets/LibriSpeech\")\n","\n","# Download pretrained model tokenizer\n","gdown.download(\"https://drive.google.com/uc?id=1hx2s4ZTDsnOFtx5_h5R_qZ3R6gEFafRx\", \"datasets/LibriSpeech/LibriSpeech_bpe_256.model\", quiet=False)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PdXrPEoaslUq"},"source":["# Test model on LibriSpeech samples"]},{"cell_type":"code","metadata":{"id":"G9TRAOYhGKHH"},"source":["# Download LibriSPeech dev-clean subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz\n","\n","# Download LibriSPeech dev-other subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NaaXsV62ux3X"},"source":["import json\n","import glob\n","import torch\n","import torchaudio\n","import IPython.display as ipd\n","from functions import create_model\n","import matplotlib.pyplot as plt\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"V0dY7IBquiRC"},"source":["config_file = \"configs/\" + pretrained_model + \".json\"\n","\n","# Load model Config\n","with open(config_file) as json_config:\n"," config = json.load(json_config)\n","\n","# PyTorch Device\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(\"Device:\", device)\n","\n","# Create and Load pretrained model\n","model = create_model(config).to(device)\n","model.summary()\n","model.eval()\n","model.load(os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_EidY8gF2Z0_"},"source":["# Get audio files paths\n","audio_files = glob.glob(\"datasets/LibriSpeech/*/*/*/*.flac\")\n","print(len(audio_files), \"audio files\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9qMnfHxZzqvX"},"source":["# Random indices\n","indices = torch.randint(0, len(audio_files), size=(10,))\n","\n","# Test model\n","for i in indices:\n","\n"," # Load audio file\n"," audio, sr = torchaudio.load(audio_files[i])\n","\n"," # Plot audio\n"," plt.title(audio_files[i].split(\"/\")[-1])\n"," plt.plot(audio[0])\n"," plt.show()\n"," print()\n","\n"," # Display\n"," ipd.display(ipd.Audio(audio, rate=sr))\n"," print()\n","\n"," # Predict sentence\n"," prediction = model.gready_search_decoding(audio.to(device), x_len=torch.tensor([len(audio[0])], device=device))[0]\n"," print(\"model prediction:\", prediction, '\\n')\n","\n"," for i in range(100):\n"," print('*', end='')\n"," print('\\n')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zCc9S6BgQ1M5"},"source":["# Training\n","Download the LibriSpeech dataset using:\n","\n","- `cd datasets && bash ./download_LibriSpeech.sh`\n","\n","Or download LibriSpeech train-clean 100h subset with:\n","\n","- `cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf datasets/train-clean-100.tar.gz`"]},{"cell_type":"code","metadata":{"id":"y7A4c0x0RnJq"},"source":["# Download LibriSPeech train-clean-100 subset\n","!cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HdYQ6GDeca6M"},"source":["Train an Efficient Conformer CTC Small model.<br>\n","The `--prepare_dataset` flag will tokenize text sequences and save samples length before training/evaluation.<br>\n","Use the `--create_tokenizer` flag if you need to create a new sentencepiece tokenizer.<br>\n","Training mode is selected by default."]},{"cell_type":"code","metadata":{"id":"0HTj66OxQ4in"},"source":["# Prepare dataset and train model\n","!python main.py --config_file configs/EfficientConformerCTCSmall.json --prepare_dataset"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Ea0W2NwVbOCe"},"source":["# Evaluation\n","Proceed to a gready search evaluation.\n","Use the `--mode` flag to select an evaluation mode:\n","\n","- `validation-clean` for evaluation on the LibriSpeech dev-clean validation set.\n","- `validation-other` for evaluation on the LibriSpeech dev-other validation set.\n","- `test-clean` for evaluation on the LibriSpeech test-clean test set.\n","- `test-other` for evaluation on the LibriSpeech test-other test set.\n","- `eval_time` to evaluate model inference time on the LibriSpeech dev-clean validation set.\n","\n","Select a model checkpoint to load for evaluation using the `--initial_epoch` flag.<br>\n","For example, `--initial_epoch swa-equal-401-450` will load the pretrained checkpoints_swa-equal-401-450.ckpt file."]},{"cell_type":"code","metadata":{"id":"SXwRILONbbkG"},"source":["!python main.py --config_file configs/EfficientConformerCTCSmall.json --mode validation-clean --initial_epoch swa-equal-401-450 --gready"],"execution_count":null,"outputs":[]}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters