In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
import time

In [2]:
feature_list = ['x','y','w0','w1','w2','w3','w4','w5','w6','day','month','year', 'accuracy']
for i in range(0,24):
    feature_list.append('h' + str(i))

In [3]:
def prepare_data(df):    
    #Feature engineering
    df.x = df.x.values * fw[0]
    df.y = df.y.values * fw[1]
    initial_date = np.datetime64('2014-01-02T01:01', dtype='datetime64[m]') 
    d_times = pd.DatetimeIndex(initial_date + np.timedelta64(int(mn), 'm') 
                               for mn in df.time.values)    
    #df['hour'] = (d_times.hour+ d_times.minute/60) * fw[2]
    for i in range(0,24):
        df['h' + str(i)] = (((d_times.hour+ d_times.minute/60) + i) % 24) * fw[2]
    
    df['w0'] = ((d_times.weekday + 0) % 7) * fw[3]
    df['w1'] = ((d_times.weekday + 1) % 7) * fw[3]
    df['w2'] = ((d_times.weekday + 2) % 7) * fw[3]
    df['w3'] = ((d_times.weekday + 3) % 7) * fw[3]
    df['w4'] = ((d_times.weekday + 4) % 7) * fw[3]
    df['w5'] = ((d_times.weekday + 5) % 7) * fw[3]
    df['w6'] = ((d_times.weekday + 6) % 7) * fw[3]
    
    df['day'] = (d_times.dayofyear * fw[4]).astype(int)
    df['month'] = d_times.month * fw[5]
    df['year'] = (d_times.year - 2013) * fw[6]
    #df.accuracy = df.accuracy.values * fw[7]
    df['accuracy'] = np.log10(df.accuracy) * fw[7]
    df['log_month'] = np.log10(3+df.time/(60 * 24 * 30)) * fw[8]
    df = df.drop(['time'], axis=1)
    
    return df

In [4]:
def calculate_distance(distances):
    return distances ** fw[10]

In [5]:
def process_one_cell(df_train, df_test, th):    
    place_counts = df_train.place_id.value_counts()
    mask = (place_counts[df_train.place_id.values] >= th).values
    df_train = df_train.loc[mask]
    row_ids = df_test.row_id
    
    best_k=np.floor(np.sqrt(len(df_train.index)/4*fw[9])/5)
    
    #Applying the classifier
    clf = KNeighborsClassifier(n_neighbors=best_k.astype(int), weights=calculate_distance, 
                               metric='manhattan')
    clf.fit(df_train[feature_list], df_train.place_id)
    predictions = clf.predict_proba(df_test[feature_list])
    result_index = np.argsort(predictions, axis=1)[:,::-1][:,:3]
    result = pd.DataFrame(df_test.row_id)
    result['p1'] = clf.classes_[result_index][:,:1]
    result['p2'] = clf.classes_[result_index][:,1:2]
    result['p3'] = clf.classes_[result_index][:,2:3]
    
    return result

In [6]:
def run_prediction(df_train, df_test):
    df_train = prepare_data(df_train)
    df_test = prepare_data(df_test)
    df_test['p1'] = np.nan
    df_test['p2'] = np.nan
    df_test['p3'] = np.nan
    
    prediction_result = process_one_cell(df_train, df_test, 1)
    prediction_result.sort_index(inplace=True)
    return prediction_result

