# [How to do Unsupervised Clustering with Keras](https://www.dlology.com/blog/how-to-do-unsupervised-clustering-with-keras/) | DLology

Read my blog post for details.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import keras.backend as K
import pickle, math, datetime
from pprint import pformat
from graphviz import Digraph
from itertools import combinations
from scipy.interpolate import interp1d
from time import time
from keras.engine.topology import Layer, InputSpec
from keras.layers import Dense, Input
from keras.models import Model
from keras.optimizers import SGD
from keras import callbacks
from keras.initializers import VarianceScaling
from sklearn.cluster import KMeans
from sklearn import metrics
import hmms
import os
import math

from TrajectoryProcess.processdata import *
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
np.random.seed(10)

Using TensorFlow backend.


In [2]:
# Configuration parameters
dataFile = 'atc-20121114-down-0-5.txt'
social_dist = 4000
debug = False
save_time = 600  # every ten minutes
show_time = 10  # every 10 seconds
free_cores = 1
n_digits = 8  # len(str(n_tr))
min_trajectory_duration = 10  # in seconds
N_nodes = 81

# constants
# these are related to the data file
colUnits = ['time_s', '--', 'mm', 'mm', 'mm', 'mm/s', 'rad', 'rad']
colNames = ['t', 'id', 'x', 'y', 'z', 'v', 'h', 'fh']
dtypes = {'t': np.float64, 'id': np.int64, 'x': np.int64, 'y': np.int64,
          'z': np.float64, 'v': np.float64, 'h': np.float64, 'fh': np.float64}

# Derived parametes: they are not free input but obtained from other params
code = dataFile[0:-4]
setBaseFolder()  # current
dataDir = getDataFolder()
dataURI = dataDir + dataFile

# Clean and Load data: This cell is copyied from 1_RAW
allData = loadRawData(code, min_trajectory_duration, colNames, dtypes, dataURI)
(trajectoriesIDs, n_trajs) = getTrajectoryIDs(code, allData)

Dataframe has (13220) trajectories, (1289981) entries 


In [3]:
# Get QTC-C sequences from trajectory pairs
QTC_C_seqs_full = []
traj_pairs = []
with open(getResultsFolder(code) + 'pair_set_under_' + str(social_dist) + '.pickle', 'rb') as fp:
    pair_set = pickle.load(fp)
    downsample_secs = 0.5

    for pair in pair_set:
        dataframe_p = getDataFrame(pair[0], trajectoriesIDs, allData)
        dataframe_q = getDataFrame(pair[1], trajectoriesIDs, allData)

        df_1 = dataframe_p
        df_2 = dataframe_q

        # we need a common timeline
        t1 = df_1.index
        t2 = df_2.index

        # time where the two trajs happened simultaneously:
        tmin = max(t1.min(), t2.min())
        tmax = min(t1.max(), t2.max())
        t = np.sort(np.unique(np.concatenate((t1, t2))))
        t = t[t >= tmin]
        t = t[t <= tmax]

        df_1 = df_1.reindex(t)
        df_1 = df_1.ffill()
        df_1 = df_1.bfill()

        df_2 = df_2.reindex(t)
        df_2 = df_2.ffill()
        df_2 = df_2.bfill()

        # Resample dataframes to every 'downsample_secs' seconds
        df_1.index = pd.to_datetime(df_1.index.values, unit="s")
        df_1 = df_1.resample(str(downsample_secs) + "S").mean()

        df_2.index = pd.to_datetime(df_2.index.values, unit="s")
        df_2 = df_2.resample(str(downsample_secs) + "S").mean()

        # Convert df indices into float seconds timestamps
        df_1.index = df_1.index.values.astype('timedelta64[s]').astype("float") / (10 ** 9)
        df_2.index = df_2.index.values.astype('timedelta64[s]').astype("float") / (10 ** 9)

        dataframe_p = df_1
        dataframe_q = df_2
        
        p_x = dataframe_p.x.values
        p_y = dataframe_p.y.values
        q_x = dataframe_q.x.values
        q_y = dataframe_q.y.values

        # Filter out datapoints where distance > 4m
        distances = [math.sqrt(((p_x[i] - q_x[i]) ** 2) + ((p_y[i] - q_y[i]) ** 2)) for i in range(0, len(p_x))]
        in_dist_ix = [i for i in range(0, len(distances)) if distances[i] <= 4000]

        dataframe_p = dataframe_p.iloc[in_dist_ix,:]
        dataframe_q = dataframe_q.iloc[in_dist_ix,:]

        QSR_seq = obtainQRSeq(dataframe_p, dataframe_q, "QTCc")
        QTC_C_seqs_full.append(QSR_seq)
        traj_pairs.append([dataframe_p, dataframe_q])

