In [1]:
import numpy as np
import pandas as pd
import os

In [42]:
def read(filepath):
    try:
        df = pd.DataFrame(columns=['number', 'id', 'time_sequence'])
        lines = [line.strip() for line in open(filepath,'r')]
        number = int(lines[1].split(" ")[3])
        id = int(lines[2].split(" ")[3])
        matrix = []
        for i in range(5,len(lines)):
            line = lines[i].split(",")
            line = np.array(line).astype(np.float64)
            matrix.append([line[3], *line[:3]])
        df.loc[0] = [number, id, matrix]
        return df
    except IOError as e:
        print("Unable to read dataset file!\n")

In [43]:
directory = 'Sketch-Data-master\SketchData\SketchData\Domain01'

df = pd.DataFrame(columns=['number', 'id', 'time_sequence'])
for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    # checking if it is a file
    if os.path.isfile(f):
        df=df.append(read(f), ignore_index=True)
        '''le = len(read(f)["time_sequence"].values[0])
        if min > le:
            if le==31:
                print(f)
            min = le'''
# min = 31
print(df)

    number  id                                      time_sequence
0        1   1  [[6.0, 0.042075, 0.036799, 0.25838], [37.0, 0....
1        1   1  [[27.0, -0.030138, 0.061892, 0.356577], [58.0,...
2       10   1  [[9.0, 0.102607, 0.069246, 0.302644], [41.0, 0...
3       10  10  [[38.0, 0.125652, 0.054037, 0.366598], [69.0, ...
4        1   2  [[11.0, 0.018762, 0.056202, 0.334956], [60.0, ...
..     ...  ..                                                ...
995     10  10  [[30.0, 0.113368, 0.034153, 0.334738], [63.0, ...
996     10  10  [[32.0, 0.111792, 0.018394, 0.323929], [62.0, ...
997     10  10  [[47.0, 0.120695, 0.035941, 0.33836], [78.0, 0...
998     10  10  [[30.0, 0.129327, 0.026097, 0.343372], [79.0, ...
999     10  10  [[8.0, 0.128029, 0.035957, 0.35456], [53.0, 0....

[1000 rows x 3 columns]


In [52]:
def cross_validation_split():
    dataset_split = []
    for i in range(10):
        fold=[x for x in range(100*i, 100*(i+1))]
        other=[x for x in range(1000) if x not in fold]
        dataset_split+=[(fold,other)]
    return dataset_split


In [21]:
def DTWdistance(data1,data2, w):
    n=len(data1)
    m=len(data2)
    DTW=np.zeros((n,m))
    w=max(w,abs(n-m))

    for i in range(n):
        for j in range(m):
            DTW[i,j]=999999
    DTW[0,0]=0

    for i in range(1,n):
        for j in range(max(1,i-w),min(m,i+w)):
            DTW[i,j]=0
    
    for i in range(1,n):
        for j in range(max(1,i-w),min(m,i+w)):
            cost=distance(data1[i],data2[j])
            DTW[i,j]=cost+np.min([DTW[i-1,j],#insertion
                                DTW[i,j-1],#deletion
                                DTW[i-1,j-1]])#match
    
    return DTW[n,m]

def distance(a,b):
    return np.linalg.norm(a-b)

In [None]:
def get_neighbors(train, test_row, num_neighbors):
	distances = list()
	for train_row in train:
		dist = DTWdistance(test_row, train_row)
		distances.append((train_row, dist))
	distances.sort(key=lambda tup: tup[1])
	neighbors = list()
	for i in range(num_neighbors):
		neighbors.append(distances[i][0])
	return neighbors

In [41]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets

n_neighbors = 10

# we only take the first two features. We could avoid this ugly
# slicing by using a two-dim dataset
X = df["time_sequence"]
y = df["number"]

for test_ind, train_ind in cross_validation_split():
    

    x_train, y_train = X[train_ind], y[train_ind]
    x_test, y_test = X[test_ind], y[test_ind]

    clf = neighbors.KNeighborsClassifier(n_neighbors, metric=DTWdistance)
    clf.fit(x_train, y_train)


h = 0.02  # step size in the mesh

# Create color maps
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ["darkorange", "c", "darkblue"]

for weights in [DTWdistance]:
    # we create an instance of Neighbours Classifier and fit the data.
    clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights, metric=DTWdistance)
    clf.fit(X, y)

    # Plot the decision boundary. For that, we will assign a color to each
    # point in the mesh [x_min, x_max]x[y_min, y_max].
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z) #cmap=cmap_light)

    # Plot also the training points
    sns.scatterplot(
        x=X[:, 0],
        y=X[:, 1],
        #hue=iris.target_names[y],
        palette=cmap_bold,
        alpha=1.0,
        edgecolor="black",
    )
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title(
        "3-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)
    )
    '''plt.xlabel(iris.feature_names[0])
    plt.ylabel(iris.feature_names[1])'''

plt.show()

ValueError: could not convert string to float: '[[6.0, 0.042075, 0.036799, 0.25838], [37.0, 0.041904, 0.037187, 0.258505], [67.0, 0.041739, 0.037219, 0.258619], [101.0, 0.041446, 0.037573, 0.258731], [132.0, 0.041303, 0.037331, 0.25883], [180.0, 0.040491, 0.037559, 0.25893], [214.0, 0.039765, 0.038245, 0.259025], [244.0, 0.038568, 0.038757, 0.259128], [277.0, 0.037329, 0.038885, 0.259251], [308.0, 0.035251, 0.039654, 0.2594], [340.0, 0.033623, 0.03984, 0.259559], [374.0, 0.031522, 0.040281, 0.259734], [403.0, 0.02965, 0.041355, 0.259901], [435.0, 0.028052, 0.042328, 0.26007], [468.0, 0.025998, 0.042693, 0.260237], [499.0, 0.023849, 0.043293, 0.260413], [533.0, 0.021556, 0.043995, 0.260606], [579.0, 0.01919, 0.044105, 0.260821], [612.0, 0.016729, 0.044236, 0.261096], [644.0, 0.013483, 0.045102, 0.261364], [678.0, 0.010871, 0.045586, 0.261658], [707.0, 0.007973, 0.045779, 0.262019], [740.0, 0.004794, 0.045779, 0.262441], [772.0, 0.000913, 0.045793, 0.262818], [803.0, -0.003038, 0.046076, 0.263215], [835.0, -0.006461, 0.045424, 0.263665], [867.0, -0.010195, 0.044215, 0.264101], [902.0, -0.014213, 0.042907, 0.264484], [932.0, -0.018004, 0.041523, 0.264813], [980.0, -0.02206, 0.04083, 0.2651], [1012.0, -0.025689, 0.039176, 0.265417], [1043.0, -0.028318, 0.035665, 0.265667], [1076.0, -0.031583, 0.033724, 0.265877], [1109.0, -0.034402, 0.030736, 0.266043], [1139.0, -0.036349, 0.027647, 0.266176], [1171.0, -0.038617, 0.024838, 0.266281], [1207.0, -0.040964, 0.022257, 0.266363], [1236.0, -0.042711, 0.020107, 0.266429], [1268.0, -0.044082, 0.017021, 0.266476], [1301.0, -0.045256, 0.014252, 0.266513], [1331.0, -0.046242, 0.010491, 0.266536], [1379.0, -0.046988, 0.007606, 0.266544], [1411.0, -0.04754, 0.004798, 0.266543], [1444.0, -0.048024, 0.001943, 0.266519], [1478.0, -0.048027, -0.00232, 0.266433], [1508.0, -0.04814, -0.005166, 0.266303], [1540.0, -0.04787, -0.007695, 0.266044], [1571.0, -0.04756, -0.011263, 0.265619], [1603.0, -0.046608, -0.01485, 0.265152], [1636.0, -0.045501, -0.017491, 0.264489], [1667.0, -0.044534, -0.021251, 0.263734], [1699.0, -0.043245, -0.025386, 0.262858], [1748.0, -0.041717, -0.029843, 0.261822], [1780.0, -0.040339, -0.033472, 0.260614], [1812.0, -0.039061, -0.036682, 0.259292], [1844.0, -0.038104, -0.039319, 0.257727], [1878.0, -0.036051, -0.042312, 0.255203], [1908.0, -0.033501, -0.046032, 0.25244], [1939.0, -0.031555, -0.048771, 0.249144], [1971.0, -0.029347, -0.051989, 0.246163], [2004.0, -0.02686, -0.053903, 0.243743], [2035.0, -0.024591, -0.055939, 0.241779], [2068.0, -0.021561, -0.058998, 0.240168], [2101.0, -0.018631, -0.061471, 0.238734], [2147.0, -0.015398, -0.06209, 0.23737], [2180.0, -0.012524, -0.062422, 0.23613], [2212.0, -0.008863, -0.064021, 0.234969], [2244.0, -0.005463, -0.064655, 0.233834], [2275.0, -0.001519, -0.06487, 0.232746], [2307.0, 0.001623, -0.064687, 0.231685], [2341.0, 0.006279, -0.063482, 0.230566], [2371.0, 0.010256, -0.063999, 0.229597], [2403.0, 0.014998, -0.063606, 0.228678], [2439.0, 0.018835, -0.063598, 0.227862], [2471.0, 0.022759, -0.063341, 0.227126], [2500.0, 0.027472, -0.063506, 0.226487], [2548.0, 0.031457, -0.063506, 0.225925], [2580.0, 0.035122, -0.063698, 0.225447], [2612.0, 0.039107, -0.062225, 0.225008], [2645.0, 0.043232, -0.060843, 0.22463], [2676.0, 0.046738, -0.05933, 0.224277], [2708.0, 0.050232, -0.058749, 0.223973], [2740.0, 0.053536, -0.057178, 0.223709], [2772.0, 0.056334, -0.055669, 0.223481], [2804.0, 0.059089, -0.052402, 0.22327], [2836.0, 0.061955, -0.051082, 0.223086], [2868.0, 0.064056, -0.047978, 0.222927], [2900.0, 0.066252, -0.045846, 0.222795], [2948.0, 0.068221, -0.041842, 0.222687], [2980.0, 0.070113, -0.038198, 0.222599], [3012.0, 0.071511, -0.034995, 0.22253], [3044.0, 0.072783, -0.029621, 0.222478], [3076.0, 0.07364, -0.025878, 0.222458], [3109.0, 0.07467, -0.020566, 0.222468], [3141.0, 0.075431, -0.016742, 0.222573], [3172.0, 0.076042, -0.015285, 0.22292], [3204.0, 0.076304, -0.007655, 0.223523], [3236.0, 0.076318, -0.002282, 0.224486], [3268.0, 0.076263, 0.00256, 0.225765], [3300.0, 0.075736, 0.00758, 0.227136], [3348.0, 0.074704, 0.009141, 0.228709], [3380.0, 0.07373, 0.011899, 0.230313], [3412.0, 0.072873, 0.018105, 0.231853], [3443.0, 0.072018, 0.020522, 0.233354], [3476.0, 0.070717, 0.023221, 0.234926], [3508.0, 0.069183, 0.026617, 0.23639], [3539.0, 0.067667, 0.030646, 0.237899], [3571.0, 0.065678, 0.033118, 0.239343], [3603.0, 0.064407, 0.036251, 0.240914], [3636.0, 0.062749, 0.037236, 0.242651], [3667.0, 0.06065, 0.03928, 0.24403], [3715.0, 0.058844, 0.039828, 0.245485], [3747.0, 0.056679, 0.041922, 0.24672], [3780.0, 0.054777, 0.042892, 0.247861], [3810.0, 0.052633, 0.044689, 0.248932], [3843.0, 0.051022, 0.044523, 0.250171], [3875.0, 0.048838, 0.045577, 0.251211], [3907.0, 0.046721, 0.045427, 0.252141], [3939.0, 0.044803, 0.045115, 0.253041], [3971.0, 0.043454, 0.044662, 0.253887], [4003.0, 0.042299, 0.045267, 0.254491], [4035.0, 0.041354, 0.044888, 0.254947], [4067.0, 0.040567, 0.044672, 0.255309], [4115.0, 0.040085, 0.044518, 0.255587], [4147.0, 0.0399, 0.043792, 0.255815], [4179.0, 0.039687, 0.043451, 0.255995], [4211.0, 0.039401, 0.043323, 0.256133]]'