# Imports

In [1]:
# Add `src` directory to path (change this to your `src` directory)
import sys
sys.path.insert(0, "/home/cody/abcnn/ABCNN/src")

In [3]:
# Import `standard` modules
import os
from torch.utils.data import TensorDataset

# Import custom modules
from setup import read_config
from setup import setup
from trainer.factories import loss_fn_factory
from trainer.factories import optimizer_factory
from trainer.factories import scheduler_factory
from trainer.multiclass_classifier_trainer import MulticlassClassifierTrainer
from trainer.utils import move_to_device
from utils import abcnn_model_loader

# Setup

In [4]:
# Set the input values
CONFIG_PATH = "/home/cody/abcnn/ABCNN/src/config.yaml"# path to configuration file
TRAINSET = "moveworks_train"# name of training set (should be a key in data_paths)
VALSET = "moveworks_val"# name of validation set (should be a key in data_paths)
TESTSET = "moveworks_test"# name of test set (should be a key in data_paths)
LOAD_PATH = None # load a model from a checkpoint file

In [5]:
# Sanity check input values
assert(os.path.isfile(CONFIG_PATH))
assert(LOAD_PATH is None or os.path.isfile(LOAD_PATH))

In [6]:
# Setup modules
config = read_config(CONFIG_PATH)
features, labels, model = setup(config["model"])
model = move_to_device(config["trainer"]["device"], model) # hacky, but necessary for trainer
datasets = {
    name: TensorDataset(features[name], labels[name])
    for name in features
}
loss_fn = loss_fn_factory(config["loss_fn"])
optimizer = optimizer_factory(config["optimizer"], model.parameters())
scheduler = scheduler_factory(config["scheduler"], optimizer)
trainer = MulticlassClassifierTrainer(config["trainer"])
if LOAD_PATH:
    model, optimizer = abcnn_model_loader(LOAD, model, optimizer)

moveworks_train: 100%|██████████| 7868/7868 [00:02<00:00, 2756.40it/s]
moveworks_val: 100%|██████████| 438/438 [00:00<00:00, 2760.01it/s]
moveworks_test: 100%|██████████| 437/437 [00:00<00:00, 2752.92it/s]


Loading FastText word vectors from: /home/cody/abcnn/embeddings/fasttext/tickets/word_vector_from_tickets_skipgram_dim300_subword_min2_max6.bin


embedding matrix: 100%|██████████| 2666/2666 [00:00<00:00, 128035.89it/s]


Creating the ABCNN model...


# Train the model

In [None]:
trainset = datasets[TRAINSET]
valset = datasets[VALSET] if VALSET else None
trainer.train(loss_fn, model, optimizer, trainset, scheduler=scheduler, valset=valset)

epochs:   0%|          | 0/10 [00:00<?, ?it/s]

