In [20]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KDTree

In [21]:
def merge_naive(pred_1, pred_2, cutoff=20):
    """
    naive cluster merging:
    iterate over hits; if a hit belongs to a larger cluster in pred_2, it is reassigned
    """
    if pred_1 is None:
        return pred_2
    c1, c2 = Counter(pred_1), Counter(pred_2)  # track id -> track size
    n1, n2 = np.vectorize(c1.__getitem__)(pred_1), np.vectorize(
        c2.__getitem__)(pred_2)  # hit id -> track size
    pred = pred_1.copy()
    idx = (n2 > n1) & (n2 < cutoff)
    pred[idx] = max(pred_1) + 1 + pred_2[idx]
    return label_encode(pred)

In [23]:
df=pd.read_csv("submission.csv")
df2=pd.read_csv("hough2.csv")
df=df[['event_id', 'hit_id', 'track_id']]
df2=df2[['event_id', 'hit_id', 'track_id']]

df['track_id']=merge(df['track_id'],df2['track_id'])
df=df[['event_id', 'hit_id', 'track_id']]

In [25]:
root="/home/alexanderliao/data/Kaggle/competitions/trackml-particle-identification/test/"

for i in tqdm(range(125)):
    if i==0:
        hits=pd.read_csv(root+"event"+str(i).zfill(9)+"-hits.csv")
    else:
        temp=pd.read_csv(root+"event"+str(i).zfill(9)+"-hits.csv")
        hits=pd.concat([hits,temp])

100%|██████████| 125/125 [01:12<00:00,  1.72it/s]


In [26]:
hits[['hit_id', 'x','y','z']].to_csv("feature.csv",index=False)

In [27]:
hits=hits[['hit_id','x','y','z']]

In [31]:
index = 0
for i in tqdm(range(125)):
    event = df.loc[df['event_id'] == i]
    temp = pd.merge(event,hits.iloc[index:index+event.shape[0]],on='hit_id')
    index=index+event.shape[0]
    #print(index)
    temp[['event_id','hit_id', 'track_id','x','y','z']].to_csv("./events/event"+str(i).zfill(3)+".csv",index=False)

100%|██████████| 125/125 [01:17<00:00,  1.61it/s]


In [32]:
index

13741466

In [36]:
def extend(df,limit=0.04, num_neighbours=18):

    #df = submission.merge(hits,  on=['hit_id'], how='left')
    df = df.assign(d = np.sqrt( df.x**2 + df.y**2 + df.z**2 ))
    df = df.assign(r = np.sqrt( df.x**2 + df.y**2))
    df = df.assign(arctan2 = np.arctan2(df.z, df.r))

    for angle in [x / 10.0 for x in range(-900, 900, 5)]:

        print ('\r %f'%angle, end='',flush=True)
        #df1 = df.loc[(df.arctan2>(angle-0.5)/180*np.pi) & (df.arctan2<(angle+0.5)/180*np.pi)]
        df1 = df.loc[(df.arctan2>(angle-1.5)/180*np.pi) & (df.arctan2<(angle+1.5)/180*np.pi)]

        min_num_neighbours = len(df1)
        if min_num_neighbours<3: continue

        hit_ids = df1.hit_id.values
        x,y,z = df1[['x', 'y', 'z']].values.T
        r  = (x**2 + y**2)**0.5
        r  = r/1000
        a  = np.arctan2(y,x)
        c = np.cos(a)
        s = np.sin(a)
        #tree = KDTree(np.column_stack([a,r]), metric='euclidean')
        tree = KDTree(np.column_stack([c, s, r]), metric='euclidean')


        track_ids = list(df1.track_id.unique())
        num_track_ids = len(track_ids)
        min_length=3

        for i in range(num_track_ids):
            p = track_ids[i]
            if p==0: continue

            idx = np.where(df1.track_id==p)[0]
            if len(idx)<min_length: continue

            if angle>0:
                idx = idx[np.argsort( z[idx])]
            else:
                idx = idx[np.argsort(-z[idx])]


            ## start and end points  ##
            idx0,idx1 = idx[0],idx[-1]
            a0 = a[idx0]
            a1 = a[idx1]
            r0 = r[idx0]
            r1 = r[idx1]
            c0 = c[idx0]
            c1 = c[idx1]
            s0 = s[idx0]
            s1 = s[idx1]

            da0 = a[idx[1]] - a[idx[0]]  #direction
            dr0 = r[idx[1]] - r[idx[0]]
            direction0 = np.arctan2(dr0,da0)

            da1 = a[idx[-1]] - a[idx[-2]]
            dr1 = r[idx[-1]] - r[idx[-2]]
            direction1 = np.arctan2(dr1,da1)



            ## extend start point
            ns = tree.query([[c0, s0, r0]], k=min(num_neighbours, min_num_neighbours), return_distance=False)
            ns = np.concatenate(ns)

            direction = np.arctan2(r0 - r[ns], a0 - a[ns])
            diff = 1 - np.cos(direction - direction0)
            ns = ns[(r0 - r[ns] > 0.01) & (diff < (1 - np.cos(limit)))]
            for n in ns: df.loc[df.hit_id == hit_ids[n], 'track_id'] = p

            ## extend end point
            ns = tree.query([[c1, s1, r1]], k=min(num_neighbours, min_num_neighbours), return_distance=False)
            ns = np.concatenate(ns)

            direction = np.arctan2(r[ns] - r1, a[ns] - a1)
            diff = 1 - np.cos(direction - direction1)
            ns = ns[(r[ns] - r1 > 0.01) & (diff < (1 - np.cos(limit)))]
            for n in ns:  df.loc[df.hit_id == hit_ids[n], 'track_id'] = p

    #print ('\r')
    df = df[['event_id', 'hit_id', 'track_id','x','y','z']]
    return df

