In [None]:
import dataclasses
import importlib

import numpy as np

from awe import awe_model, gym, awe_trainer

for module in [awe_model, gym, awe_trainer]:
    importlib.reload(module)

In [None]:
# Test whether progress bars work.
from tqdm.auto import tqdm
_ = list(tqdm(range(1)))

### Specify parameters

In [None]:
base_params = awe_trainer.AweTrainingParams(
    use_gpu=True,
    epochs=50,
    batch_size=64,
    num_workers=8,
    train_pages_subset=100,
    version_name='lstm-output-2nd-try',
    delete_existing_version=True,
    model=awe_model.AweModelParams(
        use_lstm=True,
        lstm_args={
            'bidirectional': True
        },
        use_char_lstm=False,
        char_lstm_args={
            'bidirectional': True
        },
        disable_direct_features=False,
        use_word_vectors=True,
    )
)
param_grid = [
    dataclasses.replace(base_params)
]

### Extract features

In [None]:
trainer = awe_trainer.AweTrainer(param_grid[0])
trainer.load_data()

In [None]:
trainer.extract_features()

### Find learning rate

In [None]:
trainer = awe_trainer.AweTrainer(param_grid[0])
trainer.load_data()

In [None]:
trainer.prepare_model()

In [None]:
trainer.find_lr()

### Grid training

In [None]:
awe_trainer.train_grid(param_grid)