# Track filtering/fitting with LSTMs

This is a continuous space model using the ACTS data.

In this notebook we load a pre-trained model and simply make some evaluation plots.

In [1]:
from __future__ import print_function

# System imports
import os

# Data libraries
import numpy as np
import pandas as pd

# Torch imports
import torch

# Visualization
import matplotlib.pyplot as plt

# Local imports
import torchutils
torchutils.set_cuda(False)
from torchutils import np_to_torch, torch_to_np
from track_filter import coord_scale

# Magic
%matplotlib notebook

In [2]:
torch.__version__

'0.3.0'

## Read the data

In [3]:
data_dir = '/global/cscratch1/sd/sfarrell/heptrkx/filter_data_005'
train_dir = '/global/cscratch1/sd/sfarrell/heptrkx/RNNFilter_006'
#train_dir = '/global/cscratch1/sd/sfarrell/heptrkx/RNNFilter_20180203_2242'
#train_dir = '/global/cscratch1/sd/sfarrell/heptrkx/RNNFilter_20180202_1422'
#train_dir = '/global/cscratch1/sd/sfarrell/heptrkx/RNNFilter_20171212_1455'
model_file = os.path.join(train_dir, 'model')
losses_file = os.path.join(train_dir, 'losses.npz')

In [4]:
#train_data = np.load(os.path.join(data_dir, 'train_data.npy'))
test_data = np.load(os.path.join(data_dir, 'test_data.npy'))

losses_data = np.load(losses_file)
train_losses = losses_data['train_losses']
valid_losses = losses_data['valid_losses']

# Load the pre-trained model
model = torch.load(model_file)
model.eval()
model

HitPredictor(
  (lstm): LSTM(3, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=2)
)

In [5]:
# Inputs are the hits from [0, N-1)
# Targets are the hits from [1, N) without the 'r' feature.
#train_input = np_to_torch(train_data[:,:-1])
#train_target = np_to_torch(train_data[:,1:,:2])
test_input = np_to_torch(test_data[:,:-1])
test_target = np_to_torch(test_data[:,1:,:2])

## Training loss

In [6]:
plt.figure()
plt.semilogy(train_losses, label='Training set')
plt.semilogy(valid_losses, label='Validation set')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training loss')
plt.legend(loc=0)
plt.tight_layout()

<IPython.core.display.Javascript object>

## Evaluate model performance

In [7]:
# Get the predictions and errors for the full training set
#train_output = model(train_input)
#train_error = train_output - train_target
#train_resid = train_error.cpu().data.numpy() * coord_scale[:2]

# Get the predictions and errors for the full test set
test_output = model(test_input)
test_error = test_output - test_target
test_resid = test_error.cpu().data.numpy() * coord_scale[:2]

Let's draw those curves with the different layers separated out.

In [8]:
plt.figure(figsize=(9,4))
plt.subplot(121)
hist_args = dict(bins=100, range=(-0.015, 0.015),
                 normed=False, log=True, histtype='step')
# Draw errors on first, last, and all prediction layers separately
plt.hist(test_resid[:,:1,0].flatten(), label='First layer', **hist_args)
plt.hist(test_resid[:,1:-2,0].flatten(), label='Middle layers', **hist_args)
plt.hist(test_resid[:,-2:,0].flatten(), label='Last two layers', **hist_args)
plt.hist(test_resid[:,:,0].flatten(), label='All predictions', **hist_args)
plt.xlabel('Error in $\phi$ [rad]')
plt.ylabel('Predicted hits')
plt.ylim(ymin=1)
plt.legend(loc=0)

plt.subplot(122)
hist_args = dict(bins=100, range=(-80, 80),
                 normed=False, log=True, histtype='step')
plt.hist(test_resid[:,:1,1].flatten(), label='First layer', **hist_args)
plt.hist(test_resid[:,1:-2,1].flatten(), label='Middle layers', **hist_args)
plt.hist(test_resid[:,-2:,1].flatten(), label='Last two layers', **hist_args)
plt.hist(test_resid[:,:,1].flatten(), label='All predictions', **hist_args)
plt.xlabel('Error in z [mm]')
plt.ylabel('Predicted hits')
plt.ylim(ymin=10)
plt.legend(loc=0)
plt.tight_layout()

<IPython.core.display.Javascript object>

In [9]:
plt.figure(figsize=(9,4))
plt.subplot(121)
hist_args = dict(bins=100, range=(-0.0025, 0.0025),
                 normed=False, log=False, histtype='step')
# Draw errors on first, last, and all prediction layers separately
plt.hist(test_resid[:,:1,0].flatten(), label='First layer', **hist_args)
plt.hist(test_resid[:,1:-2,0].flatten(), label='Middle layers', **hist_args)
plt.hist(test_resid[:,-2:,0].flatten(), label='Last two layers', **hist_args)
plt.hist(test_resid[:,:,0].flatten(), label='All predictions', **hist_args)
plt.xlabel('Error in $\phi$ [rad]')
plt.ylabel('Predicted hits')
plt.ylim(ymin=1)
plt.legend(loc=0)

plt.subplot(122)
hist_args = dict(bins=100, range=(-5, 5),
                 normed=False, log=False, histtype='step')
plt.hist(test_resid[:,:1,1].flatten(), label='First layer', **hist_args)
plt.hist(test_resid[:,1:-2,1].flatten(), label='Middle layers', **hist_args)
plt.hist(test_resid[:,-2:,1].flatten(), label='Last two layers', **hist_args)
plt.hist(test_resid[:,:,1].flatten(), label='All predictions', **hist_args)
plt.xlabel('Error in z [mm]')
plt.ylabel('Predicted hits')
#plt.ylim(ymin=100)
plt.legend(loc=0)
plt.tight_layout()

<IPython.core.display.Javascript object>

## Visualize trajectories

In [10]:
for i in range(10):
    layers = np.arange(10)
    inputs = torch_to_np(test_input[i]) * coord_scale
    outputs = torch_to_np(test_output[i]) * coord_scale[:2]
    targets = torch_to_np(test_target[i]) * coord_scale[:2]

    plt.figure(figsize=(9,3))

    plt.subplot(121)
    plt.plot(layers[:-1], inputs[:,0], 'b.-')
    plt.plot(layers[1:], targets[:,0], 'b.-', label='Data')
    plt.plot(layers[1:], outputs[:,0], 'r.-', label='Filter')
    plt.xlabel('Detector layer')
    plt.ylabel('$\phi$ [rad]')
    plt.legend(loc=0)

    plt.subplot(122)
    plt.plot(layers[:-1], inputs[:,1], 'b.-')
    plt.plot(layers[1:], targets[:,1], 'b.-', label='Data')
    plt.plot(layers[1:], outputs[:,1], 'r.-', label='Filter')
    plt.xlabel('Detector layer')
    plt.ylabel('z [mm]')
    plt.legend(loc=0)

    plt.tight_layout()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

#### Observations

- After centering all tracks in $\phi$ on the first hit, performance has improved dramatically. Now the trajectories look great and the residuals are pretty well behaved.