In [1]:
%cd '/scratch/sk7898/deep_radar'
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Model
from sklearn.cluster import KMeans, SpectralClustering
from sklearn import metrics
from data import get_fft_data
from clustering_utils import *

/scratch/sk7898/deep_radar


Using TensorFlow backend.


In [2]:
def get_embeddings(model_path, data_dir, cls, layer_name, pca_train):
    
    X, X_test, y, y_test, old_y, old_y_test, _, _ = get_fft_data(data_dir, sel_cls=cls, data_mode='amp')    
    old_y, y_test, old_y_test = old_y.flatten(), y_test.flatten(), old_y_test.flatten()

    model = load_model(model_path)    
    model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
    #print(model.summary())
    emb_train = model.predict(x=X)
    emb_test = model.predict(x=X_test)
    
    X_train, X_test = get_pca_comps(emb_train, emb_test) if pca_train else (emb_train, emb_test)
    
    return X_train, X_test, old_y, old_y_test

In [3]:
def print_row(cls, cluster_id, n_samples,
              dist_sum=None, dist_mean=None, dist_max=None, 
              thresh_clusters=None,
              end_cls=False):
    
    if dist_mean:
        print(
            f'{cls: ^7}'
            f'{cluster_id: ^10}'
            f'{n_samples:^10}'
            f'{" ":<4} {dist_mean:.2f}'
            f'{" ":<4} {dist_max:.2f}'
            f'{" ":<8} {thresh_clusters}'
            )
    elif dist_sum and not end_cls:
        print(
            f'{cls: ^7}'
            f'{cluster_id: ^10}'
            f'{n_samples:^10}'
            f'{" ":<4} {dist_sum:.2f}'
            )  
    elif dist_sum and end_cls:
        print('--------------------------------------------------------')
        print(
            f'{cls: ^7}'
            f'{"#Clusters:"}{cluster_id: ^4}'
            f'{"#Samples:"}{n_samples:^8}'
            f'{"Avg Distance:"} {dist_sum:.2f}'
            )   
        print('--------------------------------------------------------')
    else:
        print('Unrecognized options for print!')
    

In [4]:
def print_h_score_clusters(count_cls, 
                           h_clusters, h_modes, n_samples, 
                           cluster_dist_matrix,
                           mean_stats=False):
    if mean_stats:
        print(f'{"Label": ^5} | {"Cluster Id": ^8} | {"Samples": ^8} |'
              f'{"Mean Dist": ^8} | {"Max Dist": ^8} | {"Clusters with Dist > Mean Dist"}\n')
    else:
        print(f'{"Label": ^5} | {"Cluster Id": ^8} | {"Samples": ^8} | {"Avg. Dist": ^10}') 
       
    n_clusters = cluster_dist_matrix.shape[0]
    for cls in count_cls:
        indices = h_modes == cls
        if len(indices) > 0:
            cls_clusters = h_clusters[indices]
            cls_samples = n_samples[indices]
            clust_dists = cluster_dist_matrix[cls_clusters]
            dist_sums = np.sum(clust_dists, axis=1)
            
            if mean_stats:
                dist_means = np.mean(clust_dists, axis=1)
                dist_maxs = np.max(clust_dists, axis=1)
                clust_with_thresh_dist = [len(clust[clust>dist]) for clust, dist in zip(clust_dists, dist_means)]

            for i, (clust, smpls) in enumerate(zip(cls_clusters, cls_samples)):
                if mean_stats:
                    print_row(cls, clust, smpls, 
                              dist_sum=dist_sums[i],
                              dist_mean=dist_means[i]/n_clusters, 
                              dist_max=dist_maxs[i], 
                              thresh_clusters=clust_with_thresh_dist[i])
                else:
                    print_row(cls, clust, smpls, 
                              dist_sum=dist_sums[i]/n_clusters)

            if not mean_stats:
                print_row(cls, len(cls_clusters), cls_samples.sum(), 
                          dist_sum=np.mean(dist_sums)/n_clusters, 
                          end_cls=True)

In [5]:
def get_clust_dist_matrix(cluster_cents, cluster_modes, n_clusters):
        
    if len(cluster_modes.shape) == 1:
        cluster_modes = cluster_modes.reshape(cluster_modes.shape[0], 1)
        
    weights = 1 + np.abs(np.subtract(cluster_modes, cluster_modes.T))
    clust_dists = weights * metrics.pairwise.euclidean_distances(cluster_cents)
        
    return clust_dists

