# Train Model (Colab version)

Train a model and persist it using joblib

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Install required packages if needed
!pip install rdp fastdtw pactus fvcore "yupi==0.12.5" | grep -v 'already satisfied'

Collecting rdp
  Downloading rdp-0.8.tar.gz (4.4 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting fastdtw
  Downloading fastdtw-0.3.4.tar.gz (133 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.4/133.4 kB 10.8 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting pactus
  Downloading pactus-0.4.3-py3-none-any.whl.metadata (4.7 kB)
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.2/50.2 kB 4.1 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting yupi==0.12.5
  Downloading yupi-0.12.5-py3-none-any.whl.metadata (3.3 kB)
Collecting nudged>=0.3.1 (from yupi==0.12.5)
  Downloading nudged-0.3.1-py2.py3-none-any.whl.metadata (5.8 kB)
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-no

In [3]:
# form symbolic link with traj_xai codebase
!ln -s "/content/drive/MyDrive/XAI4Traj/traj_xai" "/content/traj_xai"

In [4]:
import sys
sys.path.append('/content/traj_xai')

In [5]:
import os
from joblib import dump, load
from pactus import Dataset
from pactus import featurizers
from pactus.models import KNeighborsModel, RandomForestModel, LSTMModel

from traj_xai.src import SimpleGRUModel, SimpleTransformerModel, TrajFormerModel

In [6]:
SEED = 42
# SEED = None

datasets = [
    Dataset.animals,  # 0
    Dataset.geolife,  # 1
    Dataset.hurdat2,  # 2
    Dataset.uci_pen_digits,  # 3
    Dataset.uci_movement_libras,  # 4
    Dataset.uci_characters,  # 5
    Dataset.mnist_stroke  # 6
]

# Load dataset
dataset_idx = 1
dataset = datasets[dataset_idx]()
print(f"Dataset loaded: {len(dataset.trajs)} trajectories ({datasets[dataset_idx].__name__})")

# Split data into train and test subsets
train, test = dataset.split(0.8, random_state=SEED)
print(f"Train set: {len(train.trajs)} trajectories")
print(f"Test set: {len(test.trajs)} trajectories")

Dataset loaded: 9288 trajectories (geolife)
Train set: 7430 trajectories
Test set: 1858 trajectories


In [7]:
# inspect the length of trajectories
lens = []
for traj in dataset.trajs:
  lens.append(len(traj.r))
# print(f"Trajectory Info:\nMin {min(lens)}\nAvg {sum(lens)/len(lens):.2f}\nMax {max(lens)}")
print(f"{min(lens)}\t{sum(lens)/len(lens):.2f}\t{max(lens)}")

6	523.06	39419


In [8]:
models = [
    lambda: RandomForestModel(featurizer=featurizers.UniversalFeaturizer(), n_jobs=-1, random_state=SEED),  # 0
    lambda: KNeighborsModel(featurizer=featurizers.UniversalFeaturizer(), random_state=SEED),  # 1
    lambda: LSTMModel(random_state=SEED),  # 2
    lambda: SimpleGRUModel(random_state=SEED),  # 3
    lambda: SimpleTransformerModel(random_state=SEED),  # 4
    lambda: TrajFormerModel(c_out=len(dataset.classes), random_state=SEED)  # 5
]

# build and train the model
model_idx = 5
model = models[model_idx]()
if model_idx < 2:
    model.train(train, 5)
else:
    model.train(train, dataset, epochs=20, batch_size=64)



In [9]:
# Evaluate the model on the test set
evaluation = model.evaluate(test)
evaluation.show()


General statistics:

Accuracy: 0.462
F1-score: 0.190
Mean precision: 0.358
Mean recall: 0.192

Confusion matrix:

airplane  bike      boat      bus       car       run       subway    taxi      train     walk      precision 
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       6.15      0.0       0.0       0.65      0.0       0.0       1.96      0.0       2.08      50.0      
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       1.94      0.0       17.21     9.68      0.0       9.4       6.86      6.06      7.91      38.18     
33.33     1.94      100.0     3.01      56.77     0.0       6.84      6.86      18.18     1.04      64.71     
0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       0.0       
0.0       0.32      0.0       3.28      1.94      0.0       8.55      1.96      0.0       0.78      29.41   

In [10]:
def get_filename():
    model_str = ["rf", "knn", "lstm", "gru", "transformer", "trajformer"]
    return f"{model_str[model_idx]}_{datasets[dataset_idx].__name__}"

In [11]:
persist_dir = "drive/MyDrive/XAI4Traj/models"
persist_filename = get_filename()
persist_path = f"{persist_dir}/{persist_filename}.joblib"

os.makedirs(persist_dir, exist_ok=True)
dump(model, persist_path)

['drive/MyDrive/XAI4Traj/models/trajformer_geolife.joblib']

In [12]:
new_model = load(persist_path)

In [None]:
evaluation2 = new_model.evaluate(test)
evaluation2.show()

In [None]:
# sanity check that the persisted model gets the same results
evaluation.y_pred == evaluation2.y_pred

True