In [None]:
for i in tqdm(range(125)):
    result=pd.read_csv("./events/event"+str(i).zfill(3)+".csv")
    #print("before:{}".format(result.shape[0]))
    for j in range(8):
        result=extend(result,limit=0.05, num_neighbours=18)
    #print("after:{}".format(result.shape[0]))
    result[['event_id', 'hit_id', 'track_id']].to_csv("./enhanced/enhanced"+str(i).zfill(3)+".csv",index=False)

  0%|          | 0/125 [00:00<?, ?it/s]

 89.5000000

  1%|          | 1/125 [07:43<15:57:02, 463.09s/it]

 89.5000000

  2%|▏         | 2/125 [15:29<15:53:10, 464.97s/it]

 89.5000000

  2%|▏         | 3/125 [22:18<15:06:55, 446.03s/it]

 89.5000000

  3%|▎         | 4/125 [29:09<14:42:11, 437.45s/it]

 89.5000000

  4%|▍         | 5/125 [36:48<14:43:17, 441.65s/it]

 89.5000000

  5%|▍         | 6/125 [43:42<14:26:43, 437.00s/it]

 89.5000000

  6%|▌         | 7/125 [50:40<14:14:08, 434.31s/it]

 89.5000000

  6%|▋         | 8/125 [58:01<14:08:37, 435.19s/it]

 89.5000000

  7%|▋         | 9/125 [1:05:19<14:01:59, 435.51s/it]

 89.5000000

  8%|▊         | 10/125 [1:12:35<13:54:45, 435.52s/it]

 89.5000000

  9%|▉         | 11/125 [1:19:57<13:48:36, 436.11s/it]

 89.5000000

 10%|▉         | 12/125 [1:26:51<13:37:54, 434.29s/it]

 89.5000000

 10%|█         | 13/125 [1:33:33<13:25:58, 431.77s/it]

 89.5000000

 11%|█         | 14/125 [1:41:23<13:23:54, 434.54s/it]

 89.5000000

 12%|█▏        | 15/125 [1:47:59<13:11:54, 431.95s/it]

 89.5000000

 13%|█▎        | 16/125 [1:54:00<12:56:44, 427.56s/it]

 89.5000000

 14%|█▎        | 17/125 [2:00:09<12:43:18, 424.06s/it]

 89.5000000

 14%|█▍        | 18/125 [2:08:30<12:43:55, 428.37s/it]

 89.5000000

 15%|█▌        | 19/125 [2:14:43<12:31:40, 425.47s/it]

 89.5000000

 16%|█▌        | 20/125 [2:21:09<12:21:04, 423.47s/it]

 89.5000000

 17%|█▋        | 21/125 [2:25:58<12:02:56, 417.09s/it]

 89.5000000

 18%|█▊        | 22/125 [2:32:04<11:51:57, 414.73s/it]

 89.5000000

 18%|█▊        | 23/125 [2:38:55<11:44:47, 414.58s/it]

 89.5000000

 19%|█▉        | 24/125 [2:45:43<11:37:26, 414.33s/it]

 89.5000000

 20%|██        | 25/125 [2:52:26<11:29:45, 413.86s/it]

 89.5000000

 21%|██        | 26/125 [2:58:49<11:20:52, 412.66s/it]

 89.5000000

 22%|██▏       | 27/125 [3:04:27<11:09:30, 409.90s/it]

 89.5000000

 22%|██▏       | 28/125 [3:10:16<10:59:09, 407.72s/it]

 89.5000000

 23%|██▎       | 29/125 [3:16:34<10:50:43, 406.70s/it]

 89.5000000

 24%|██▍       | 30/125 [3:21:43<10:38:46, 403.43s/it]

 89.5000000

 25%|██▍       | 31/125 [3:28:12<10:31:20, 402.98s/it]

 89.5000000

 26%|██▌       | 32/125 [3:34:08<10:22:21, 401.52s/it]

 89.5000000

 26%|██▋       | 33/125 [3:39:02<10:10:38, 398.25s/it]

 89.5000000

 27%|██▋       | 34/125 [3:44:05<9:59:46, 395.46s/it] 

 89.5000000

 28%|██▊       | 35/125 [3:51:15<9:54:40, 396.45s/it]

 89.5000000

 29%|██▉       | 36/125 [3:57:19<9:46:43, 395.55s/it]

 89.5000000

 30%|██▉       | 37/125 [4:03:43<9:39:40, 395.23s/it]

 89.5000000

 30%|███       | 38/125 [4:10:46<9:34:08, 395.96s/it]

 89.5000000

 31%|███       | 39/125 [4:16:54<9:26:30, 395.24s/it]

 89.5000000

 32%|███▏      | 40/125 [4:23:48<9:20:34, 395.70s/it]

 89.5000000

 33%|███▎      | 41/125 [4:30:38<9:14:28, 396.05s/it]

 89.5000000

 34%|███▎      | 42/125 [4:37:03<9:07:31, 395.80s/it]

 89.5000000

 34%|███▍      | 43/125 [4:43:22<9:00:22, 395.40s/it]

 89.5000000

 35%|███▌      | 44/125 [4:49:51<8:53:35, 395.26s/it]

 89.5000000

 36%|███▌      | 45/125 [4:57:43<8:49:16, 396.96s/it]

 89.5000000

 37%|███▋      | 46/125 [5:03:30<8:41:15, 395.89s/it]

 89.5000000

 38%|███▊      | 47/125 [5:09:07<8:33:00, 394.62s/it]

 89.5000000

 38%|███▊      | 48/125 [5:14:22<8:24:18, 392.97s/it]

 89.5000000

 39%|███▉      | 49/125 [5:20:25<8:16:58, 392.35s/it]

 89.5000000

 40%|████      | 50/125 [5:27:14<8:10:52, 392.69s/it]

 89.5000000

 41%|████      | 51/125 [5:34:01<8:04:39, 392.96s/it]

 89.5000000

 42%|████▏     | 52/125 [5:40:27<7:57:56, 392.83s/it]

 89.5000000

 42%|████▏     | 53/125 [5:47:01<7:51:25, 392.86s/it]

 89.5000000

 43%|████▎     | 54/125 [5:53:15<7:44:27, 392.50s/it]

 89.5000000

 44%|████▍     | 55/125 [6:00:02<7:38:14, 392.78s/it]

 89.5000000

 45%|████▍     | 56/125 [6:05:44<7:30:39, 391.87s/it]

 89.5000000

 46%|████▌     | 57/125 [6:12:15<7:24:06, 391.85s/it]

 89.5000000

 46%|████▋     | 58/125 [6:19:20<7:18:11, 392.42s/it]

 89.5000000

 47%|████▋     | 59/125 [6:26:17<7:12:07, 392.85s/it]

 89.5000000

 48%|████▊     | 60/125 [6:32:44<7:05:28, 392.75s/it]

 89.5000000

 49%|████▉     | 61/125 [6:38:47<6:58:23, 392.25s/it]

 89.5000000

 50%|████▉     | 62/125 [6:44:52<6:51:24, 391.82s/it]

 89.5000000

 50%|█████     | 63/125 [6:51:21<6:44:49, 391.77s/it]

 89.5000000

 51%|█████     | 64/125 [6:58:08<6:38:32, 392.01s/it]

 89.5000000

 52%|█████▏    | 65/125 [7:04:48<6:32:08, 392.13s/it]

 89.5000000

 53%|█████▎    | 66/125 [7:10:17<6:24:39, 391.17s/it]

 89.5000000

 54%|█████▎    | 67/125 [7:16:40<6:18:00, 391.04s/it]

 89.5000000

 54%|█████▍    | 68/125 [7:23:49<6:12:01, 391.60s/it]

 89.5000000

 55%|█████▌    | 69/125 [7:30:08<6:05:19, 391.43s/it]

 89.5000000

 56%|█████▌    | 70/125 [7:36:00<5:58:17, 390.87s/it]

 89.5000000

 57%|█████▋    | 71/125 [7:42:00<5:51:22, 390.42s/it]

 89.5000000

 58%|█████▊    | 72/125 [7:48:53<5:45:09, 390.75s/it]

 89.5000000

 58%|█████▊    | 73/125 [7:55:05<5:38:24, 390.48s/it]

 89.5000000

 59%|█████▉    | 74/125 [8:01:30<5:31:50, 390.41s/it]

 89.5000000

 60%|██████    | 75/125 [8:08:05<5:25:23, 390.47s/it]

 89.5000000

 61%|██████    | 76/125 [8:13:50<5:18:23, 389.88s/it]

 89.5000000

 62%|██████▏   | 77/125 [8:20:12<5:11:49, 389.78s/it]

 -39.000000

