In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from trackml.dataset import load_event, load_dataset 
from trackml.randomize import shuffle_hits
from trackml.score import score_event 
import os
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import time
from multiprocessing import Pool
%matplotlib notebook
%run utils.ipynb

In [None]:
hits_from_seeds = np.load('./hits_from_seeds.npy')

In [None]:
hits_from_seeds.shape

In [None]:

hits, cells, particles, truth = load_data_single_event(1050)
unique_lv_pairs = []
for row in hits.iloc[:,4:6].itertuples():
    lv_pair = (row[1], row[2])
    if lv_pair not in unique_lv_pairs:
        unique_lv_pairs.append(lv_pair)
"""
fcl_dict = {}
for lv_pair in unique_lv_pairs:
    lv = str(lv_pair[0]) + "-" + str(lv_pair[1])
    inputs = tf.placeholder(tf.float32, [None, 1, 3])
    num_outputs = 30
    layer = tf.contrib.layers.fully_connected(inputs, num_outputs, activation_fn=tf.nn.relu)
    fcl_dict[lv] = layer
"""
def get_fcl(layer_id, volume_id):
    lv_pair = str(layer_id) + "-" + str(volume_id)
    for key, value in fcl_dict.items():
        if lv_pair == key:
            return value

In [None]:
"""
Encapsulates the NN used for next-hit prediction 
"""
class predict_engine(): 
    N_steps = 10
    N_INPUT_FEATURES = 5
    N_NEURONS = 200
    N_OUTPUTS = 3
    MODEL_PATH = './checkpoints/baseline-idealv4/-33103.meta'
    def __init__(self):
        init = tf.global_variables_initializer()
        
        infSess = tf.Session() 
        init.run(session = infSess)
            
        saver = tf.train.import_meta_graph(self.MODEL_PATH)
        graph = saver.restore(infSess,'./checkpoints/baseline-idealv4/-33103' )
        #graph = tf.get_default_graph #hopefully this always points to the right one...
        predict_op = infSess.graph.get_tensor_by_name("print_op:0")
        self.predict_op = predict_op
        self.infSess = infSess 
    
    def pred(self):
        x_data, y_data, _, __ = next(next_batch(self.N_steps, 1, self.N_INPUT_FEATURES, ideal = True))
        """
            Need to modify pred function to not use train set data
        
        """
        
        
        inputs = tf.placeholder(tf.float32, [None, 10,5 ])
        input_tensor = self.infSess.graph.get_tensor_by_name('input_ph:0')
        prediction = self.infSess.run(self.predict_op, feed_dict={input_tensor:x_data})
        print('input was: ')
        print(x_data)
        print('\n')
        print('pred is: ')
        print(prediction)
        return prediction
    
    def close(self):
        self.infSess.close()
        
    
        

In [None]:
engine = predict_engine()

In [None]:
engine.pred()

In [None]:
def predict_next_hit(hits_so_far):
    n_steps = len(hits_so_far[0])
    n_input_features = 5
    n_neurons = 200
    n_outputs = 3

    X = tf.placeholder(tf.float32, [None, n_steps, n_input_features], name='input')
    
    lstm = tf.contrib.rnn.LSTMCell(num_units = n_neurons, use_peepholes = True)
    lstm_cell = tf.contrib.rnn.OutputProjectionWrapper(lstm, output_size = n_outputs, reuse=tf.AUTO_REUSE)
    rnn_outputs, states = tf.nn.dynamic_rnn(lstm_cell, X, dtype = tf.float32)
    global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
    print_op = tf.Print(rnn_outputs, [rnn_outputs], name='print_pred')
    init = tf.global_variables_initializer()
    
    with tf.Session() as lstm3sess:

        init.run()
        initial_step = global_step.eval()
        saver = tf.train.Saver()
        writer = tf.summary.FileWriter('./logs/lstm3', lstm3sess.graph)
        ckpt = tf.train.get_checkpoint_state(os.path.dirname('./checkpoints/lstm3/'))
        # if that checkpoint exists, restore from checkpoint
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(lstm3sess, ckpt.model_checkpoint_path)

        x_data = hits_so_far
        test_print = lstm3sess.run(print_op, feed_dict={X:x_data})
        return [test_print[0][-1][0], test_print[0][-1][1], test_print[0][-1][2]]

In [None]:
hits, cells, particles, truth = load_data_single_event(1050)
num_seeds = len(hits_from_seeds)
seeded_hits = next_seed(hits_from_seeds)
parameters = []
parameters = [[10, 100, num_seeds]] * len(hits_from_seeds)
for i in range(len(hits_from_seeds)):
    parameters[i].append([10, 100, num_seeds, next(seeded_hits)])

In [None]:
test = np.asarray(parameters)

In [None]:
start = time.time()
pool = Pool(4)
tracks, mse = pool.starmap(predict_tracks, parameters)
end = time.time()
print(end - start)

In [None]:
def predict_tracks(max_hits_per_track, error, num_seeds):
    #grab hit_id per hit
    seeded_hits = next_seed(hits_from_seeds)
    hits_from_tracks = []
    mse_list = []
    for i in range(num_seeds):
        feed_seed = next(seeded_hits)
        count = 0
        mse_list.append([])
        while (len(feed_seed[0]) <= max_hits_per_tracks - 1):
            count = count + 1
            predicted_hit = predict_next_hit(feed_seed)
            next_hit = get_xyz (closestHit (predicted_hit[0], predicted_hit[1], predicted_hit[2], hits) )
            hit_info = get_hit_info(next_hit[0], next_hit[1], next_hit[2])
            next_hit.extend((hit_info[0], hit_info[1]))
            mse = distance_between_two_points(predicted_hit, next_hit)
            mse_list[i].append(mse)
            if mse > error:
                break
            feed_seed[0].append(next_hit)
        hits_from_tracks.append(feed_seed)
    return hits_from_tracks, mse_list

In [None]:
hits, cells, particles, truth = load_data_single_event(1050)
#do MSE (distance between predicted and real) for every hit per track
#grab hit_id per hit
error = 100
max_hits_per_tracks = 10
seeded_hits = next_seed(hits_from_seeds)
hits_from_tracks = []
mse_list = []
start = time.time()
for i in range(len(hits_from_seeds)):
    print("iteration: " + str(i))
    feed_seed = next(seeded_hits)
    count = 0
    mse_list.append([])
    while (len(feed_seed[0]) <= max_hits_per_tracks - 1):
        count = count + 1
        #print("num predicted hits: " + str(count))
        #print("hits so far (feed_seed): ", feed_seed[0])
        predicted_hit = predict_next_hit(feed_seed)
        #print("predicted_hit: ", predicted_hit)
        next_hit = get_xyz (closestHit (predicted_hit[0], predicted_hit[1], predicted_hit[2], hits) )
        hit_info = get_hit_info(next_hit[0], next_hit[1], next_hit[2])
        next_hit.extend((hit_info[0], hit_info[1]))
        #print("next hit in sequence: ", next_hit)
        mse = distance_between_two_points(predicted_hit, next_hit)
        mse_list[i].append(mse)
        if mse > error:
            break
        feed_seed[0].append(next_hit)
        #print(feed_seed[0])
    #print(feed_seed)
    hits_from_tracks.append(feed_seed)
end = time.time()
print(end-start)