# Train Model

Train a model and persist it using joblib

In [2]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

In [3]:
from joblib import dump, load
from pactus import Dataset
from pactus import featurizers
from pactus.models import KNeighborsModel, RandomForestModel, LSTMModel

from traj_xai.src import SimpleGRUModel, SimpleTransformerModel, TrajFormerModel

In [4]:
SEED = 42
# SEED = None

datasets = [
    Dataset.animals,  # 0
    Dataset.geolife,  # 1
    Dataset.hurdat2,  # 2
    Dataset.uci_pen_digits,  # 3
    Dataset.uci_movement_libras,  # 4
    Dataset.uci_characters,  # 5
    Dataset.mnist_stroke  # 6
]

# Load dataset
dataset_idx = 1
dataset = datasets[dataset_idx]()
print(f"Dataset loaded: {len(dataset.trajs)} trajectories ({datasets[dataset_idx].__name__})")

# Split data into train and test subsets
train, test = dataset.split(0.8, random_state=SEED)
print(f"Train set: {len(train.trajs)} trajectories")
print(f"Test set: {len(test.trajs)} trajectories")

Dataset loaded: 9288 trajectories (geolife)
Train set: 7430 trajectories
Test set: 1858 trajectories


In [5]:
models = [
    lambda: RandomForestModel(featurizer=featurizers.UniversalFeaturizer(), n_jobs=-1, random_state=SEED),  # 0
    lambda: KNeighborsModel(featurizer=featurizers.UniversalFeaturizer(), random_state=SEED),  # 1
    lambda: LSTMModel(random_state=SEED),  # 2
    lambda: SimpleGRUModel(random_state=SEED),  # 3
    lambda: SimpleTransformerModel(random_state=SEED),  # 4
    lambda: TrajFormerModel(c_out=len(dataset.classes), random_state=SEED)  # 5
]

# build and train the model
model_idx = 0
model = models[model_idx]()
if model_idx < 2:
    model.train(train, 5)
else:
    model.train(train, dataset, epochs=10, batch_size=64)

Fitting 5 folds for each of 1 candidates, totalling 5 fits




[CV 1/5] END ..................................., score=0.861 total time=   0.6s
[CV 2/5] END ..................................., score=0.859 total time=   0.6s
[CV 3/5] END ..................................., score=0.866 total time=   0.6s
[CV 4/5] END ..................................., score=0.873 total time=   0.6s
[CV 5/5] END ..................................., score=0.861 total time=   0.6s


In [6]:
# Evaluate the model on the test set
evaluation = model.evaluate(test)
evaluation.show()

12:11:18 [INFO] Evaluating the random_forest model



General statistics:

Accuracy: 0.876
F1-score: 0.655
Mean precision: 0.694
Mean recall: 0.626

Confusion matrix:

airplane  bike      boat      bus       car       run       subway    taxi      train     walk      precision 
66.67     0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       100.0     
0.0       86.73     0.0       1.64      1.29      0.0       0.85      0.98      0.0       1.04      93.71     
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       4.53      0.0       85.52     13.55     0.0       11.11     27.45     6.06      1.43      77.86     
0.0       0.0       100.0     2.46      81.94     0.0       2.56      8.82      3.03      0.13      84.11     
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       0.32      0.0       2.73      0.0       0.0       76.07     4.9       0.0       0.65      80.91   

In [7]:
def get_filename():
    model_str = ["rf", "knn", "lstm", "gru", "transformer", "trajformer"]
    return f"{model_str[model_idx]}_{datasets[dataset_idx].__name__}"

In [8]:
persist_dir = "models"
persist_filename = get_filename()
persist_path = f"{persist_dir}/{persist_filename}.joblib"

os.makedirs(persist_dir, exist_ok=True)
dump(model, persist_path)

['models/rf_geolife.joblib']

In [9]:
new_model = load(persist_path)

evaluation2 = new_model.evaluate(test)
evaluation2.show()

12:11:47 [INFO] Evaluating the random_forest model



General statistics:

Accuracy: 0.876
F1-score: 0.655
Mean precision: 0.694
Mean recall: 0.626

Confusion matrix:

airplane  bike      boat      bus       car       run       subway    taxi      train     walk      precision 
66.67     0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       100.0     
0.0       86.73     0.0       1.64      1.29      0.0       0.85      0.98      0.0       1.04      93.71     
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       4.53      0.0       85.52     13.55     0.0       11.11     27.45     6.06      1.43      77.86     
0.0       0.0       100.0     2.46      81.94     0.0       2.56      8.82      3.03      0.13      84.11     
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       0.32      0.0       2.73      0.0       0.0       76.07     4.9       0.0       0.65      80.91   

In [10]:
# sanity check that the persisted model gets the same results
evaluation.y_pred == evaluation2.y_pred

True