<a href="https://colab.research.google.com/github/himkt/optuna-allennlp/blob/master/Optuna_AllenNLP_custom_loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet "allennlp==v1.1.0rc3" optuna

[K     |████████████████████████████████| 491kB 12.8MB/s 
[K     |████████████████████████████████| 235kB 53.9MB/s 
[K     |████████████████████████████████| 778kB 52.1MB/s 
[K     |████████████████████████████████| 317kB 60.0MB/s 
[K     |████████████████████████████████| 266kB 56.3MB/s 
[K     |████████████████████████████████| 1.1MB 53.3MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 81kB 11.2MB/s 
[K     |████████████████████████████████| 3.0MB 51.0MB/s 
[K     |████████████████████████████████| 890kB 54.4MB/s 
[K     |████████████████████████████████| 1.1MB 52.4MB/s 
[K     |████████████████████████████████| 81kB 10.6MB/s 
[K     |████████████████████████████████| 112kB 65.5MB/s 
[K     |████████████████████████████████| 51kB 8.4MB/s 
[K     |████████████████████████████████| 133kB 60.2MB/s 
[?25h  Bu

In [2]:
import random

from allennlp.data import Vocabulary, allennlp_collate
from allennlp.data.dataset_readers import TextClassificationJsonReader
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import WhitespaceTokenizer
from allennlp.models import BasicClassifier
from allennlp.modules import Embedding
from allennlp.modules.seq2vec_encoders import CnnEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.training import GradientDescentTrainer
import numpy
import optuna
import torch
from torch.optim import SGD
from torch.utils.data import DataLoader

from optuna.integration import AllenNLPPruningCallback
from optuna import Trial

In [3]:
def prepare_data():
    reader = TextClassificationJsonReader(
        token_indexers={"tokens": SingleIdTokenIndexer()},
        tokenizer=WhitespaceTokenizer(),
    )
    train_dataset = reader.read("https://s3-us-west-2.amazonaws.com/allennlp/datasets/imdb/train.jsonl")  # NOQA
    valid_dataset = reader.read("https://s3-us-west-2.amazonaws.com/allennlp/datasets/imdb/dev.jsonl")  # NOQA
    vocab = Vocabulary.from_instances(train_dataset)
    train_dataset.index_with(vocab)
    valid_dataset.index_with(vocab)
    return train_dataset, valid_dataset, vocab

In [4]:
def build_model(
        vocab: Vocabulary,
        embedding_dim: int,
        max_filter_size: int,
        num_filters: int,
        output_dim: int,
        dropout: float,
):
    model = BasicClassifier(
        text_field_embedder=BasicTextFieldEmbedder(
            {
                "tokens": Embedding(
                  embedding_dim=embedding_dim,
                  trainable=True,
                  vocab=vocab
              )
            }
        ),
        seq2vec_encoder=CnnEncoder(
            ngram_filter_sizes=range(2, max_filter_size),
            num_filters=num_filters,
            embedding_dim=embedding_dim,
            output_dim=output_dim,
        ),
        dropout=dropout,
        vocab=vocab,
    )
    return model

In [5]:
def objective(trial: Trial):
    embedding_dim = trial.suggest_int("embedding_dim", 128, 256)
    max_filter_size = trial.suggest_int("max_filter_size", 3, 6)
    num_filters = trial.suggest_int("num_filters", 128, 256)
    output_dim = trial.suggest_int("output_dim", 128, 512)
    dropout = trial.suggest_float("dropout", 0, 1.0)
    lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)

    train_dataset, valid_dataset, vocab = prepare_data()
    model = build_model(vocab, embedding_dim, max_filter_size, num_filters, output_dim, dropout)
    model.to(torch.device("cuda:0"))

    optimizer = SGD(model.parameters(), lr=lr)
    data_loader = DataLoader(train_dataset, batch_size=10, collate_fn=allennlp_collate)
    validation_data_loader = DataLoader(valid_dataset, batch_size=64, collate_fn=allennlp_collate)
    trainer = GradientDescentTrainer(
        model=model,
        optimizer=optimizer,
        data_loader=data_loader,
        validation_data_loader=validation_data_loader,
        validation_metric="+accuracy",
        patience=None,  # `patience=None` since it could conflict with AllenNLPPruningCallback
        num_epochs=10,
        serialization_dir=f"result/{trial.number}",
        epoch_callbacks=[AllenNLPPruningCallback(trial, "validation_accuracy")],
    )
    return trainer.train()["best_validation_accuracy"]

In [None]:
random.seed(41)
torch.manual_seed(41)
numpy.random.seed(41)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.HyperbandPruner(),
)
study.optimize(objective, n_trials=50)

print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

reading instances: 0it [00:00, ?it/s]
downloading:   0%|          | 0/26939140 [00:00<?, ?B/s]
downloading:   0%|          | 17408/26939140 [00:00<04:05, 109854.47B/s]
downloading:   0%|          | 52224/26939140 [00:00<03:29, 128270.00B/s]
downloading:   0%|          | 104448/26939140 [00:00<02:51, 156103.41B/s]
downloading:   1%|          | 243712/26939140 [00:00<02:09, 206601.87B/s]
downloading:   2%|1         | 522240/26939140 [00:00<01:34, 280373.97B/s]
downloading:   4%|3         | 1061888/26939140 [00:00<01:06, 386318.59B/s]
downloading:   8%|8         | 2158592/26939140 [00:01<00:46, 538457.91B/s]
downloading:  16%|#6        | 4365312/26939140 [00:01<00:29, 756155.51B/s]
downloading:  24%|##4       | 6577152/26939140 [00:01<00:19, 1054665.82B/s]
downloading:  33%|###2      | 8756224/26939140 [00:01<00:12, 1456691.48B/s]
downloading:  41%|####      | 10935296/26939140 [00:01<00:08, 1986769.02B/s]
downloading:  49%|####8     | 13097984/26939140 [00:01<00:05, 2664945.68B/s]
downlo