In [11]:
%run utils.ipynb
from trackml.dataset import load_event, load_dataset
from trackml.randomize import shuffle_hits
from trackml.score import score_event

In [12]:
def is_cylinder (volume_id, layer_id):
    return (((volume_id == 8 or volume_id == 13) and layer_id in [2,4,6,8])
           or (volume_id == 17 and layer_id in [2,4]))


def layer_number (volume_id, layer_id):
    assert (volume_id in [8, 13, 17]), "This volume_id is not in a cylinder!"
    assert (((volume_id == 8 or volume_id == 13) and layer_id in [2,4,6,8])
           or (volume_id == 17 and layer_id in [2,4])), "This layer_id is not in this volume_id"
    diction = {(8,2):1, (8,4):2, (8,6):3, (8,8):4, (13,2):5, (13,4):6, (13,6):7, (13,8):8, (17,2):9, (17,4):10}
    return diction[(volume_id, layer_id)]

"""
returns dataframe of hits that only lie on cylinders for a given event
"""
def get_cylinder_hits(hit_dataframe, truth_dataframe ):
    
    volume_layer_list = [(8,2),(8,4),(8,6),(8,8),(13,2),(13,4),(13,6),(13,8),(17,2),(17,4)]
    temp = [] 
    merge_by_hit_ids = hit_dataframe.merge(truth_dataframe, how = 'inner', on ='hit_id')
    for pair in volume_layer_list:
        volume, layer = pair
        test_data = merge_by_hit_ids[(merge_by_hit_ids['volume_id'].values == volume) & (merge_by_hit_ids['layer_id'].values == layer)]
        temp.append(test_data)
    return pd.concat(temp)
    
#set full_tracks flag if looking for tracks of len 10
#set ideal flag if looking for tracks of len 10 and only 1 hit per layer
#setting ideal flag will set/override full_tracks flag

def gen_simple_tracks(hit_df, truth_df, full_tracks = False, ideal = False, debug=False):
    if ideal is True:
        full_tracks = True
    assert(isinstance(hit_df, pd.DataFrame) and isinstance(truth_df, pd.DataFrame))
    hit_and_truth = get_cylinder_hits(hit_df, truth_df)
    prelim_tracks = gen_tracks(hit_and_truth).values
    if full_tracks is False:
        return prelim_tracks 
    
    #full_tracks flag must be true at this point
    rows_to_remove = []
    for row, track in enumerate(prelim_tracks):
        if len(track) is not 10: #10 for each vol-layer in detector
            rows_to_remove.append(row)
    full_tracks = np.delete(prelim_tracks, rows_to_remove)
    if ideal is False:
        if debug:
            print('prelim track num: '+str(len(prelim_tracks)))
            print('removed tracks: '+str(len(rows_to_remove)))
            print('remaining tracks: '+str(len(prelim_tracks) - len(rows_to_remove)))
        return full_tracks
    
    #ideal flag must be true at this point
    bad_rows = []
    compare = [x for x in range(1,11)] #to guarantee each hit in track per layer
    for idx, track in enumerate(full_tracks):
        layer_nums = []
        for hit in track:
            vol_id = hits.loc[hits['hit_id']==hit]['volume_id'].item()
            lay_id = hits.loc[hits['hit_id']==hit]['layer_id'].item()
            layer_nums.append(layer_number(vol_id, lay_id))
        if layer_nums != compare:
            bad_rows.append(idx)
    ideal_tracks = np.delete(full_tracks, bad_rows)
    if debug:
        print('full tracks: '+str(len(full_tracks)))
        print('non-ideal tracks: '+str(len(bad_rows)))
        print('ideal tracks: '+str(len(ideal_tracks)))
    return ideal_tracks 
            
def simple_batch_iter(hit_df, truth_df, batch_size, full_tracks = False, ideal = False):
    if ideal is True:
        full_tracks = True
    tracks = gen_simple_tracks(hit_df, truth_df, full_tracks, ideal)
    np.random.shuffle(tracks)
    remainder = len(tracks) % ba
    tch_size if len(tracks) % batch_size is not 0 else 0
    if remainder is not 0:
        modded_tracks = tracks[:-remainder]
    else:
        modded_tracks = tracks 
    assert(len(modded_tracks)%batch_size is 0)
    for batch in modded_tracks.reshape(-1,batch_size,1):
        yield batch
        
        
        
  