In [40]:
#del temp
#del final
for i in tqdm(range(125)):
    temp=pd.read_csv("./enhanced/enhanced"+str(i).zfill(3)+".csv")
    temp=temp[['event_id', 'hit_id', 'track_id']]
    print(temp.shape[0])
    print(i)
    if i==0:
        final=temp
    else:
        final=pd.concat([final,temp])
        #print(final.shape[0])

  3%|▎         | 4/125 [00:00<00:07, 16.20it/s]

119016
0
122319
1
105772
2
109438
3


  7%|▋         | 9/125 [00:00<00:06, 18.99it/s]

119419
4
107928
5
107605
6
116719
7
115088
8


 10%|█         | 13/125 [00:00<00:05, 18.83it/s]

113699
9
117029
10
107033
11
104052
12


 12%|█▏        | 15/125 [00:00<00:06, 18.19it/s]

125834
13
102454
14
90141
15
97005
16


 16%|█▌        | 20/125 [00:01<00:05, 18.06it/s]

131777
17
98028
18
104746
19
80539
20


 19%|█▉        | 24/125 [00:01<00:05, 17.20it/s]

106517
21
119762
22
116472
23


 21%|██        | 26/125 [00:01<00:05, 16.55it/s]

116987
24
109625
25
95306
26


 22%|██▏       | 28/125 [00:01<00:06, 15.74it/s]