In [6]:
def get_homogeneity_score(labels, threshold):
    label_idx = np.unique(labels, return_inverse=True)[1]
    pi = np.bincount(label_idx).astype(np.float)
    pi = pi[pi > 0]
    pi_sum = np.sum(pi)
    probs = pi/pi_sum
    scores_above_thresh = probs[probs >= threshold]
    print(probs)
    
    return_val = max(scores_above_thresh) if len(scores_above_thresh) > 0 else 0
    return return_val

In [7]:
def X_kmeans(X, y, n_clusters, n_classes=4):
        
    kmeans = KMeans(init='k-means++', n_clusters=n_clusters, n_init=n_classes)
    kmeans.fit(X)
        
    cluster_labels = kmeans.predict(X)        
    cluster_cents = kmeans.cluster_centers_
    cluster_modes = get_cluster_mode(y, cluster_labels, n_clusters=n_clusters)
    
    return kmeans, cluster_labels, cluster_cents, cluster_modes

In [8]:
def ideal_homogeneous_clusters(y_true, cluster_labels, cluster_modes, n_clusters, homogeneity_threshold=0.8):
    homogeneous_clusters = []
    homogeneous_labels = [] 
    homogeneity_scores = []
    
    for i in range(n_clusters):
        indexes = cluster_labels.flatten() == i
        score = get_homogeneity_score(y_true[indexes], homogeneity_threshold)
        #score = metrics.cluster.homogeneity_score(y_true[indexes], cluster_labels[indexes])
        if score > 0:                        # score > homogeneity_threshold
            homogeneous_clusters.append(i)
            homogeneous_labels.append(cluster_modes[i])
            homogeneity_scores.append(score)

    return np.array(homogeneity_scores), np.array(homogeneous_clusters), np.array(homogeneous_labels)

In [9]:
def dist_check(to_add, to_remove, dists, threshold):
    all_but_one = np.concatenate([dists[:to_remove], dists[to_remove+1:]])
    if np.min(all_but_one) > threshold:
        return True
    return False

In [10]:
def ideal_dist_clusters(X,
                        h_clusters, 
                        cluster_labels,
                        cluster_modes, 
                        n_clusters,
                        dist_threshold=0):
    
    ideal_clusters, X_min_dists = [], []
    dist_dict = {}
    
    # For all cluster i in Homogeneous clusters H, get the pairwise distance of all samples of cluster i 
    # with all the other clusters in Homogeneous clusters which do not belong to the same class (calculated by the mode)
    for clust_id in h_clusters:
        indexes, smpls_end_idx = [], [0]
        end_idx_dict = {}
        idx = [i for i in h_clusters if cluster_modes[i] != cluster_modes[clust_id]]
        for c_id in idx:
            idxs, _ = np.where(cluster_labels.reshape(-1, 1) == c_id)
            end_idx = smpls_end_idx[-1] + len(idxs) - 1
            smpls_end_idx.append(end_idx)
            end_idx_dict[end_idx] = c_id
            indexes += list(idxs)
           
        X_ref = X[cluster_labels.flatten() == clust_id]
        X_clusters = X[indexes]
        dists = metrics.pairwise.euclidean_distances(X_ref, X_clusters)
        # Get the minimum value from the pairwise distances matrix
        min_dist = np.min(dists)
        
        # Add the cluster to the list of ideal_clusters if the minimum distance is above the dist_threshold
        # Save all the cluster pairs which has violate the dist_threshold requirement
        if min_dist > dist_threshold:
            X_min_dists.append(min_dist)
            ideal_clusters.append(clust_id)
        else:
            smpls_end_idx = np.array(smpls_end_idx)
            idx_x, idx_y = np.where(dists <= dist_threshold)
            for x, y in zip(idx_x, idx_y):
                temp_idxs = smpls_end_idx >= y
                if len(smpls_end_idx[temp_idxs]) > 0:
                    end_idx = smpls_end_idx[temp_idxs][0]
                    dist_dict[(clust_id, end_idx_dict[end_idx])] = dists[x, y]
        
    # For (c1, c2) cluster pair, if c2 is the only one causing violation in c1 and the other way around
    # Include c1 or c2 depending on the class representations we already have in the ideal_clusters
    modes_till_now, mode_counts = np.unique(cluster_modes[ideal_clusters], return_counts=True)
    for key, val in dist_dict.items():
        c1, c2 = key[0], key[1]
        c1_cls, c2_cls = cluster_modes[c1], cluster_modes[c2]
        if (c2, c1) in dist_dict.keys() and c1 not in ideal_clusters and c2 not in ideal_clusters:
            count_c1 = mode_counts[modes_till_now == c1_cls] if c1_cls in modes_till_now else 0
            count_c2 = mode_counts[modes_till_now == c2_cls] if c2_cls in modes_till_now else 0
            if count_c1 < count_c2 and dist_check(c1, c2, dists[c1], dist_threshold):
                ideal_clusters.append(c1)
                X_min_dists.append(val)
            elif dist_check(c2, c1, dists[c2], dist_threshold):
                ideal_clusters.append(c2)
                X_min_dists.append(val)
            else:
                pass
            
    return ideal_clusters, X_min_dists

