Skip to content

Pipeline for training Language Models using PyTorch.

Notifications You must be signed in to change notification settings

dayyass/language-modeling

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

77 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

About

Pipeline for training Language Models using PyTorch.
Inspired by Yandex Data School NLP Course (week03: Language Modeling)

Usage

First, install dependencies:

# clone repo
git clone https://github.com/dayyass/language_modeling.git

# install dependencies
cd language_modeling
pip install -r requirements.txt

Data Format

Prepared text file with space separated words on each line.
More about it here.

Statistical Language Modeling

Training

Script for training statistical language models:

python statistical_lm/train.py --path_to_data "data/arxiv_train.txt" --n 3 --path_to_save "models/3_gram_language_model.pkl"

Required arguments:

  • --path_to_data - path to train data
  • --n - n-gram order

Optional arguments:

  • --smoothing - smoothing method (available: None, "add-k") (default: None)
  • --delta - smoothing additive parameter (only for add-k smoothing) (default: 1.0)
  • --path_to_save - path to save model (default: "models/language_model.pkl")
  • --verbose - verbose (default: True)

Validation

Script for validation statistical language models using perplexity:

python statistical_lm/validate.py --path_to_data "data/arxiv_test.txt" --path_to_model "models/3_gram_language_model.pkl"

Required arguments:

  • --path_to_data - path to validation data
  • --path_to_model - path to language model

Optional arguments:

  • --verbose - verbose (default: True)

Inference

Script for generation new sequences using statistical language models:

python statistical_lm/inference.py --path_to_model "models/3_gram_language_model.pkl" --prefix "artificial" --temperature 0.5

Required arguments:

  • --path_to_model - path to language model

Optional arguments:

  • --prefix - prefix before sequence generation (default: "")
  • --strategy - decoding strategy (available: "sampling", "top-k-uniform", "top-k", "top-p-uniform", "top-p" and "beam search") (default: "sampling")
  • --temperature - sampling temperature, if temperature == 0.0, always takes most likely token - greedy decoding (only for "sampling" decoding strategy) (default: 0.0)
  • --k - top-k parameter (only for "top-k-uniform" and "top-k" decoding strategy) (default: 10)
  • --p - top-p parameter (only for "top-p-uniform" and "top-p" decoding strategy) (default: 0.9)
  • --max_length - max number of generated words (default: 100)
  • --seed - random seed (default: 42)

Command output with 3-gram language model trained on arxiv.txt with prefix "artificial" and temperature 0.5:

artificial neural network ( cnn ) architectures on h2o platform for real - world applications . <EOS>

RNN Language Modeling

Training

Script for training RNN language models:

python rnn_lm/train.py --path_to_data "data/arxiv_train.txt" --path_to_save_folder "models/rnn_language_model" --n_epoch 5 --max_length 512 --batch_size 128 --embedding_dim 64 --rnn_hidden_size 256

Required arguments:

  • --path_to_data - path to train data
  • --n_epoch - number of epochs
  • --batch_size - dataloader batch_size
  • --embedding_dim - embedding dimension
  • --rnn_hidden_size - LSTM hidden size

Optional arguments:

  • --path_to_save_folder - path to save folder (default: "models/rnn_language_model")
  • --max_length - max sentence length (chars) (default: None)
  • --shuffle - dataloader shuffle (default: True)
  • --rnn_num_layers - number of LSTM layers (default: 1)
  • --rnn_dropout - LSTM dropout (default: 0.0)
  • --train_eval_freq - evaluation frequency (number of batches) (default: 50)
  • --clip_grad_norm - max_norm parameter in clip_grad_norm (default: 1.0)
  • --seed - random seed (default: 42)
  • --device - torch device (available: "cpu", "cuda") (default: "cuda")
  • --verbose - verbose (default: True)

Validation

Script for validation RNN language models using perplexity:

python rnn_lm/validate.py --path_to_data "data/arxiv_test.txt" --path_to_model_folder "models/rnn_language_model" --max_length 512

Required arguments:

  • --path_to_data - path to validation data
  • --path_to_model - path to language model

Optional arguments:

  • --max_length - max sentence length (chars) (default: None)
  • --seed - random seed (default: 42)
  • --device - torch device (available: "cpu", "cuda") (default: "cuda")
  • --verbose - verbose (default: True)

Inference

Script for generation new sequences using RNN language models:

python rnn_lm/inference.py --path_to_model_folder "models/rnn_language_model" --prefix "artificial" --temperature 0.5

Required arguments:

  • --path_to_model_folder - path to language model folder

Optional arguments:

  • --prefix - prefix before sequence generation (default: "")
  • --temperature - sampling temperature, if temperature == 0.0, always takes most likely token - greedy decoding (default: 0.0)
  • --max_length - max number of generated tokens (chars) (default: 100)
  • --seed - random seed (default: 42)
  • --device - torch device (available: "cpu", "cuda") (default: "cuda")

Command output with RNN language model trained on arxiv.txt with prefix "artificial" and temperature 0.5:

artificial visual information of the number , using an intervidence for detection for order to the recognition

Models

List of implemented models:

Decoding Strategy

  • greedy
  • temperature sampling
  • top-k-uniform
  • top-k
  • top-p-uniform
  • top-p
  • beam search

Smoothing (only for N-gram Language Models)

  • no smoothing
  • add-k / Laplace smoothing
  • interpolation smoothing
  • back-off / Katz smoothing
  • Kneser-Ney smoothing

Models Comparison

Generation comparison available here.

Statistical Language Modeling

perplexity (train / test) none add-k / Laplace interpolation back-off / Katz Kneser-Ney
1-gram 881.27 / 1832.23 882.63 / 1838.22 - - -
2-gram 95.32 / 8.57e+7 1106.79 / 1292.02 - - -
3-gram 12.78 / 6.2e+22 7032.91 / 10499.24 - - -