100485
27
107973
28
88746
29


 26%|██▌       | 32/125 [00:02<00:06, 14.69it/s]

112836
30
102966
31


 27%|██▋       | 34/125 [00:02<00:06, 14.21it/s]

83605
32
85798
33


 29%|██▉       | 36/125 [00:02<00:06, 13.77it/s]

124264
34
104358
35


 30%|███       | 38/125 [00:02<00:06, 13.47it/s]

110932
36
123111
37
106542
38
121116
39


 33%|███▎      | 41/125 [00:03<00:06, 12.72it/s]

118201
40
110271
41


 34%|███▍      | 43/125 [00:03<00:06, 12.10it/s]

109476
42
110596
43


 36%|███▌      | 45/125 [00:03<00:06, 11.95it/s]

138107
44
99048
45


 38%|███▊      | 47/125 [00:04<00:06, 11.62it/s]

95210
46
89037
47


 40%|████      | 50/125 [00:04<00:06, 11.11it/s]

104819
48
120085
49


 42%|████▏     | 52/125 [00:04<00:06, 10.89it/s]

118453
50
112472
51


 42%|████▏     | 53/125 [00:04<00:06, 10.81it/s]

113710
52
108741
53


 44%|████▍     | 55/125 [00:05<00:06, 10.59it/s]

118718
54
98611
55


 46%|████▌     | 57/125 [00:05<00:06, 10.28it/s]

113164
56
121400
57


 47%|████▋     | 59/125 [00:05<00:06, 10.04it/s]

120011
58
112732
59


 49%|████▉     | 61/125 [00:06<00:06,  9.88it/s]

102260
60
105668
61


 50%|█████     | 63/125 [00:06<00:06,  9.72it/s]