In [11]:
def get_ideal_clusters(X, y_true, 
                       cluster_labels, 
                       n_clusters, 
                       cluster_cents, 
                       cluster_modes,
                       count_cls=[1, 2, 3, 4],
                       homogeneity_threshold=0.8,
                       dist_threshold=0,
                       mean_stats=False,
                       verbose=0):
    
    cluster_dist_matrix = get_clust_dist_matrix(cluster_cents,
                                                cluster_modes, 
                                                n_clusters)
    h_scores, h_clusters, h_modes = ideal_homogeneous_clusters(y_true, 
                                                               cluster_labels,
                                                               cluster_modes,
                                                               n_clusters,
                                                               homogeneity_threshold=homogeneity_threshold)
    n_samples = get_n_samples(h_clusters, cluster_labels)        
     
    if verbose:
        print_h_score_clusters(count_cls, h_clusters, h_modes, n_samples, cluster_dist_matrix, 
                               mean_stats=mean_stats)
    
    if dist_threshold:
        ideal_clusters, cluster_min_dists = ideal_dist_clusters(X,
                                                                h_clusters,
                                                                cluster_labels,
                                                                cluster_modes, 
                                                                n_clusters, 
                                                                dist_threshold=dist_threshold)
        
        return ideal_clusters, h_clusters, h_scores, h_modes, cluster_min_dists
    
    else:
        return h_clusters, h_clusters, h_scores, h_modes, cluster_dist_matrix

In [12]:
def relabel_points(X_subset, y_subset, 
                   non_h_indexes,
                   ideal_clusters, 
                   cluster_cents, 
                   cluster_modes,
                   dist_diff_thresh=None):
    
    new_y = y_subset.copy()
    was_changed = np.zeros(len(y_subset), dtype='int')
    
    h_cents = cluster_cents[ideal_clusters]
    ideal_clust_modes = cluster_modes[ideal_clusters]
    dists = metrics.pairwise.euclidean_distances(X_subset, h_cents)
    min_dist = np.min(dists, axis=1)
    min_dist_idx = np.argmin(dists, axis=1)
    
    
    for idx, x in enumerate(X_subset):
        closest_cluster = ideal_clusters[min_dist_idx[idx]]
        closest_cluster_cls = cluster_modes[closest_cluster]
        x_clust_dists = dists[idx][ideal_clust_modes != closest_cluster_cls]
        min_diff_dist = np.min(x_clust_dists)/min_dist[idx]
        #min_diff_dist = np.min(x_clust_dists - min_dist[idx])
        if min_diff_dist > dist_diff_thresh:
            new_y[idx] = closest_cluster_cls
            was_changed[idx] = 1
            
    return new_y, was_changed

In [13]:
def get_cluster_samples(X_train, y_train,
                        ideal_clusters,
                        h_clusters,
                        cluster_labels):
    
    indexes, non_h_indexes, h_indexes = [], [], []
    cluster_labels = cluster_labels.reshape(-1, 1)
    
    for clust_id in ideal_clusters:
        idxs, _ = np.where(cluster_labels == clust_id)
        indexes += list(idxs)

    for clust_id in h_clusters:
        idxs, _ = np.where(cluster_labels == clust_id)
        h_indexes += list(idxs)
        
    X_subset = X_train[indexes]
    c_labels = cluster_labels[indexes]
    c_labels = c_labels.flatten()
    
    non_h_indexes = [i for i in range(X_train.shape[0]) if i not in indexes]
    X_non_h_subset = X_train[non_h_indexes]
    y_non_h_subset = y_train[non_h_indexes]
    
    return X_subset, c_labels, X_non_h_subset, y_non_h_subset, indexes, non_h_indexes