In [7]:
def run_one_cell(n_cell_x, n_cell_y, x_index, y_index, x_length, y_length, df):
    min_x = x_index * x_length
    max_x = (x_index + 1) * x_length
    min_y = y_index * y_length
    max_y = (y_index + 1) * y_length
    
    # include the edge
    if(y_index + 1 == n_cell_y):
        max_y += 0.1
    if(x_index + 1 == n_cell_x):
        max_x += 0.1
    df_train_cell = df[(df.time <= 786239 * 0.875) & \
                       (df.x >= min_x - 0.1) & \
                       (df.x < max_x + 0.1) & \
                       (df.y >= min_y - 0.1) & \
                       (df.y < max_y + 0.1)].copy()
    df_validation_cell = df[(df.time > 786239 * 0.875) & \
                            (df.x >= min_x) & \
                            (df.x < max_x) & \
                            (df.y >= min_y) & \
                            (df.y < max_y)].copy()
     
    prediction_result = run_prediction(df_train_cell, df_validation_cell)
     
    # Calculate score
    prediction_result.sort_index(inplace=True)
    prediction_result['score'] = (prediction_result.p1 == df_validation_cell.place_id) * 1
    prediction_result['score'] += (prediction_result.p2 == df_validation_cell.place_id) * 0.5
    prediction_result['score'] += (prediction_result.p3 == df_validation_cell.place_id) * 0.33
    score = prediction_result.score.mean()
    
    return score

In [8]:
def run_validation():
    df = pd.read_csv('../../train.csv',
                               usecols=['row_id','x','y','accuracy','time','place_id'])
    
    n_cell_x = 10
    n_cell_y = 20
    x_length = 10 / n_cell_x
    y_length = 10 / n_cell_y
    total_score = 0
    score_count = 0
    base_fw = [400, 1000, 1/10.5, 1/2.0, 1./22., 2, 9, 23, 4.5, 0.6, -2]
    global fw_final
    global fw
    
    for x_index in range(0, n_cell_x):
        start_time = time.time()
        for y_index in range(0, n_cell_y):
            fw = base_fw[:]
            score = run_one_cell(n_cell_x, n_cell_y, x_index, y_index, x_length, y_length, df)
            
            if(score < 0.5484):
                print('adjusting weight on ', x_index, y_index, ' with score ', score, flush=True)
                
                adjust_fw = base_fw[:]
                for i in range(0,len(base_fw)):
                    max_score = 0
                    max_score_weight = 0
                    for ratio in [-0.1, -0.05, 0.05, 0.1]:
                        fw = base_fw[:]
                        fw[i] = fw[i] * (1+ratio)
                        temp_score = run_one_cell(n_cell_x, n_cell_y, x_index, y_index, x_length, y_length, df)
                        if(temp_score> max_score):
                            max_score = temp_score
                            max_score_weight = fw[i]
                    
                    if(max_score > score):
                        print('Found new weight imporved the socre to ', max_score, 'with weight', max_score_weight, 'on feature ', i, flush=True)
                        adjust_fw[i] = max_score_weight
                        
                fw_final.append(adjust_fw)
            else:
                fw_final.append(fw)
            print("Cell is done", x_index, y_index, flush=True)
            total_score += score
            score_count += 1
            
        print("Elapsed time overall: %s seconds" % (time.time() - start_time), x_index, flush = True)
    print("Final:", total_score/score_count, flush=True)
    return fw_final


In [9]:
fw = []
fw_final = []
fw_final = run_validation()
print(fw_final, flush=True)

Cell is done 0 0
Cell is done 0 1
Cell is done 0 2
Cell is done 0 3
adjusting weight on  0 4  with score  0.5451112028058555
Found new weight imporved the socre to  0.545281617495354 with weight 360.0 on feature  0
Found new weight imporved the socre to  0.5453261811429713 with weight 0.05 on feature  4
Found new weight imporved the socre to  0.5453340210439408 with weight 20.7 on feature  7
Found new weight imporved the socre to  0.5452267381885664 with weight 0.57 on feature  9
Found new weight imporved the socre to  0.5451235815968603 with weight -2.1 on feature  10
Cell is done 0 4
Cell is done 0 5
Cell is done 0 6
Cell is done 0 7
Cell is done 0 8
Cell is done 0 9
Cell is done 0 10
Cell is done 0 11
Cell is done 0 12
Cell is done 0 13
Cell is done 0 14
Cell is done 0 15
Cell is done 0 16
Cell is done 0 17
Cell is done 0 18
Cell is done 0 19
Elapsed time overall: 1135.6765894889832 seconds 0
Cell is done 1 0
Cell is done 1 1
adjusting weight on  1 2  with score  0.5409229194771319


