# LSTM track finder for 2D toy data in PyTorch

Here, I'm going to get familiar with PyTorch by reproducing the models in LSTM_Toy2D.ipynb.

In [1]:
from __future__ import print_function

from timeit import default_timer as timer

import numpy as np

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from data import generate_straight_tracks, generate_uniform_noise, generate_track_bkg
from drawing import draw_2d_event, draw_2d_input_and_pred

%matplotlib notebook

## Prepare some data

In [2]:
# Detector parameters
det_width = 50
det_depth = 50
det_shape = (det_depth, det_width)
seed_size = 5

# Data config
n_events = 102400
n_bkg_tracks = 5
noise_prob = 0

In [3]:
# Generate some data
sig_tracks = generate_straight_tracks(n_events, det_shape).astype(np.float32)
bkg_tracks = generate_track_bkg(n_events, det_shape,
                                tracks_per_event=n_bkg_tracks,
                                skip_layers=seed_size).astype(np.float32)
noise = generate_uniform_noise(n_events, det_shape, prob=noise_prob).astype(np.float32)
events = sig_tracks + bkg_tracks + noise

## Define the model

In [4]:
class LSTMTrackFinder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super(LSTMTrackFinder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, 1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, input_dim)
    
    def forward(self, x):
        x, _ = self.lstm(x)
        # This might be slow, could be improved.
        # Trying this without the softmax, for the loss calculation.
        x = torch.stack([self.fc(x[:,i]) for i in range(x.size(1))], dim=1)
        #x = torch.stack([F.softmax(self.fc(x[:,i])) for i in range(x.size(1))], dim=1)
        return x

def logits_to_probs(logits):
    size = logits.size()
    return F.softmax(logits.view(-1, size[-1])).view(size)

def cost_function(logits, labels):
    # Flatten the batch and detector layer dimensions
    flat_logits = logits.view(-1, logits.size(-1))
    flat_labels = labels.view(-1)
    return F.cross_entropy(flat_logits, flat_labels)

In [5]:
train_input = Variable(torch.from_numpy(events))
train_labels = Variable(torch.from_numpy(sig_tracks.argmax(axis=2)))

In [6]:
# Model config
hidden_dim = 100

# Train config
n_epochs = 10
batch_size = 64

n_samples = len(train_input)
n_batches = (n_samples + batch_size - 1) // batch_size

In [7]:
# Create the model
model = LSTMTrackFinder(det_width, hidden_dim)
print(model)
print('Parameters:',sum(param.numel() for param in model.parameters()))
# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters())
#optimizer = torch.optim.RMSprop(model.parameters())

LSTMTrackFinder (
  (lstm): LSTM(50, 100, batch_first=True)
  (fc): Linear (100 -> 50)
)
Parameters: 65850


In [9]:
for param in model.parameters():
    print(param.size(), param.numel())

torch.Size([400, 50]) 20000
torch.Size([400, 100]) 40000
torch.Size([400]) 400
torch.Size([400]) 400
torch.Size([50, 100]) 5000
torch.Size([50]) 50


In [9]:
# Training loop over epochs
for i in range(n_epochs):
    
    print('Epoch', i)
    start_time = timer()
    sum_loss = 0

    # Loop over batches
    for j in np.arange(0, n_samples, batch_size):
        batch_input = train_input[j:j+batch_size]
        batch_labels = train_labels[j:j+batch_size]
        model.zero_grad()
        batch_logits = model(batch_input)
        batch_loss = cost_function(batch_logits, batch_labels)
        batch_loss.backward()
        optimizer.step()
        sum_loss += batch_loss
    
    end_time = timer()
    avg_loss = sum_loss.data.numpy()[0] / n_batches
    print('  average loss', avg_loss, 'time %gs' % (end_time - start_time))

Epoch 0
  average loss 2.16838989258 time 163.721s
Epoch 1
  average loss 1.53918838501 time 164.938s
Epoch 2
  average loss 1.35406570435 time 161.92s
Epoch 3
  average loss 1.2037386322 time 165.939s
Epoch 4
  average loss 1.09688018799 time 158.775s
Epoch 5
  average loss 0.954774932861 time 158.378s
Epoch 6
  average loss 0.876244430542 time 163.463s
Epoch 7
  average loss 0.776243286133 time 162.096s
Epoch 8
  average loss 0.747665252686 time 158.022s
Epoch 9
  average loss 0.682240219116 time 158.33s


In [None]:
# Calculate full training set predictions
train_logits = model(train_input)
train_preds = logits_to_probs(train_logits)
train_loss = cost_function(train_logits, train_labels)

print('Final train loss', train_loss.data.numpy()[0])

Final train loss 0.618681


In [None]:
# Compute accuracy
(train_logits.max(dim=2)[1] == train_labels) / train_labels.numel()

In [None]:
# Draw an event
i = 6
draw_2d_input_and_pred(train_input[i].data.numpy(), train_preds[i].data.numpy(), cmap='gray_r');