In [14]:
def plots(axs, X, y,
          row_plt, 
          col_idx,
          row_idx=None,
          palette=None,
          param_str=None,
          X_changed=None,
          ideal_clusters=None,
          plot_cents=None,
          annotation=None):

    c_palette = palette if palette is not None else sns.color_palette("bright", len(np.unique(y)))
    ax = axs[row_idx, col_idx] if row_idx is not None else axs[col_idx]
    ax.set_xlim([-20, 20])
        
    if (row_idx == 0 or row_idx is None) and param_str:
        ax.set_title(param_str)
        
    if row_plt == 'pca':
        sns.scatterplot(X[:, 0], X[:, 1], 
                        hue=y,
                        legend='full', 
                        palette=c_palette,
                        ax=ax)
        
    if row_plt == 'samples':
        if plot_cents is None or annotation is None:
            raise ValueError('Missing list of number of samples to plot')
            
        p2 = sns.scatterplot(X[:, 0], X[:, 1], 
                             hue=y,
                             legend=False, 
                             palette=c_palette,
                             ax=ax)
        
        for i, c in enumerate(plot_cents):
            p2.text(c[0], c[1],
                    annotation[i],
                    horizontalalignment='left',
                    size='large',
                    color='black',
                    weight='semibold')
        
    if row_plt == 'dist_thresh':
        if plot_cents is None or annotation is None:
            raise ValueError('Missing list of number of samples to plot')
        
        p1 = sns.scatterplot(X[:, 0], X[:, 1], 
                             hue=y,
                             legend=False, 
                             palette=c_palette,
                             ax=ax)

        for i, c in enumerate(plot_cents):
            p1.text(c[0], c[1],
                    round(annotation[i], 1),
                    horizontalalignment='left',
                    bbox=dict(facecolor='red', alpha=0.5),
                    size='large',
                    color='white',
                    weight='bold')
    
    if row_plt == 'relabel':
        if plot_cents is None or annotation is None or ideal_clusters is None:
            raise ValueError('Missing list of number of samples to plot')
            
        p1 = sns.scatterplot(X[:, 0], X[:, 1], 
                             legend=False, 
                             color='black',
                             ax=ax)
    
        p2 = sns.scatterplot(X_changed[:, 0], X_changed[:, 1], 
                             hue=y,
                             legend=False, 
                             palette=c_palette,
                             ax=ax)
        
        for i, c in enumerate(plot_cents):
            p1.text(c[0], c[1],
                    annotation[ideal_clusters[i]],
                    horizontalalignment='left',
                    bbox=dict(facecolor='red', alpha=0.5),
                    size='large',
                    color='white',
                    weight='bold')

In [15]:
regression = True
sel_cls = [1, 2, 3, 4]
n_classes = len(sel_cls)
radar_dir = '/scratch/sk7898/radar_data/pedbike'
data_dir = os.path.join(radar_dir, 'regression_fft_data')

if regression:
    cls_str = '1_2_3_4'
    model_dir = os.path.join(radar_dir, 'models/lstm') 
    layer_name = 'counting_dense_2'
    model_str = os.path.join(cls_str + '_amp_512_hidden_128/model_best_valid_loss_dp_4.h5')
    model_path = os.path.join(model_dir, model_str)
else:
    radar_type = 'Bumblebee'
    model_dir =  os.path.join(radar_dir, 'models/upstream', radar_type)
    layer_name = 'dense_1'
    model_path = os.path.join(model_dir, 'best_model.h5') 

In [16]:
n_clusters, homogeneity_threshold, dist_threshold, diff_thresh = 60, 0.98, 0, 1.5
cluster_list = [60]
#h_thresh_list = [0.98, 0.99, 1]
#dist_thresh_list = [0.5, 1.5, 2.5, 3.5]
plot_row_list = ['samples'] #'dist_thresh', 'relabel']

pca_train = False
plot_pca = False
plot_ideal = True
plot_rows = len(plot_row_list)
plot_cols = len(cluster_list)
plot_height = 7 * plot_rows
plot_width = 9 * plot_cols

X_train, X_test, old_y, old_y_test = get_embeddings(model_path, 
                                                    data_dir,
                                                    cls=sel_cls, 
                                                    layer_name=layer_name, 
                                                    pca_train=pca_train)
print(X_train.shape, X_test.shape)
    
if pca_train:
    sns.set(rc={'figure.figsize':(plot_width, plot_height)})    
    fig, axs = plt.subplots(nrows=plot_rows, ncols=plot_cols, sharey=True)

    if plot_pca:
        row_idx = None if len(plot_row_list) == 1 else 0
        plots(axs, X_train, old_y, 
              row_plt='pca', 
              param_str='X_train (PCA)',
              row_idx=row_idx, col_idx=0)
        
