# Imports

In [1]:
# Add `src` directory to path (change this to your `src` directory)
import sys
sys.path.insert(0, "/home/cody/abcnn/ABCNN/src")

In [2]:
# Import `standard` modules
import os
from torch.utils.data import TensorDataset

# Import custom modules
from setup import read_config
from setup import setup
from trainer.factories import loss_fn_factory
from trainer.factories import optimizer_factory
from trainer.factories import scheduler_factory
from trainer.multiclass_classifier_trainer import MulticlassClassifierTrainer
from trainer.utils import move_to_device
from utils import abcnn_model_loader

# Setup

In [3]:
# Set the input values
CONFIG_PATH = "/home/cody/abcnn/ABCNN/src/config.yaml"# path to configuration file
TRAINSET = "moveworks_train"# name of training set (should be a key in data_paths)
VALSET = "moveworks_val"# name of validation set (should be a key in data_paths)
TESTSET = "moveworks_test"# name of test set (should be a key in data_paths)
LOAD_PATH = None # load a model from a checkpoint file

In [4]:
# Sanity check input values
assert(os.path.isfile(CONFIG_PATH))
assert(LOAD_PATH is None or os.path.isfile(LOAD_PATH))

In [5]:
# Setup modules
config = read_config(CONFIG_PATH)
features, labels, model = setup(config["model"])
model = move_to_device(config["trainer"]["device"], model) # hacky, but necessary for trainer
datasets = {
    name: TensorDataset(features[name], labels[name])
    for name in features
}
loss_fn = loss_fn_factory(config["loss_fn"])
optimizer = optimizer_factory(config["optimizer"], model.parameters())
scheduler = scheduler_factory(config["scheduler"], optimizer)
trainer = MulticlassClassifierTrainer(config["trainer"])
if LOAD_PATH:
    model, optimizer = abcnn_model_loader(LOAD, model, optimizer)

moveworks_train: 100%|██████████| 7868/7868 [00:02<00:00, 2742.42it/s]
moveworks_val: 100%|██████████| 438/438 [00:00<00:00, 2770.56it/s]
moveworks_test: 100%|██████████| 437/437 [00:00<00:00, 2770.79it/s]


Loading FastText word vectors from: /home/cody/abcnn/embeddings/fasttext/tickets/word_vector_from_tickets_skipgram_dim300_subword_min2_max6.bin


embedding matrix: 100%|██████████| 2666/2666 [00:00<00:00, 124026.87it/s]


Creating the ABCNN model...


KeyError: 'environment'

# Train the model

In [None]:
trainset = datasets[TRAINSET]
valset = datasets[VALSET] if VALSET else None
trainer.train(loss_fn, model, optimizer, trainset, scheduler=scheduler, valset=valset)

# Make predictions

In [None]:
testset = datasets[TESTSET]
trainer.predict(testset)