In [13]:
def get_data(max_seq_len, batch_size, feature_len, truth_df, hits_df, simple = False, full_tracks = False, ideal = False):
    hits = hits_df
    max_seq_len = max_seq_len
    if ideal is True:
        full_tracks = True
    if full_tracks is True:
        max_seq_len = 10
    b_size = batch_size
    features = feature_len #xyz or phi r z
    if simple is True:
        all_data = list(simple_batch_iter(hits_df, truth_df, batch_size, full_tracks, ideal))
    else:
        all_data = list(batch_iter(truth_df,b_size))
    
    #print(all_data)
    for result in all_data:
        batch = []
        batch_lv = []
        labels_tensor = []
        labels_tensor_lv = []
        for track_list in result:
            for hit_id in track_list:
                hit_coord = []
                track = []
                track_lb = []
                lv_pair = []
                lv_tensor = []
                label_coord = []
                for elem in hit_id:
                    x, y, z, layer_id, volume_id = hits.loc[hits['hit_id']== elem]['x'].item(), hits.loc[hits['hit_id']== elem]['y'].item(), hits.loc[hits['hit_id']== elem]['z'].item(), hits.loc[hits['hit_id'] == elem]['volume_id'].item(), hits.loc[hits['hit_id'] == elem]['layer_id'].item()
                    r,phi,z = cartesian_to_3d_polar(x,y,z)
                    hit_coord = [r,phi,z,layer_id, volume_id]
                    label_coord = [r,phi,z]
                    track.append(hit_coord)
                    track_lb.append(label_coord)
                    layer, volume = hits.loc[hits['hit_id'] == elem]['volume_id'].item(), hits.loc[hits['hit_id'] == elem]['layer_id'].item()
                    lv_pair = [layer, volume]
                    lv_tensor.append(lv_pair)
                zeros_to_add = max_seq_len - len(track)
                if zeros_to_add > 0:
                    add_array = np.zeros((zeros_to_add,feature_len))
                    add_array_lb = np.zeros((zeros_to_add, 3)) #3 is hardcoded for xyz/rphiz
                    add_array_lv = np.zeros((zeros_to_add, 2))
                    np_data = np.array(track)
                    np_data_lb = np.array(track_lb)
                    np_data_lv = np.array(lv_tensor)
                    padded_track_data  = np.append(np_data,add_array,axis=0)
                
                    padded_track_data_lb = np.append(np_data_lb, add_array_lb, axis=0)
                    padded_track_data_lv = np.append(np_data_lv, add_array_lv, axis=0)
                elif zeros_to_add < 0:
                    modded_track = track[:zeros_to_add]
                    modded_track_lv = lv_tensor[:zeros_to_add]
                    modded_track_lb = track_lb[:zeros_to_add]
                    padded_track_data_lb = np.array(modded_track_lb)
                    padded_track_data = np.array(modded_track)
                    padded_track_data_lv = np.array(modded_track_lv)
                else:
                    padded_track_data_lb = np.array(track_lb)
                    padded_track_data = np.array(track)
                    padded_track_data_lv = np.array(lv_tensor)
            
            row_label = padded_track_data_lb[1:]
            row_label_lv = padded_track_data_lv[1:]
            padded_row_label = np.append(row_label, np.zeros((1,3)), axis=0) #hardcoded 3 for xyz/rphiz
            padded_row_label_lv = np.append(row_label_lv, np.zeros((1,2)), axis=0)
            labels_tensor.append(padded_row_label)
            labels_tensor_lv.append(padded_row_label_lv)
            batch.append(padded_track_data)
            batch_lv.append(padded_track_data_lv)
            
        padded_batch_data = np.array(batch)
        padded_batch_data_lv = np.array(batch_lv)
        padded_labels = np.array(labels_tensor)
        padded_labels_lv = np.array(labels_tensor_lv)
        #print(padded_labels)
        #print(padded_labels_lv)
        yield padded_batch_data, padded_labels, padded_batch_data_lv, padded_labels_lv
        
def next_batch(max_seq_len, batch_size, feature_len):
    all_data = load_dataset('data/train_sample/', parts=['hits','truth'])
    for data in all_data:
        hit_df, truth_df = data[1], data[2]
        yield from get_data(max_seq_len, batch_size, feature_len, truth_df, hit_df)
        

In [14]:
def load_data_single_event(event_number):
    file_name = 'event00000' + str(event_number)
    event_id = file_name
    hits, cells, particles, truth = load_event('data/train_sample/'+event_id)
    return hits, cells, particles, truth
hits, cells, particles, truth_df = load_data_single_event(1000)

In [None]:
itr = simple_batch_iter(hits,truth_df,5, full_tracks=True)

In [16]:
a = get_data(10,1,5,truth_df,hits, simple = True, full_tracks = True, ideal = True)

In [17]:
data, label, data_lv, label_lv = next(a)

In [18]:
data

array([[[  32.50954756,   -2.60091572,   -3.52813005,    8.        ,
            2.        ],
        [  71.69335542,   -2.57546852,   -5.94687986,    8.        ,
            4.        ],
        [ 115.59768922,   -2.54950431,   -8.87187958,    8.        ,
            6.        ],
        [ 171.74125643,   -2.51454731,  -12.69690037,    8.        ,
            8.        ],
        [ 260.65754778,   -2.45791464,  -18.60000038,   13.        ,
            2.        ],
        [ 362.64783232,   -2.39263539,  -24.60000038,   13.        ,
            4.        ],
        [ 500.95353664,   -2.30318113,  -34.20000076,   13.        ,
            6.        ],
        [ 660.62577792,   -2.19581995,  -45.        ,   13.        ,
            8.        ],
        [ 816.22507131,   -2.08794826,  -54.40000153,   17.        ,
            2.        ],
        [1016.6482781 ,   -1.9347321 ,  -76.        ,   17.        ,
            4.        ]]])

In [19]:
label

array([[[  71.69335542,   -2.57546852,   -5.94687986],
        [ 115.59768922,   -2.54950431,   -8.87187958],
        [ 171.74125643,   -2.51454731,  -12.69690037],
        [ 260.65754778,   -2.45791464,  -18.60000038],
        [ 362.64783232,   -2.39263539,  -24.60000038],
        [ 500.95353664,   -2.30318113,  -34.20000076],
        [ 660.62577792,   -2.19581995,  -45.        ],
        [ 816.22507131,   -2.08794826,  -54.40000153],
        [1016.6482781 ,   -1.9347321 ,  -76.        ],
        [   0.        ,    0.        ,    0.        ]]])

In [20]:
data_lv

array([[[ 8,  2],
        [ 8,  4],
        [ 8,  6],
        [ 8,  8],
        [13,  2],
        [13,  4],
        [13,  6],
        [13,  8],
        [17,  2],
        [17,  4]]])

In [21]:
label_lv

array([[[ 8.,  4.],
        [ 8.,  6.],
        [ 8.,  8.],
        [13.,  2.],
        [13.,  4.],
        [13.,  6.],
        [13.,  8.],
        [17.,  2.],
        [17.,  4.],
        [ 0.,  0.]]])