for idx, n_clusters in enumerate(cluster_list): 
    kmeans, cluster_labels, cluster_cents, cluster_modes = X_kmeans(X_train, old_y,
                                                                    n_clusters=n_clusters,
                                                                    n_classes=n_classes)
    
    ideal_clusters, h_clusters, scores, modes, dists = get_ideal_clusters(X_train, 
                                                                          old_y, 
                                                                          cluster_labels,
                                                                          n_clusters,
                                                                          cluster_cents,
                                                                          cluster_modes,
                                                                          count_cls=sel_cls,
                                                                          homogeneity_threshold=homogeneity_threshold,
                                                                          dist_threshold=dist_threshold)
    
    print(ideal_clusters, h_clusters, scores)
    
    X_subset, c_labels, X_non_h_subset, y_non_h_subset, h_indexes, non_h_indexes = get_cluster_samples(X_train, old_y,
                                                                                                       ideal_clusters,
                                                                                                       h_clusters,
                                                                                                       cluster_labels)
    new_y, was_changed = relabel_points(X_non_h_subset, y_non_h_subset, 
                                        non_h_indexes, 
                                        ideal_clusters, 
                                        cluster_cents, 
                                        cluster_modes, 
                                        dist_diff_thresh=diff_thresh)
    
    X_changed = X_non_h_subset[was_changed == 1]
    y_changed = y_non_h_subset[was_changed == 1]
    
    param_str = f'n_clusters:{n_clusters} {"|"} h_thresh:{homogeneity_threshold}'
    if dist_threshold:
        param_str += f' {"|"} d_thresh:{dist_threshold}'
    if diff_thresh:
        param_str += f' {"|"} diff_thresh:{diff_thresh}'
   
    if pca_train:
        for i, row_plt in enumerate(plot_row_list):
            row_idx = None if len(plot_row_list) == 1 else i
            col_idx = idx+1 if plot_pca else idx
            plot_cents = [cluster_cents[clust_id] for clust_id in ideal_clusters]

            if row_plt == 'samples':
                clusters, X, y = (ideal_clusters, X_subset, c_labels) if plot_ideal else (h_clusters, X_train[h_indexes], cluster_labels[h_indexes])
                n_samples = get_n_samples(clusters, cluster_labels.flatten())
                plots(axs, X, y.flatten(),
                      row_plt,
                      param_str=param_str,
                      plot_cents=plot_cents,
                      annotation=n_samples,
                      row_idx=row_idx, col_idx=col_idx)

            if row_plt == 'dist_thresh':
                plots(axs, X_subset, c_labels,
                      row_plt,
                      param_str=param_str,
                      plot_cents=plot_cents,
                      annotation=dists,
                      row_idx=row_idx, col_idx=col_idx)

            if row_plt == 'relabel': 
                print(f'Relabeled Samples: {len(y_changed != y_non_h_subset[was_changed == 1])}')
                plots(axs, X_subset, y_changed,
                      row_plt,
                      param_str=param_str,
                      X_changed=X_changed,
                      plot_cents=plot_cents,
                      annotation=cluster_modes,
                      ideal_clusters=ideal_clusters,
                      row_idx=row_idx, col_idx=col_idx)           

(16017, 64) (1780, 64)
[0.64895636 0.25616698 0.06641366 0.028463  ]
[0.00735294 0.99264706]
[0.00564972 0.94350282 0.05084746]
[0.95752896 0.02702703 0.01544402]
[1.]
[0.00406504 0.98780488 0.00813008]
[0.10638298 0.89361702]
[0.02766798 0.21343874 0.72727273 0.03162055]
[0.00483092 0.0531401  0.94202899]
[0.99009901 0.00990099]
[0.00584795 0.91520468 0.07894737]
[0.0026738  0.94385027 0.02941176 0.02406417]
[0.04487179 0.89102564 0.06410256]
[0.92995169 0.0410628  0.02898551]
[0.18479685 0.31323722 0.20314548 0.29882045]
[0.98901099 0.01098901]
[0.02564103 0.97435897]
[0.02339181 0.96491228 0.01169591]
[0.98958333 0.01041667]
[0.18402778 0.81597222]
[0.97452229 0.02547771]
[0.01408451 0.94014085 0.04577465]
[1.]
[0.00990099 0.00990099 0.01485149 0.96534653]
[0.01857143 0.03571429 0.84       0.10571429]
[0.51476793 0.17299578 0.31223629]
[1.]
[0.03603604 0.94594595 0.01801802]
[0.03351955 0.96089385 0.00558659]
[0.01639344 0.0273224  0.95628415]
[0.04552846 0.63414634 0.24390244 0.076