In [10]:
fw_final = [[400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 1000, 0.09523809523809523, 0.5, 0.05, 2, 9, 20.7, 4.5, 0.57, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1100.0, 0.09047619047619046, 0.45, 0.04090909090909091, 1.8, 8.549999999999999, 24.150000000000002, 4.5, 0.66, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 900.0, 0.1, 0.45, 0.04772727272727273, 1.8, 9.450000000000001, 21.849999999999998, 4.5, 0.54, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 900.0, 0.09047619047619046, 0.45, 0.05, 2.1, 8.1, 21.849999999999998, 4.5, 0.66, -1.9],\
 [440.00000000000006, 900.0, 0.10476190476190476, 0.55, 0.04772727272727273, 1.9, 9, 20.7, 4.5, 0.63, -2.2],\
 [440.00000000000006, 900.0, 0.09047619047619046, 0.525, 0.05, 1.9, 9, 24.150000000000002, 4.5, 0.54, -2.1],\
 [400, 950.0, 0.1, 0.5, 0.04090909090909091, 1.8, 9, 20.7, 4.5, 0.63, -2.1],\
 [360.0, 900.0, 0.09523809523809523, 0.55, 0.045454545454545456, 2.1, 9.9, 25.3, 4.5, 0.66, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.45, 0.045454545454545456, 1.9, 9.9, 25.3, 4.5, 0.6, -2],\
 [380.0, 1000, 0.09047619047619046, 0.525, 0.045454545454545456, 1.8, 9.9, 25.3, 4.5, 0.66, -1.9],\
 [360.0, 900.0, 0.1, 0.55, 0.04772727272727273, 2.2, 8.549999999999999, 20.7, 4.5, 0.57, -2.1],\
 [420.0, 1050.0, 0.1, 0.45, 0.05, 1.8, 9.9, 20.7, 4.5, 0.63, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09047619047619046, 0.525, 0.04772727272727273, 1.9, 9, 25.3, 4.5, 0.57, -2],\
 [400, 900.0, 0.09523809523809523, 0.525, 0.05, 2, 9.9, 23, 4.5, 0.6, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 900.0, 0.09523809523809523, 0.525, 0.045454545454545456, 2.1, 9.450000000000001, 25.3, 4.5, 0.63, -1.8],\
 [420.0, 1050.0, 0.08571428571428572, 0.475, 0.04318181818181818, 1.8, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [440.00000000000006, 1050.0, 0.09523809523809523, 0.475, 0.04772727272727273, 2.1, 8.1, 21.849999999999998, 4.5, 0.63, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 900.0, 0.08571428571428572, 0.45, 0.04318181818181818, 1.9, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [380.0, 950.0, 0.09523809523809523, 0.55, 0.04318181818181818, 1.8, 9, 24.150000000000002, 4.5, 0.57, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [380.0, 900.0, 0.1, 0.45, 0.05, 2.2, 8.549999999999999, 20.7, 4.5, 0.54, -2.2],\
 [400, 1050.0, 0.08571428571428572, 0.525, 0.04318181818181818, 2.1, 8.1, 25.3, 4.5, 0.6, -2.1],\
 [440.00000000000006, 900.0, 0.08571428571428572, 0.475, 0.04090909090909091, 1.8, 9.450000000000001, 25.3, 4.5, 0.54, -2.2],\
 [440.00000000000006, 900.0, 0.09047619047619046, 0.475, 0.04772727272727273, 2.1, 8.549999999999999, 23, 4.5, 0.6, -2.1],\
 [380.0, 900.0, 0.09047619047619046, 0.475, 0.04772727272727273, 1.8, 9.450000000000001, 25.3, 4.5, 0.6, -2.2],\
 [360.0, 900.0, 0.1, 0.45, 0.05, 2.2, 8.549999999999999, 24.150000000000002, 4.5, 0.57, -2.2],\
 [420.0, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 1.9, 9.450000000000001, 23, 4.5, 0.6, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 950.0, 0.09047619047619046, 0.5, 0.04090909090909091, 1.9, 8.1, 25.3, 4.5, 0.54, -2.1],\
 [440.00000000000006, 950.0, 0.1, 0.45, 0.04090909090909091, 1.8, 9.9, 20.7, 4.5, 0.57, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [380.0, 950.0, 0.1, 0.5, 0.045454545454545456, 2.1, 8.549999999999999, 23, 4.5, 0.66, -2.2],\
 [360.0, 950.0, 0.09523809523809523, 0.55, 0.045454545454545456, 1.9, 9.450000000000001, 25.3, 4.5, 0.63, -1.9],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 900.0, 0.09523809523809523, 0.55, 0.04090909090909091, 2, 9.9, 25.3, 4.5, 0.63, -2.1],\
 [380.0, 900.0, 0.10476190476190476, 0.525, 0.04772727272727273, 2.1, 9.450000000000001, 20.7, 4.5, 0.63, -2.2],\
 [420.0, 950.0, 0.09047619047619046, 0.45, 0.04772727272727273, 1.8, 9.9, 24.150000000000002, 4.5, 0.6, -2.1],\
 [360.0, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -1.9],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [380.0, 1050.0, 0.09047619047619046, 0.475, 0.04090909090909091, 1.9, 9.450000000000001, 24.150000000000002, 4.5, 0.57, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 900.0, 0.09047619047619046, 0.475, 0.04318181818181818, 2.2, 9.450000000000001, 23, 4.5, 0.66, -2.1],\
 [420.0, 900.0, 0.08571428571428572, 0.45, 0.045454545454545456, 1.8, 9, 25.3, 4.5, 0.66, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 1100.0, 0.09047619047619046, 0.475, 0.04318181818181818, 1.8, 9.9, 20.7, 4.5, 0.66, -2.2],\
 [400, 900.0, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 900.0, 0.08571428571428572, 0.45, 0.04090909090909091, 1.8, 9.9, 25.3, 4.5, 0.63, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1050.0, 0.08571428571428572, 0.55, 0.04772727272727273, 2.2, 9.9, 25.3, 4.5, 0.66, -1.8],\
 [420.0, 1050.0, 0.09523809523809523, 0.45, 0.05, 2, 8.549999999999999, 20.7, 4.5, 0.57, -1.9],\
 [440.00000000000006, 950.0, 0.08571428571428572, 0.5, 0.05, 1.8, 9.9, 24.150000000000002, 4.5, 0.54, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.04090909090909091, 1.9, 9.9, 21.849999999999998, 4.5, 0.57, -2],\
 [420.0, 1050.0, 0.10476190476190476, 0.55, 0.04318181818181818, 1.8, 9.9, 20.7, 4.5, 0.54, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2.2],\
 [420.0, 1050.0, 0.09047619047619046, 0.45, 0.04090909090909091, 1.8, 9.9, 25.3, 4.5, 0.66, -1.8],\
 [380.0, 950.0, 0.08571428571428572, 0.45, 0.04772727272727273, 2.2, 9.9, 21.849999999999998, 4.5, 0.54, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 1000, 0.09523809523809523, 0.475, 0.04318181818181818, 2, 8.549999999999999, 25.3, 4.5, 0.6, -2.1],\
 [360.0, 950.0, 0.1, 0.55, 0.04772727272727273, 1.8, 9.450000000000001, 21.849999999999998, 4.5, 0.54, -1.9],\
 [400, 1000, 0.09523809523809523, 0.45, 0.04318181818181818, 1.9, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1100.0, 0.09047619047619046, 0.475, 0.04090909090909091, 1.8, 9.9, 25.3, 4.5, 0.66, -1.8],\
 [400, 1000, 0.09523809523809523, 0.525, 0.045454545454545456, 1.8, 9.9, 24.150000000000002, 4.5, 0.6, -2.2],\
 [440.00000000000006, 950.0, 0.09047619047619046, 0.475, 0.04772727272727273, 1.9, 9.9, 24.150000000000002, 4.5, 0.66, -2.2],\
 [440.00000000000006, 1050.0, 0.08571428571428572, 0.525, 0.04090909090909091, 2.1, 9.9, 24.150000000000002, 4.5, 0.63, -1.8],\
 [360.0, 950.0, 0.1, 0.475, 0.045454545454545456, 1.8, 9, 25.3, 4.5, 0.66, -2.2],\
 [360.0, 1000, 0.08571428571428572, 0.55, 0.04318181818181818, 1.9, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.10476190476190476, 0.5, 0.04318181818181818, 2, 9, 24.150000000000002, 4.5, 0.6, -1.9],\
 [400, 1000, 0.08571428571428572, 0.525, 0.045454545454545456, 2.2, 9.450000000000001, 23, 4.5, 0.6, -2],\
 [400, 950.0, 0.08571428571428572, 0.55, 0.04090909090909091, 1.8, 9.9, 21.849999999999998, 4.5, 0.63, -2],\
 [360.0, 1000, 0.10476190476190476, 0.5, 0.04772727272727273, 1.8, 9, 21.849999999999998, 4.5, 0.6, -2.1],\
 [420.0, 1000, 0.08571428571428572, 0.525, 0.04772727272727273, 1.8, 9, 23, 4.5, 0.6, -2.2],\
 [440.00000000000006, 950.0, 0.09047619047619046, 0.475, 0.045454545454545456, 2.2, 9.9, 20.7, 4.5, 0.57, -2.2],\
 [380.0, 900.0, 0.09047619047619046, 0.55, 0.04772727272727273, 1.9, 9.9, 20.7, 4.5, 0.57, -2.2],\
 [380.0, 1000, 0.08571428571428572, 0.525, 0.045454545454545456, 2, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [400, 900.0, 0.10476190476190476, 0.55, 0.045454545454545456, 2.2, 8.1, 25.3, 4.5, 0.66, -1.9],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 900.0, 0.08571428571428572, 0.45, 0.04772727272727273, 1.8, 9, 20.7, 4.5, 0.54, -2.2],\
 [400, 900.0, 0.08571428571428572, 0.55, 0.04318181818181818, 2.1, 8.549999999999999, 24.150000000000002, 4.5, 0.63, -2],\
 [420.0, 900.0, 0.09523809523809523, 0.475, 0.05, 2.2, 8.549999999999999, 23, 4.5, 0.57, -2.2],\
 [360.0, 950.0, 0.09047619047619046, 0.475, 0.04772727272727273, 2.1, 9.9, 24.150000000000002, 4.5, 0.54, -2.1],\
 [380.0, 950.0, 0.09523809523809523, 0.525, 0.04772727272727273, 1.9, 9.9, 24.150000000000002, 4.5, 0.6, -2.2],\
 [440.00000000000006, 1100.0, 0.10476190476190476, 0.45, 0.04318181818181818, 1.8, 8.1, 20.7, 4.5, 0.57, -2.1],\
 [420.0, 900.0, 0.1, 0.475, 0.05, 2.1, 9.450000000000001, 25.3, 4.5, 0.66, -1.9],\
 [360.0, 1100.0, 0.09047619047619046, 0.45, 0.05, 2.2, 9.450000000000001, 20.7, 4.5, 0.63, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 1050.0, 0.1, 0.55, 0.04090909090909091, 1.8, 8.549999999999999, 25.3, 4.5, 0.66, -2],\
 [420.0, 1000, 0.09523809523809523, 0.5, 0.04090909090909091, 1.9, 9.9, 23, 4.5, 0.63, -2.2],\
 [420.0, 900.0, 0.08571428571428572, 0.55, 0.04772727272727273, 2.1, 9.450000000000001, 24.150000000000002, 4.5, 0.54, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 900.0, 0.08571428571428572, 0.475, 0.04772727272727273, 2.2, 9.9, 20.7, 4.5, 0.54, -2],\
 [440.00000000000006, 900.0, 0.1, 0.55, 0.04772727272727273, 1.8, 9, 20.7, 4.5, 0.57, -1.9],\
 [440.00000000000006, 950.0, 0.09047619047619046, 0.45, 0.05, 2.1, 8.1, 23, 4.5, 0.54, -1.9],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [380.0, 950.0, 0.1, 0.5, 0.04772727272727273, 2.1, 9.9, 20.7, 4.5, 0.6, -2.1],\
 [380.0, 950.0, 0.08571428571428572, 0.55, 0.04090909090909091, 2.1, 8.549999999999999, 24.150000000000002, 4.5, 0.66, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 1100.0, 0.09047619047619046, 0.45, 0.04318181818181818, 2.2, 9.9, 21.849999999999998, 4.5, 0.54, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.10476190476190476, 0.5, 0.045454545454545456, 2.1, 9, 23, 4.5, 0.6, -2.1],\
 [400, 950.0, 0.09047619047619046, 0.55, 0.04090909090909091, 1.9, 9.9, 21.849999999999998, 4.5, 0.54, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1100.0, 0.09047619047619046, 0.5, 0.05, 2.2, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [440.00000000000006, 950.0, 0.09047619047619046, 0.525, 0.04090909090909091, 1.9, 8.549999999999999, 20.7, 4.5, 0.54, -2.1],\
 [440.00000000000006, 1000, 0.09047619047619046, 0.45, 0.04772727272727273, 2, 9, 25.3, 4.5, 0.66, -1.8],\
 [360.0, 950.0, 0.10476190476190476, 0.475, 0.04318181818181818, 1.8, 9.450000000000001, 24.150000000000002, 4.5, 0.63, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1050.0, 0.08571428571428572, 0.45, 0.045454545454545456, 2.1, 8.549999999999999, 24.150000000000002, 4.5, 0.66, -2.1],\
 [380.0, 900.0, 0.1, 0.525, 0.045454545454545456, 1.9, 9.450000000000001, 20.7, 4.5, 0.54, -2.1],\
 [380.0, 950.0, 0.1, 0.475, 0.05, 1.8, 8.1, 20.7, 4.5, 0.63, -1.8],\
 [360.0, 950.0, 0.09047619047619046, 0.525, 0.04090909090909091, 2.2, 9.9, 24.150000000000002, 4.5, 0.57, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 900.0, 0.09523809523809523, 0.475, 0.045454545454545456, 2, 9, 23, 4.5, 0.66, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 1000, 0.09523809523809523, 0.475, 0.045454545454545456, 1.9, 8.549999999999999, 23, 4.5, 0.6, -2.1],\
 [400, 1000, 0.08571428571428572, 0.475, 0.045454545454545456, 2.1, 9.450000000000001, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [380.0, 900.0, 0.1, 0.55, 0.05, 1.8, 9.9, 20.7, 4.5, 0.54, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 950.0, 0.1, 0.525, 0.05, 2.2, 8.549999999999999, 21.849999999999998, 4.5, 0.66, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [420.0, 950.0, 0.09523809523809523, 0.525, 0.04318181818181818, 1.9, 9, 20.7, 4.5, 0.54, -2.1],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 1100.0, 0.09047619047619046, 0.45, 0.05, 1.9, 9.9, 21.849999999999998, 4.5, 0.57, -1.8],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 950.0, 0.1, 0.5, 0.04318181818181818, 2.1, 8.549999999999999, 25.3, 4.5, 0.63, -2],\
 [400, 900.0, 0.1, 0.5, 0.05, 2.1, 8.549999999999999, 21.849999999999998, 4.5, 0.6, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 1000, 0.09047619047619046, 0.475, 0.05, 1.9, 9.450000000000001, 24.150000000000002, 4.5, 0.54, -2.1],\
 [420.0, 950.0, 0.09047619047619046, 0.55, 0.04318181818181818, 1.8, 9, 20.7, 4.5, 0.6, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [360.0, 1000, 0.09523809523809523, 0.525, 0.04772727272727273, 2, 9.450000000000001, 24.150000000000002, 4.5, 0.57, -2],\
 [380.0, 950.0, 0.09047619047619046, 0.475, 0.045454545454545456, 1.9, 9.9, 20.7, 4.5, 0.54, -2],\
 [380.0, 1000, 0.09047619047619046, 0.55, 0.045454545454545456, 1.9, 8.1, 24.150000000000002, 4.5, 0.54, -2],\
 [420.0, 900.0, 0.09047619047619046, 0.475, 0.04090909090909091, 1.8, 9.450000000000001, 21.849999999999998, 4.5, 0.63, -1.9],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [440.00000000000006, 900.0, 0.08571428571428572, 0.45, 0.04772727272727273, 1.8, 9.9, 20.7, 4.5, 0.54, -2.2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2],\
 [400, 1000, 0.09523809523809523, 0.5, 0.045454545454545456, 2, 9, 23, 4.5, 0.6, -2]]

In [11]:
def run_test():
    
    # Run test
    df = pd.read_csv('../../train.csv',
                           usecols=['row_id','x','y','accuracy','time','place_id'])
    df_test = pd.read_csv('../../test.csv',
                           usecols=['row_id','x','y','accuracy','time'])
    
    n_cell_x = 10
    n_cell_y = 20
    x_length = 10 / n_cell_x
    y_length = 10 / n_cell_y
    counter = 0
    total_result = pd.DataFrame()
    global fw
    
    for x_index in range(0, n_cell_x):
        start_time = time.time()
        for y_index in range(0, n_cell_y):
            min_x = x_index * x_length
            max_x = (x_index + 1) * x_length
            min_y = y_index * y_length
            max_y = (y_index + 1) * y_length
            
            # include the edge
            if(y_index + 1 == n_cell_y):
                max_y += 0.1
            if(x_index + 1 == n_cell_x):
                max_x += 0.1

            df_train_cell = df[(df.x >= min_x - 0.1) & \
                               (df.x < max_x + 0.1) & \
                               (df.y >= min_y - 0.1) & \
                               (df.y < max_y + 0.1)].copy()
            
            df_test_cell = df_test[(df_test.x >= min_x) & \
                                   (df_test.x < max_x) & \
                                   (df_test.y >= min_y) & \
                                   (df_test.y < max_y)].copy()
            fw = fw_final[counter]
            prediction_result = run_prediction(df_train_cell, df_test_cell)
            total_result = total_result.append(prediction_result)
            counter += 1
        print("Elapsed time overall: %s seconds" % (time.time() - start_time), x_index, flush = True)
            
    
    total_result.sort_index(inplace=True)
    total_result['place_id'] = total_result.p1.astype(str) + " " + \
                               total_result.p2.astype(str) + " " + \
                               total_result.p3.astype(str)
    total_result[['row_id', 'place_id']].to_csv('Baseline620.csv', index=False)

In [12]:
run_test()

Elapsed time overall: 663.994193315506 seconds 0
Elapsed time overall: 740.8766412734985 seconds 1
Elapsed time overall: 745.5363931655884 seconds 2
Elapsed time overall: 704.3780963420868 seconds 3
Elapsed time overall: 690.7077469825745 seconds 4
Elapsed time overall: 720.6914947032928 seconds 5
Elapsed time overall: 692.8503804206848 seconds 6
Elapsed time overall: 722.5725965499878 seconds 7
Elapsed time overall: 689.5900194644928 seconds 8
Elapsed time overall: 624.4675269126892 seconds 9


In [None]:
0.58065