In [73]:
import glob
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import normalize
import csv
import pandas as pd
import random
from collections import defaultdict
import pickle

In [74]:
locations = ['not_crossing_wen', 'not_crossing_nick', 'not_crossing']
locations_crossing = ['crossing_wen', 'crossing_nick', 'crossing']

not_crossing = defaultdict(list)
crossing = defaultdict(list)

In [75]:
# Read only the valid data into dict

for location in locations:
    path = 'data/' + location + '/'
    for filename in glob.glob(path + '*.csv'):
        name = filename.replace(path, '')[:-4]

        df=pd.read_csv(filename)
        if len(df.columns) >= 5 * 2: # 2 columns per point for x & y
            for data in df.values:
                not_crossing[name].append(data)
                
for location in locations_crossing:
    path = 'data/' + location + '/'
    for filename in glob.glob(path + '*.csv'):
        name = filename.replace(path, '')[:-4]

        df=pd.read_csv(filename)
        if len(df.columns) >= 5 * 2: # 2 columns per point for x & y
            for data in df.values:
                crossing[name].append(data)

In [84]:
# sanity check
crossing['01234567891011121314151617'][:5]

[array([0.38805532, 0.08132222, 0.3928495 , 0.07436835, 0.38376334,
        0.07147983, 0.39474568, 0.07911147, 0.37062904, 0.07041246,
        0.380716  , 0.11870236, 0.35417375, 0.10275031, 0.38235566,
        0.1763957 , 0.35862297, 0.17170343, 0.40344015, 0.21594571,
        0.39685455, 0.21892008, 0.35804254, 0.23037648, 0.34468532,
        0.22811249, 0.35048506, 0.32761222, 0.34842077, 0.32894176,
        0.34485534, 0.42615405, 0.33247641, 0.43560019, 0.36754808,
        0.11209422]),
 array([0.38343269, 0.09680031, 0.38688397, 0.09064046, 0.37991419,
        0.08724332, 0.38666096, 0.09162083, 0.36820725, 0.08272316,
        0.37677792, 0.12523876, 0.34593552, 0.10810407, 0.38498998,
        0.16869476, 0.32637277, 0.15388227, 0.41926196, 0.1748613 ,
        0.32543752, 0.20542498, 0.34980398, 0.2237018 , 0.33314329,
        0.22108464, 0.35381913, 0.30551109, 0.34165591, 0.3070375 ,
        0.33114693, 0.39260769, 0.32559067, 0.4013333 , 0.36137345,
        0.11714577]),
 arr

In [98]:
# Create models

def extract(datum, label):
#     print(datum)
    return (label, [float(d) for d in datum])

model_scores = []

for key in not_crossing:
    if key in crossing:
        data = [extract(d, 1) for d in not_crossing[key]]
        data += [extract(d, 2) for d in crossing[key]]
        
        random.shuffle(data)
        
        X = normalize([d[1] for d in data])
        y = [d[0] for d in data]
        
        N = len(X)
        
        if N < 8:
            continue
        
        X_train = X[:3*N//4]
        X_test = X[3*N//4:]
        y_train = y[:3*N//4]
        y_test = y[3*N//4:]
        
#         print(y_train)
#         break;
        
        clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))
        
        try:
            clf.fit(X_train, y_train)
            score = clf.score(X_test, y_test)
            
            if score < 0.6:
                print("⚠️ low score " + key + ' count: ' + str(N), 'score: ' + str(score))
            else:
                print('✔️ ' + key + ' score: ' + str(score), 'count: ' + str(N))
                pickle.dump(clf, open('models/' + key + '_{:.2f}'.format(score) + '.sav', 'wb'))
            
            model_scores.append((key, score, N))
        except:
            # This means there is only one classification in training data. Caused by train/test split
            print("❌ omitted " + key + ' count: ' + str(N))

✔️ 34567891011121314151617 score: 1.0 count: 73
✔️ 011121314151617 score: 1.0 count: 13
✔️ 611121314151617 score: 0.6666666666666666 count: 9
✔️ 567891011121314151617 score: 0.7 count: 39
✔️ 111213141516 score: 0.8918918918918919 count: 148
⚠️ low score 012345678911121314151617 count: 13 score: 0.5
✔️ 0124681011121314151617 score: 1.0 count: 28
⚠️ low score 01234567811121314151617 count: 10 score: 0.3333333333333333
✔️ 5681011121314151617 score: 0.6666666666666666 count: 12
✔️ 01234567891011121314151617 score: 0.9252336448598131 count: 426
✔️ 11121314151617 score: 0.6470588235294118 count: 65
✔️ 024567891011121314151617 score: 1.0 count: 33
✔️ 0124567891011121314151617 score: 0.6900826446280992 count: 967
✔️ 0123456781011121314151617 score: 0.8181818181818182 count: 41
✔️ 0234567891011121314151617 score: 0.75 count: 16
⚠️ low score 1112131416 count: 8 score: 0.5
⚠️ low score 0245678911121314151617 count: 9 score: 0.3333333333333333
⚠️ low score 02456811121314151617 count: 17 score: 0.4

In [99]:
sorted(model_scores, key=lambda tup: (tup[1], tup[2]), reverse=True)

[('34567891011121314151617', 1.0, 73),
 ('012456781011121314151617', 1.0, 68),
 ('024567891011121314151617', 1.0, 33),
 ('0124681011121314151617', 1.0, 28),
 ('011121314151617', 1.0, 13),
 ('01245681011121314151617', 1.0, 12),
 ('111213141617', 1.0, 8),
 ('01234567891011121314151617', 0.9252336448598131, 426),
 ('0124567811121314151617', 0.9090909090909091, 44),
 ('111213141516', 0.8918918918918919, 148),
 ('024567811121314151617', 0.8888888888888888, 34),
 ('0123456781011121314151617', 0.8181818181818182, 41),
 ('0234567891011121314151617', 0.75, 16),
 ('01245678911121314151617', 0.7272727272727273, 41),
 ('567891011121314151617', 0.7, 39),
 ('0124567891011121314151617', 0.6900826446280992, 967),
 ('5681011121314151617', 0.6666666666666666, 12),
 ('0567891011121314151617', 0.6666666666666666, 11),
 ('611121314151617', 0.6666666666666666, 9),
 ('11121314151617', 0.6470588235294118, 65),
 ('024681011121314151617', 0.6, 18),
 ('012345678911121314151617', 0.5, 13),
 ('1112131416', 0.5, 8)

### The best model: 01234567891011121314151617