train:   1%|          | 1/123 [00:00<00:46,  2.63it/s][A
train:   2%|▏         | 2/123 [00:00<00:38,  3.14it/s][A
train:   2%|▏         | 3/123 [00:00<00:35,  3.36it/s][A
train:   3%|▎         | 4/123 [00:01<00:34,  3.48it/s][A
train:   4%|▍         | 5/123 [00:01<00:33,  3.56it/s][A
train:   5%|▍         | 6/123 [00:01<00:32,  3.61it/s][A
train:   6%|▌         | 7/123 [00:01<00:31,  3.65it/s][A
train:   7%|▋         | 8/123 [00:02<00:31,  3.67it/s][A
train:   7%|▋         | 9/123 [00:02<00:30,  3.70it/s][A
train:   8%|▊         | 10/123 [00:02<00:30,  3.71it/s][A
train:   9%|▉         | 11/123 [00:02<00:30,  3.73it/s][A
train:  10%|▉         | 12/123 [00:03<00:29,  3.74it/s][A
train:  11%|█         | 13/123 [00:03<00:29,  3.75it/s][A
train:  11%|█▏        | 14/123 [00:03<00:28,  3.76it/s][A
train:  12%|█▏        | 15/123 [00:03<00:28,  3.76it/s][A
train:  13%|█▎        | 16/123 [00:04<00:28,  3.77it/s][A
train:  14%|█▍   

train:   1%|          | 1/123 [00:00<00:50,  2.43it/s][A
train:   2%|▏         | 2/123 [00:00<00:40,  2.97it/s][A
train:   2%|▏         | 3/123 [00:00<00:37,  3.21it/s][A
train:   3%|▎         | 4/123 [00:01<00:35,  3.35it/s][A
train:   4%|▍         | 5/123 [00:01<00:34,  3.44it/s][A
train:   5%|▍         | 6/123 [00:01<00:33,  3.50it/s][A
train:   6%|▌         | 7/123 [00:01<00:32,  3.54it/s][A
train:   7%|▋         | 8/123 [00:02<00:32,  3.58it/s][A
train:   7%|▋         | 9/123 [00:02<00:31,  3.61it/s][A
train:   8%|▊         | 10/123 [00:02<00:31,  3.63it/s][A
train:   9%|▉         | 11/123 [00:03<00:30,  3.65it/s][A
train:  10%|▉         | 12/123 [00:03<00:30,  3.66it/s][A
train:  11%|█         | 13/123 [00:03<00:29,  3.68it/s][A
train:  11%|█▏        | 14/123 [00:03<00:29,  3.69it/s][A
train:  12%|█▏        | 15/123 [00:04<00:29,  3.70it/s][A
train:  13%|█▎        | 16/123 [00:04<00:28,  3.71it/s][A
train:  14%|█▍        | 17/123 [00:04<00:28,  3.71it/s][A
train:

train:   6%|▌         | 7/123 [00:02<00:33,  3.42it/s][A
train:   7%|▋         | 8/123 [00:02<00:33,  3.48it/s][A
train:   7%|▋         | 9/123 [00:02<00:32,  3.52it/s][A
train:   8%|▊         | 10/123 [00:02<00:31,  3.55it/s][A
train:   9%|▉         | 11/123 [00:03<00:31,  3.58it/s][A
train:  10%|▉         | 12/123 [00:03<00:30,  3.60it/s][A
train:  11%|█         | 13/123 [00:03<00:30,  3.62it/s][A
train:  11%|█▏        | 14/123 [00:03<00:30,  3.63it/s][A
train:  12%|█▏        | 15/123 [00:04<00:29,  3.64it/s][A
train:  13%|█▎        | 16/123 [00:04<00:29,  3.65it/s][A
train:  14%|█▍        | 17/123 [00:04<00:28,  3.66it/s][A
train:  15%|█▍        | 18/123 [00:04<00:28,  3.67it/s][A
train:  15%|█▌        | 19/123 [00:05<00:28,  3.68it/s][A
train:  16%|█▋        | 20/123 [00:05<00:27,  3.69it/s][A
train:  17%|█▋        | 21/123 [00:05<00:27,  3.69it/s][A
train:  18%|█▊        | 22/123 [00:05<00:27,  3.70it/s][A
train:  19%|█▊        | 23/123 [00:06<00:26,  3.70it/s][A


train:  11%|█         | 13/123 [00:03<00:31,  3.51it/s][A
train:  11%|█▏        | 14/123 [00:03<00:30,  3.53it/s][A
train:  12%|█▏        | 15/123 [00:04<00:30,  3.55it/s][A
train:  13%|█▎        | 16/123 [00:04<00:30,  3.56it/s][A
train:  14%|█▍        | 17/123 [00:04<00:29,  3.58it/s][A
train:  15%|█▍        | 18/123 [00:05<00:29,  3.59it/s][A
train:  15%|█▌        | 19/123 [00:05<00:28,  3.60it/s][A
train:  16%|█▋        | 20/123 [00:05<00:28,  3.60it/s][A
train:  17%|█▋        | 21/123 [00:05<00:28,  3.61it/s][A
train:  18%|█▊        | 22/123 [00:06<00:27,  3.62it/s][A
train:  19%|█▊        | 23/123 [00:06<00:27,  3.63it/s][A
train:  20%|█▉        | 24/123 [00:06<00:27,  3.64it/s][A
train:  20%|██        | 25/123 [00:06<00:26,  3.65it/s][A
train:  21%|██        | 26/123 [00:07<00:26,  3.65it/s][A
train:  22%|██▏       | 27/123 [00:07<00:26,  3.66it/s][A
train:  23%|██▎       | 28/123 [00:07<00:25,  3.66it/s][A
train:  24%|██▎       | 29/123 [00:07<00:25,  3.67it/s]

train:  15%|█▌        | 19/123 [00:05<00:28,  3.62it/s][A
train:  16%|█▋        | 20/123 [00:05<00:28,  3.63it/s][A
train:  17%|█▋        | 21/123 [00:05<00:28,  3.64it/s][A
train:  18%|█▊        | 22/123 [00:06<00:27,  3.65it/s][A
train:  19%|█▊        | 23/123 [00:06<00:27,  3.65it/s][A
train:  20%|█▉        | 24/123 [00:06<00:27,  3.66it/s][A
train:  20%|██        | 25/123 [00:06<00:26,  3.67it/s][A
train:  21%|██        | 26/123 [00:07<00:26,  3.67it/s][A
train:  22%|██▏       | 27/123 [00:07<00:26,  3.68it/s][A
train:  23%|██▎       | 28/123 [00:07<00:25,  3.68it/s][A
train:  24%|██▎       | 29/123 [00:07<00:25,  3.69it/s][A
train:  24%|██▍       | 30/123 [00:08<00:25,  3.69it/s][A
train:  25%|██▌       | 31/123 [00:08<00:24,  3.69it/s][A
train:  26%|██▌       | 32/123 [00:08<00:24,  3.70it/s][A
train:  27%|██▋       | 33/123 [00:08<00:24,  3.70it/s][A
train:  28%|██▊       | 34/123 [00:09<00:24,  3.70it/s][A
train:  28%|██▊       | 35/123 [00:09<00:23,  3.71it/s]

train:  20%|██        | 25/123 [00:06<00:26,  3.63it/s][A
train:  21%|██        | 26/123 [00:07<00:26,  3.64it/s][A
train:  22%|██▏       | 27/123 [00:07<00:26,  3.65it/s][A
train:  23%|██▎       | 28/123 [00:07<00:26,  3.65it/s][A
train:  24%|██▎       | 29/123 [00:07<00:25,  3.66it/s][A
train:  24%|██▍       | 30/123 [00:08<00:25,  3.66it/s][A
train:  25%|██▌       | 31/123 [00:08<00:25,  3.67it/s][A
train:  26%|██▌       | 32/123 [00:08<00:24,  3.67it/s][A
train:  27%|██▋       | 33/123 [00:08<00:24,  3.68it/s][A
train:  28%|██▊       | 34/123 [00:09<00:24,  3.68it/s][A
train:  28%|██▊       | 35/123 [00:09<00:23,  3.68it/s][A
train:  29%|██▉       | 36/123 [00:09<00:23,  3.69it/s][A
train:  30%|███       | 37/123 [00:10<00:23,  3.69it/s][A
train:  31%|███       | 38/123 [00:10<00:23,  3.69it/s][A
train:  32%|███▏      | 39/123 [00:10<00:22,  3.70it/s][A
train:  33%|███▎      | 40/123 [00:10<00:22,  3.70it/s][A
train:  33%|███▎      | 41/123 [00:11<00:22,  3.70it/s]

# Make predictions

In [None]:
testset = datasets[TESTSET]
trainer.predict(testset)