QTC_C_seqs_full

No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR interaction found!
No QSR inter

[[('+', '-', '-', '-', 3787.3173883370273, 1352858144.0),
  ('+', '-', '-', '-', 3678.3110254572002, 1352858144.5),
  ('+', '-', '-', '-', 3439.375015900418, 1352858145.0),
  ('-', '-', '+', '-', 3259.522971233674, 1352858145.5),
  ('+', '-', '+', '-', 3087.093455015575, 1352858146.0),
  ('-', '-', '-', '-', 2883.5769887415872, 1352858146.5),
  ('+', '-', '-', '-', 2609.5923053228066, 1352858147.0),
  ('+', '-', '+', '-', 2361.0953496205952, 1352858147.5),
  ('-', '-', '-', '-', 1885.5156456524035, 1352858148.0),
  ('-', '-', '+', '+', 1354.5538010725156, 1352858148.5),
  ('+', '-', '-', '+', 1112.645720793461, 1352858149.0),
  ('-', '-', '-', '+', 929.5186926576571, 1352858149.5),
  ('-', '-', '+', '+', 706.6972477659723, 1352858150.0),
  ('-', '-', '-', '+', 660.4975397380372, 1352858150.5),
  ('-', '-', '-', '+', 653.6094017071664, 1352858151.0),
  ('+', '-', '-', '-', 631.476048635259, 1352858151.5),
  ('+', '+', '-', '+', 640.2925112790247, 1352858152.0),
  ('-', '+', '+', '+', 75

In [4]:
# Filter out states with dist > 4m, and filter out sequences of len < 2
QTC_C_seqs_full = [[state for state in seq if state[4] <= 4000] for seq in QTC_C_seqs_full]

filt_traj_pairs_indices = []
QTC_C_seqs_filt = []
for i in range(0, len(QTC_C_seqs_full)):
    if len(QTC_C_seqs_full[i]) > 1:
        QTC_C_seqs_filt.append(QTC_C_seqs_full[i])
        filt_traj_pairs_indices.append(i)

# QTC_C_seqs_full = [seq for seq in QTC_C_seqs_full if len(seq) > 1]
QTC_C_seqs_full = QTC_C_seqs_filt

print(len(QTC_C_seqs_full))
QTC_C_seqs_full

523


[[('+', '-', '-', '-', 3787.3173883370273, 1352858144.0),
  ('+', '-', '-', '-', 3678.3110254572002, 1352858144.5),
  ('+', '-', '-', '-', 3439.375015900418, 1352858145.0),
  ('-', '-', '+', '-', 3259.522971233674, 1352858145.5),
  ('+', '-', '+', '-', 3087.093455015575, 1352858146.0),
  ('-', '-', '-', '-', 2883.5769887415872, 1352858146.5),
  ('+', '-', '-', '-', 2609.5923053228066, 1352858147.0),
  ('+', '-', '+', '-', 2361.0953496205952, 1352858147.5),
  ('-', '-', '-', '-', 1885.5156456524035, 1352858148.0),
  ('-', '-', '+', '+', 1354.5538010725156, 1352858148.5),
  ('+', '-', '-', '+', 1112.645720793461, 1352858149.0),
  ('-', '-', '-', '+', 929.5186926576571, 1352858149.5),
  ('-', '-', '+', '+', 706.6972477659723, 1352858150.0),
  ('-', '-', '-', '+', 660.4975397380372, 1352858150.5),
  ('-', '-', '-', '+', 653.6094017071664, 1352858151.0),
  ('+', '-', '-', '-', 631.476048635259, 1352858151.5),
  ('+', '+', '-', '+', 640.2925112790247, 1352858152.0),
  ('-', '+', '+', '+', 75

In [5]:
print(len(QTC_C_seqs_full))

523


In [6]:
traj_pairs = [traj_pairs[i] for i in filt_traj_pairs_indices]
len(traj_pairs)

523

In [7]:
QTC_C_seqs = []
QTC_C_seqs = [[state[0] + state[1] +
               state[2] + state[3] for state in seq] for seq in QTC_C_seqs_full]
print(QTC_C_seqs)
len(QTC_C_seqs)

[['+---', '+---', '+---', '--+-', '+-+-', '----', '+---', '+-+-', '----', '--++', '+--+', '---+', '--++', '---+', '---+', '+---', '++-+', '-+++', '--++', '+--+', '-+++', '-+++', '+-++', '----', '--++', '+--+', '+---', '++--', '-++-', '+++-', '++--', '+++-', '-++-', '-++-', '-+--', '-+--', '-++-', '-++-', '+++-', '++--', '+++-', '+-+-', '+-+-', '+-++', '+-+-', '+--+', '+-+-', '+-+-', '+---', '+---', '+-+-', '+-+-', '+-+-', '--+-', '+-+-'], ['0+0-', '+++-', '+++-', '++++'], ['-++-', '-++-'], ['--+-', '+-+-', '+++-', '++--'], ['+---', '+---', '+---'], ['+-+-', '+-+-', '--+-', '+-+-', '----', '+---', '+++-', '-++-', '++--', '-++-', '-++-', '-+--'], ['--++', '--++', '++++', '++++'], ['0+0+', '---+', '+--+', '++--', '++--', '++-+', '++-+', '-+-+', '++-+', '+--+', '--+-', '-+-+', '+-+-', '+---', '+++-', '+-++', '-+-+', '--+-', '----', '++++', '-++-', '--++'], ['-++-', '-++-', '-++-', '-++-', '-++-', '-++-', '-++-', '+-+-', '+++-', '+++-', '+++-', '-++-', '-++-', '-++-', '-++-', '-++-', '-++-'

523

In [8]:
print(len(QTC_C_seqs))

523


In [9]:
# Create list of QTC_C states so that indices can be used as integer state IDs compatible with HMM library
QTC_symbols = []
for i in range(0,4):
    QTC_symbols.append("-")
    QTC_symbols.append("0")
    QTC_symbols.append("+")
print("QTC symbols:", QTC_symbols[:3])
QTC_C_states = list(combinations(QTC_symbols, 4))
QTC_C_states = [state[0] + state[1] + state[2] + state[3] for state in QTC_C_states]
QTC_C_states = list(np.unique(QTC_C_states))
print("QTC_C states:\n", QTC_C_states)
print(len(QTC_C_states), "states total")

QTC symbols: ['-', '0', '+']
QTC_C states:
 ['++++', '+++-', '+++0', '++-+', '++--', '++-0', '++0+', '++0-', '++00', '+-++', '+-+-', '+-+0', '+--+', '+---', '+--0', '+-0+', '+-0-', '+-00', '+0++', '+0+-', '+0+0', '+0-+', '+0--', '+0-0', '+00+', '+00-', '+000', '-+++', '-++-', '-++0', '-+-+', '-+--', '-+-0', '-+0+', '-+0-', '-+00', '--++', '--+-', '--+0', '---+', '----', '---0', '--0+', '--0-', '--00', '-0++', '-0+-', '-0+0', '-0-+', '-0--', '-0-0', '-00+', '-00-', '-000', '0+++', '0++-', '0++0', '0+-+', '0+--', '0+-0', '0+0+', '0+0-', '0+00', '0-++', '0-+-', '0-+0', '0--+', '0---', '0--0', '0-0+', '0-0-', '0-00', '00++', '00+-', '00+0', '00-+', '00--', '00-0', '000+', '000-', '0000']
81 states total


In [10]:
def QTC_C_to_num(QTC_C):
    return QTC_C_states.index(QTC_C)


def QTC_C_seq_to_num_seq(QTC_C_seq):
    num_seq = []
    for QTC_C in QTC_C_seq:
        num_seq.append(QTC_C_to_num(QTC_C))

    return num_seq


def num_to_QTC_C(num):
    return QTC_C_states[num]


def num_seq_to_QTC_C_seq(num_seq):
    QTC_C_seq = []
    for num in num_seq:
        QTC_C_seq.append(num_to_QTC_C(num))

    return QTC_C_seq


print(QTC_C_to_num("++--"))
print(num_to_QTC_C(8))
print(num_seq_to_QTC_C_seq([0, 1, 2, 3]))
print(QTC_C_seq_to_num_seq(num_seq_to_QTC_C_seq([0, 1, 2, 3])))

symbol_seqs = []
for QTC_C_seq in QTC_C_seqs:
    symbol_seq = [QTC_C_to_num(QTC_C) for QTC_C in QTC_C_seq]
    symbol_seqs.append(np.array(symbol_seq))

symbol_seqs

4
++00
['++++', '+++-', '+++0', '++-+']
[0, 1, 2, 3]


[array([13, 13, 13, 37, 10, 40, 13, 10, 40, 36, 12, 39, 36, 39, 39, 13,  3,
        27, 36, 12, 27, 27,  9, 40, 36, 12, 13,  4, 28,  1,  4,  1, 28, 28,
        31, 31, 28, 28,  1,  4,  1, 10, 10,  9, 10, 12, 10, 10, 13, 13, 10,
        10, 10, 37, 10]),
 array([61,  1,  1,  0]),
 array([28, 28]),
 array([37, 10,  1,  4]),
 array([13, 13, 13]),
 array([10, 10, 37, 10, 40, 13,  1, 28,  4, 28, 28, 31]),
 array([36, 36,  0,  0]),
 array([60, 39, 12,  4,  4,  3,  3, 30,  3, 12, 37, 30, 10, 13,  1,  9, 30,
        37, 40,  0, 28, 36]),
 array([28, 28, 28, 28, 28, 28, 28, 10,  1,  1,  1, 28, 28, 28, 28, 28, 28,
        28, 37, 28, 10,  1, 37, 28, 37, 28, 28, 28, 28, 28, 28, 28, 28, 28,
        28, 28,  1, 28, 28, 28, 28, 28]),
 array([40, 13,  4,  4]),
 array([12, 10, 13, 13, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 13, 13]),
 array([40, 40,  4]),
 array([36,  9, 12, 39,  3, 27]),
 array([28, 28, 28, 40, 40, 30, 39, 31]),
 array([27, 28, 31, 13,  9, 12, 36, 30, 36,  3, 

In [11]:
seq_lens = [len(seq) for seq in symbol_seqs]
seq_len = int(round(np.max(seq_lens)))
# seq_len = int(round(np.mean(seq_lens)))
seq_len

115

In [12]:
print(len(symbol_seqs), len(traj_pairs))

523 523


In [13]:
# seq_len_filt_ix = []
# mean_len_symbol_seqs = []
# for i in range(0,len(symbol_seqs)):
#     if len(symbol_seqs[i]) >= seq_len:
#         seq_len_filt_ix.append(i)
#         mean_len_symbol_seqs.append(symbol_seqs[i][:seq_len])    

# mean_len_symbol_seqs

# seq_len_filt_ix = []
max_len_symbol_seqs = []
for i in range(0,len(symbol_seqs)):
    if len(symbol_seqs[i]) < seq_len:
        max_len_symbol_seqs.append(np.pad(symbol_seqs[i], (0, seq_len - len(symbol_seqs[i]))))
    else:
        max_len_symbol_seqs.append(symbol_seqs[i])

max_len_symbol_seqs

[array([13, 13, 13, 37, 10, 40, 13, 10, 40, 36, 12, 39, 36, 39, 39, 13,  3,
        27, 36, 12, 27, 27,  9, 40, 36, 12, 13,  4, 28,  1,  4,  1, 28, 28,
        31, 31, 28, 28,  1,  4,  1, 10, 10,  9, 10, 12, 10, 10, 13, 13, 10,
        10, 10, 37, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([61,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0

In [14]:
# seq_len_filt_ix
# mean_len_traj_pairs = [traj_pairs[i] for i in range(0,len(traj_pairs)) if i in seq_len_filt_ix]
# len(mean_len_traj_pairs)
len(max_len_symbol_seqs)

523

In [15]:
# symbol_seq[:seq_len] for symbol_seq in symbol_seqs if len(symbol_seq) >= seq_len]
x = np.array(max_len_symbol_seqs)
# x = x.reshape((x.shape[0], len(x[0])))
# x = np.divide(x, 81.)
x

array([[13, 13, 13, ...,  0,  0,  0],
       [61,  1,  1, ...,  0,  0,  0],
       [28, 28,  0, ...,  0,  0,  0],
       ...,
       [ 9, 12, 12, ...,  0,  0,  0],
       [39,  3, 39, ...,  0,  0,  0],
       [36, 36, 36, ...,  0,  0,  0]])

In [16]:
x.shape

(523, 115)

## K-Means clustering

In [17]:
scores = []
for n_clusters in range(2, 11):
    kmeans = KMeans(n_clusters=n_clusters, n_init=20, n_jobs=8)
    labels = kmeans.fit_predict(x)
    scores.append(metrics.davies_bouldin_score(x, labels))
    
n_clusters = list(range(2, 11))[np.argmin(scores)]

In [18]:
scores

[1.438255709479031,
 1.7174148075131586,
 1.6770183490025343,
 1.9146860607227798,
 1.7970515630462363,
 1.708269518958564,
 1.8211063845407551,
 1.7561023059901448,
 1.9734761815683146]

In [19]:
n_clusters

2

In [20]:
kmeans = KMeans(n_clusters=n_clusters, n_init=20, n_jobs=8)

In [21]:
y_pred_kmeans = kmeans.fit_predict(x)

In [22]:
len(y_pred_kmeans)

523

In [26]:
# Save cluster labels to csv
with open("k-means_2_cluster_labels_QTC_C.csv", "a") as csv_file:
    for pred in y_pred_kmeans:
        pred = int(not(pred))
        csv_file.write(str(pred)+"\n")

In [29]:
plt.rcParams['figure.figsize'] = [20, 20]

In [27]:
y_pred_kmeans

array([0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
       0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
       1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,

In [50]:
# sample_no = -1

In [33]:
for sample_no in range(0,len(max_len_symbol_seqs)):
    # Plot trajectory pairs
    # sample_no += 1

    # s_seqs = np.array([mean_len_symbol_seqs[sample_no]])
    # e_seqs = np.array([mean_len_symbol_seqs[sample_no]])

    # #     print(list(map(num_to_QTC_C, s_seqs[0])))

    # ll_pb = pbHMM.full_data_estimate(s_seqs, e_seqs)
    # # print(ll_pb)
    # prob_pb = np.exp(ll_pb)

    # # print("The probability of the dataset being generated by the pass-by model is:", \
    # #       prob_pb, ".")

    # ll_CND = uniform_HMM.full_data_estimate(s_seqs, e_seqs)
    # # print(ll_CND)
    # prob_CND = np.exp(ll_CND)

    # # print("The probability of the dataset being generated by the CND model is:", \
    # #       prob_CND, ".")

    # ll_ot = otHMM.full_data_estimate(s_seqs, e_seqs)
    # # print(ll_ot)
    # prob_ot = np.exp(ll_ot)

    # # print("The probability of the dataset being generated by the overtake model is:", prob_ot, ".")

    # class_id = np.argmax([ll_pb, ll_ot, ll_CND])
    # print("Classified as", classes[class_id])

    # # Plot classified trajectory pair
    # traj_pair = traj_pairs[sample_no]
    # p_x = traj_pair[0].x.values
    # p_y = traj_pair[0].y.values
    # q_x = traj_pair[1].x.values
    # q_y = traj_pair[1].y.values

    # p_colour = np.array([[shade, 0, 0] for shade in np.linspace(0, 255, len(p_x))]) / 255.0
    # q_colour = np.array([[0, 0, shade] for shade in np.linspace(0, 255, len(q_x))]) / 255.0

    # plt.scatter(p_x, p_y, c=p_colour, zorder=10)
    # plt.plot(p_x, p_y, c="r")

    # plt.scatter(q_x, q_y, c=q_colour, zorder=10)
    # plt.plot(q_x, q_y, c="b")

    # plt.legend(["Human1", "Human2"])
    # plt.title("QTC_C model's " + classes[class_id] + " classification from ATC dataset")

    # Plot classified trajectory pair
    plt.rcParams.update({'font.size': 22})

    traj_pair = traj_pairs[sample_no]
    p_x = traj_pair[0].x.values
    p_y = traj_pair[0].y.values
    q_x = traj_pair[1].x.values
    q_y = traj_pair[1].y.values

    p_interp = interp1d(p_x, p_y)
    p_x = np.linspace(p_x[0], p_x[-1], len(p_x)*1000)
    #     p_x = np.linspace(np.min(p_x), np.max(p_x), len(p_x)*1000)
    p_y = p_interp(p_x)

    q_interp = interp1d(q_x, q_y)
    q_x = np.linspace(q_x[0], q_x[-1], len(q_x)*1000)
    #     q_x = np.linspace(np.min(q_x), np.max(q_x), len(q_x)*1000)
    q_y = q_interp(q_x)

    # Plot human1 trajectory
    # lwidths=1+p_x[:-1]
    # points = np.array([p_x, p_y]).T.reshape(-1, 1, 2)
    # segments = np.concatenate([points[:-1], points[1:]], axis=1)
    # lc = LineCollection(segments, linewidths=lwidths,color='blue')
    fig,a = plt.subplots()
    # a.add_collection(lc)
    # a.set_xlim(0,4*np.pi)
    # a.set_ylim(-1.1,1.1)

    # # Plot human2 trajectory
    # lwidths=1+(np.array(q_x)[:-1])[::-1]
    # points = np.array([q_x, q_y]).T.reshape(-1, 1, 2)
    # segments = np.concatenate([points[:-1], points[1:]], axis=1)
    # lc = LineCollection(segments, linewidths=lwidths, color='blue')

    # a.add_collection(lc)
    # a.legend(["Human1", "Human2"])

    # fig.show()

    p_colour = np.array([[shade, 0, 0] for shade in np.linspace(50, 255, len(p_x))]) / 255.0
    q_colour = np.array([[0, 0, shade] for shade in np.linspace(50, 255, len(q_x))]) / 255.0

    p_size = np.linspace(10, 100, len(p_x))
    q_size = np.linspace(10, 100, len(q_x))
    #     p_size = [i / 500 for i in range(0, len(p_x))]
    #     q_size = [i / 500 for i in range(0, len(q_x))]

    a.scatter(p_x, p_y, c=p_colour, zorder=10, s=p_size)
    a.plot(p_x, p_y, c="r")

    a.scatter(q_x, q_y, c=q_colour, zorder=10, s=p_size)
    a.plot(q_x, q_y, c="b")

    a.legend(["Human1", "Human2"])
    plt.title("Sequence "+ str(sample_no)+": cluster " + str(y_pred_kmeans[sample_no]))
    plt.savefig("plots/clusters/"+ str(y_pred_kmeans[sample_no]) + "/" + str(sample_no) + ".png")
    plt.close()
        
    print(sample_no + 1, "/", len(traj_pairs))

1 / 523
2 / 523
3 / 523
4 / 523
5 / 523
6 / 523
7 / 523
8 / 523
9 / 523
10 / 523
11 / 523
12 / 523
13 / 523
14 / 523
15 / 523
16 / 523
17 / 523
18 / 523
19 / 523
20 / 523
21 / 523
22 / 523
23 / 523
24 / 523
25 / 523
26 / 523
27 / 523
28 / 523
29 / 523
30 / 523
31 / 523
32 / 523
33 / 523
34 / 523
35 / 523
36 / 523
37 / 523
38 / 523
39 / 523
40 / 523
41 / 523
42 / 523
43 / 523
44 / 523
45 / 523
46 / 523
47 / 523
48 / 523
49 / 523
50 / 523
51 / 523
52 / 523
53 / 523
54 / 523
55 / 523
56 / 523
57 / 523
58 / 523
59 / 523
60 / 523
61 / 523
62 / 523
63 / 523
64 / 523
65 / 523
66 / 523




67 / 523
68 / 523
69 / 523
70 / 523
71 / 523
72 / 523
73 / 523
74 / 523
75 / 523
76 / 523
77 / 523
78 / 523
79 / 523
80 / 523
81 / 523
82 / 523
83 / 523
84 / 523
85 / 523
86 / 523
87 / 523
88 / 523
89 / 523
90 / 523
91 / 523
92 / 523
93 / 523
94 / 523
95 / 523
96 / 523
97 / 523
98 / 523
99 / 523
100 / 523
101 / 523
102 / 523
103 / 523
104 / 523
105 / 523
106 / 523
107 / 523
108 / 523
109 / 523
110 / 523
111 / 523
112 / 523
113 / 523
114 / 523
115 / 523
116 / 523
117 / 523
118 / 523
119 / 523
120 / 523
121 / 523
122 / 523
123 / 523
124 / 523
125 / 523
126 / 523
127 / 523
128 / 523
129 / 523
130 / 523
131 / 523
132 / 523
133 / 523
134 / 523
135 / 523
136 / 523
137 / 523
138 / 523
139 / 523
140 / 523
141 / 523
142 / 523
143 / 523
144 / 523
145 / 523
146 / 523
147 / 523
148 / 523
149 / 523
150 / 523
151 / 523
152 / 523
153 / 523
154 / 523
155 / 523
156 / 523
157 / 523
158 / 523
159 / 523
160 / 523
161 / 523
162 / 523
163 / 523
164 / 523
165 / 523
166 / 523
167 / 523
168 / 523
169 / 523
170