113210
62
118523
63


 52%|█████▏    | 65/125 [00:06<00:06,  9.55it/s]

114683
64
95282
65


 54%|█████▎    | 67/125 [00:07<00:06,  9.42it/s]

111204
66
124975
67


 55%|█████▌    | 69/125 [00:07<00:06,  9.28it/s]

108355
68
100970
69


 57%|█████▋    | 71/125 [00:07<00:05,  9.10it/s]

101706
70
116858
71


 58%|█████▊    | 73/125 [00:08<00:05,  8.99it/s]

106262
72
110843
73


 60%|██████    | 75/125 [00:08<00:05,  8.88it/s]

114323
74
98515
75


 62%|██████▏   | 77/125 [00:08<00:05,  8.78it/s]

103049
76
91801
77


 63%|██████▎   | 79/125 [00:09<00:05,  8.65it/s]

111263
78
92759
79


 65%|██████▍   | 81/125 [00:09<00:05,  8.56it/s]

128605
80
106909
81


 66%|██████▋   | 83/125 [00:09<00:04,  8.44it/s]

107450
82
124662
83


 68%|██████▊   | 85/125 [00:10<00:04,  8.32it/s]

96579
84
113056
85


 70%|██████▉   | 87/125 [00:10<00:04,  8.23it/s]

121306
86
120271
87


 71%|███████   | 89/125 [00:10<00:04,  8.13it/s]

117696
88
118666
89


 73%|███████▎  | 91/125 [00:11<00:04,  7.98it/s]

111010
90
103530
91


 74%|███████▍  | 93/125 [00:11<00:04,  7.89it/s]

99246
92
133648
93


 76%|███████▌  | 95/125 [00:12<00:03,  7.80it/s]

97598
94


 77%|███████▋  | 96/125 [00:12<00:03,  7.73it/s]

115003
95


 78%|███████▊  | 97/125 [00:12<00:03,  7.70it/s]

104648
96
112038
97


 79%|███████▉  | 99/125 [00:12<00:03,  7.62it/s]

115998
98


 80%|████████  | 100/125 [00:13<00:03,  7.58it/s]

115489
99


 81%|████████  | 101/125 [00:13<00:03,  7.51it/s]

129107
100


 82%|████████▏ | 102/125 [00:13<00:03,  7.45it/s]

127561
101
113312
102


 83%|████████▎ | 104/125 [00:14<00:02,  7.38it/s]

102112
103


 84%|████████▍ | 105/125 [00:14<00:02,  7.34it/s]

117401
104


 85%|████████▍ | 106/125 [00:14<00:02,  7.28it/s]

92018
105


 86%|████████▌ | 107/125 [00:14<00:02,  7.24it/s]

113413
106


 86%|████████▋ | 108/125 [00:15<00:02,  7.18it/s]

128091
107


 87%|████████▋ | 109/125 [00:15<00:02,  7.14it/s]

113817
108
116426
109


 89%|████████▉ | 111/125 [00:15<00:02,  7.00it/s]

113769
110


 90%|████████▉ | 112/125 [00:16<00:01,  6.96it/s]

117204
111


 90%|█████████ | 113/125 [00:16<00:01,  6.90it/s]

108254
112
126621
113


 92%|█████████▏| 115/125 [00:16<00:01,  6.80it/s]

98236
114


 93%|█████████▎| 116/125 [00:17<00:01,  6.77it/s]

99882
115


 94%|█████████▎| 117/125 [00:17<00:01,  6.74it/s]

118413
116
116556
117


 94%|█████████▍| 118/125 [00:17<00:01,  6.70it/s]

105372
118


 96%|█████████▌| 120/125 [00:18<00:00,  6.63it/s]

87086
119


 97%|█████████▋| 121/125 [00:18<00:00,  6.61it/s]

106090
120


 98%|█████████▊| 122/125 [00:18<00:00,  6.57it/s]

116194
121
92200
122


 99%|█████████▉| 124/125 [00:19<00:00,  6.49it/s]

102393
123


100%|██████████| 125/125 [00:19<00:00,  6.47it/s]

108117
124





In [42]:
final[['event_id', 'hit_id', 'track_id']]=final[['event_id', 'hit_id', 'track_id']].fillna(0).astype('int') 
final.to_csv('./enhanced.csv',index=False)


In [None]:
a=pd.read_csv("enhanced.csv")

In [None]:
a #sanity checks