In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os
from PIL import Image
import yaml
from utils import *
from helpers import *
import pickle

import pdb
from tqdm import tqdm

from dlclive import DLCLive, Processor

In [None]:
class DataManager:
    def __init__(self, keys=[]):
        self.data = {key: [] for key in keys}
    
    def _print_key_error(self, key):
        print(f"Key '{key}' does not exist in the data manager.")

    def add_nan(self, key, shape):
        data_ = self.data[key]
        pop_count = 0
        while len(data_) > 0 and data_[-1] is None:
            data_.pop()
            pop_count += 1
        while pop_count > 0:
            data_.append(np.full(shape, np.nan))
            pop_count -= 1

    def add(self, key, value):
        if key in self.data:
            if value is not None:
                self.add_nan(key, value.shape)
            self.data[key].append(value)
        else:
            self._print_key_error(key)

    def to_numpy(self):
        for key in self.data:
            data_ = self.data[key]
            if len(data_) == 0 or data_[0] is None:
                self.data[key] = None
                return
            self.add_nan(key, data_[0].shape)
            self.data[key] = np.array(data_)

    def save(self, dir):
        for key in self.data:
            np.save(f'{dir}/{key}.npy', self.data.get(key))

    def load(self, dir, keys):
        if keys is None:
            keys = self.data.keys
        for key in keys:
            self.data[key] = np.load(f'{dir}/{key}.npy')
        return self.data

    def remove(self, key):
        if key in self.data:
            del self.data[key]
        else:
            self._print_key_error(key)

    def get(self, key):
        if key in self.data:
            return self.data[key]
        else:
            self._print_key_error(key)
            return None

In [None]:
class DLCManager:
    def __init__(self, 
            model_path,
            processor=Processor(),
            pcutoff=0.2,
            resize=1):
        
        self.model = DLCLive(
            model_path=model_path,
            processor=processor,
            pcutoff=pcutoff,
            # display=True,
            resize=resize)
        self.frame = None
        self.is_first_frame = None
        self.dm = DataManager()

    def init_data(self, feature_keys):
        self.dm = DataManager(feature_keys)
        self.prev_pose_xy = None
        
    def update_frame(self, frame, is_first_frame):
        self.frame = frame
        self.is_first_frame = is_first_frame
        
    def detect_pose_helper(self):
        if self.is_first_frame:
            pose = self.model.init_inference(self.frame)
        else:
            pose = self.model.get_pose(self.frame)
        return pose
        
    def detect_pose(self):
        curr_pose = self.detect_pose_helper()
        
        curr_pose_xy, curr_pose_p = curr_pose[:, :-1], curr_pose[:, -1]
        if self.prev_pose_xy is None:
            self.prev_pose_xy = curr_pose_xy

        pose = np.stack([self.prev_pose_xy, curr_pose_xy])
        feature_angles_item, pose_speed_item = extract_pose_features(pose)

        self.dm.add('xy', curr_pose_xy)
        self.dm.add('p', curr_pose_p)
        self.dm.add('angles', feature_angles_item)
        self.dm.add('speed', pose_speed_item)

        self.prev_pose_xy = curr_pose_xy

    def save_data(self, dir):
        self.dm.to_numpy()
        self.dm.save(dir)

    def load_data(self, dir, feature_keys):
        return self.dm.load(dir, feature_keys)

In [None]:
class Driver:
    def __init__(self, config_file, dev_mode = False, verbose = True):
        with open(f'configs/{config_file}.yaml', 'r') as file:
            config = yaml.safe_load(file)
        self.is_metadata_present = (config['path']['xls'] is not None)
        self.fps = config['info']['fps']
        root_dir = config['path']['root']
        self.video_dir = f"{root_dir}/{config['path']['video']}"

        if self.is_metadata_present:
            xls_path = f"{root_dir}/{config['path']['xls']}"
            self.files_info = read_octopus_xlsx(xls_path)
            self.files_info = self.files_info[self.files_info['Stim Method'] == 'Electrical']
            self.filenames = self.files_info["File Name"].to_list()
        else:
            self.filenames = [os.path.splitext(file)[0] for file in os.listdir(self.video_dir)]

        self.working_dir = f"{root_dir}/{config['path']['working']}" # to save processed data and figures

        model_path = f"{root_dir}/{config['path']['model']}"

        self.feature_keys = ['xy', 'p', 'angles', 'speed']
        self.dlc = DLCManager(model_path)
        
        self.prosvd_k = 4 # no. of dims to reduce to

        self.time_margin = (-120, 180) # to trim videos
        self.total_f = self.time_margin[1] - self.time_margin[0]

        self.init_frame_crop = 10 # No of initial frames used to set cropping info
        self.init_frame_prosvd = 90 # No of initial frames used to initialize proSVD
        self.init_frame = self.init_frame_crop + self.init_frame_prosvd

        self.tx = np.arange(self.init_frame + self.time_margin[0], self.time_margin[1])/self.fps

        if self.is_metadata_present:
            self.stim_class_list = sorted(self.files_info['Stimulation Class'].unique().tolist())
        else:
            self.stim_class_list = np.array(['Cord Electrical', 'Distal Electrical', 'Proximal Electrical'])

        if dev_mode:
            self.files_info = self.files_info.iloc[[2]]
            self.filenames = self.files_info["File Name"].to_list()

        if verbose:
            print(f"Processing {len(self.filenames)} videos from {self.video_dir}")
            if len(self.filenames) < 4:
                print('\t', end='')
                print(*self.filenames, sep="\n\t")

    def get_fig_dir(self, filename):
        figs_dir = f"{self.working_dir}/figs"
        os.makedirs(figs_dir, exist_ok=True)
        return figs_dir

    def get_data_dir(self, filename=None):
        if filename is None:
            data_dir = f"{self.working_dir}"
        else:
            data_dir = f"{self.working_dir}/data/{filename}"
        os.makedirs(data_dir, exist_ok=True)
        return data_dir
    
    def load_video(self, video_idx):
        video_filename = self.filenames[video_idx]

        self.video = None
        try:
            self.video = load_video(self.video_dir, video_filename)
        except:
            return False

        self.index = -1

        self.dlc.init_data(feature_keys=self.feature_keys)

        return (self.video is not None)
    
    def is_video_empty(self):
        return self.video is None or not self.video.isOpened()
    
    def read_video(self):
        ret, self.frame = self.video.read()
        self.index += 1
        self.dlc.update_frame(self.frame, self.index == 0)
        return ret
    
    def release_video(self):
        self.video.release()
    
    def detect_pose(self):
        self.dlc.detect_pose()

    def save_data(self, video_idx):
        video_filename = self.filenames[video_idx]
        data_dir = self.get_data_dir(video_filename)
        self.dlc.save_data(data_dir)

    def post_process(self, verbose=False):
        err_log = {
            'poor_pose': [],
            'file_missing': []
        }

        columns = ['filename', *self.feature_keys, 'move_class', 'stim_class']
        df = pd.DataFrame(columns=columns)

        for video_idx in range(len(self.filenames)):
            video_filename = self.filenames[video_idx]

            data_dir = self.get_data_dir(video_filename)

            try:
                features = self.dlc.load_data(data_dir, self.feature_keys)
            except ValueError:
                err_log['poor_pose'].append(video_filename)
                continue
            except FileNotFoundError:
                err_log['file_missing'].append(video_filename)
                continue
            except:
                raise("Uncaught exception")
            
            row = None
            if self.is_metadata_present:
                row = self.files_info.iloc[video_idx]
            md = load_metadata_new(row, time_margin = self.time_margin)

            start_f, end_f = self.init_frame + md[0], md[1]

            data = {}

            for key, value in features.items():
                data[key] = value[start_f: end_f, ...]

            move_idx, stim_idx = 0, 0
            if self.is_metadata_present:
                move_idx = int(row['Classification'])
                stim_class = row["Stimulation Class"]
                stim_idx = self.stim_class_list.index(stim_class)

            data['filename'] = video_filename
            data['move_class'] = move_idx
            data['stim_class'] = stim_idx

            df = df.append(data, ignore_index=True)

        if verbose:
            for key, items in err_log.items():
                print(f"{key}: {len(items)}")
                for item in items:
                    print(f"\t{item}")

        data_dir = self.get_data_dir()

        feature_key_tuple = [('features', key) for key in self.feature_keys]

        df.columns = pd.MultiIndex.from_tuples([
            ('metadata', 'filename'),
            *feature_key_tuple,
            ('labels', 'move_class'),
            ('labels', 'stim_class')
        ])

        # df = df.fillna(0)

        with open(f'{data_dir}/features.pkl', 'wb') as f:
            pickle.dump(df, f)

    def visualize_results(self):
        num_lines = 3 # for 3 angles
        colors = plt.cm.Paired(np.linspace(0, 1, num=num_lines))

        num_rows, num_cols = 1, 3
        fig, axs = plt.subplots(num_rows, num_cols, figsize=(16, 4), gridspec_kw={'top': 0.85})

        for i, key in enumerate(self.feature_angles_dict.keys()):
            feature_angle_mean = np.mean(self.feature_angles_dict[key], axis=0)

            for k in range(3):
                data = feature_angle_mean[..., k]
                data = smooth_data(data, 5)
                axs[i].plot(data, label=k, linewidth=2, c=colors[k])
            axs[i].set_ylim(-10, 190)
            axs[i].set_title(key)
            axs[i].set_xlabel("Time (s)")
            axs[i].axvline(x=0, color='orange', linewidth=2, alpha=0.3)

        legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=i) for i in range(3)]
        fig.legend(title="Angle", handles=legend_elements, loc='upper right')
        fig.suptitle(f'Electrical Stimulations - Angle')

        figs_dir = self.get_fig_dir("")
        figs_dir_full = f'{figs_dir}/dlc-summary'
        os.makedirs(figs_dir_full, exist_ok=True)
        fig.savefig(f'{figs_dir_full}/Electrical Stimulations - Angle.png', facecolor='white')
    

In [None]:
# driver = Driver("octo-2-unlabelled-small", dev_mode = False)
driver = Driver("octo")

In [None]:
# for video_idx in tqdm(range(len(driver.filenames))):
#     ret = driver.load_video(video_idx)
#     if not ret:
#         continue

#     while not driver.is_video_empty():
#         ret = driver.read_video()
#         if not ret:
#             break

#         driver.detect_pose()

#     driver.release_video()

#     driver.save_data(video_idx)

# driver.post_process()

In [None]:
data_dir = driver.get_data_dir()
with open(f'{data_dir}/features.pkl', 'rb') as f:
    df_features = pickle.load(f)
print(df_features.shape)
display(df_features.head(2))

In [None]:
# Quantitative analysis of DLC performance

p_like = XY = df_features['features', 'p'].values
XY = df_features['features', 'xy'].values

N = len(p_like)

def plot_scatter(xy):
    colors = plt.cm.Blues(np.linspace(1, 0.2, num=xy.shape[0]))
    for i, (x, y) in enumerate(xy):
        plt.scatter(x, -y, color=colors[i], label=f"Point {i+1}")

    plt.gca().set_aspect('equal')
    # plt.legend()
    plt.show()

err_list = []

for i in range(N):
    p_el = p_like[i]

    feat = XY[i]

    # plot_scatter(feat[0])

    feat = np.diff(feat, axis=-2)
    feat = np.square(feat)
    
    feat = np.sqrt(np.sum(feat, axis=-1))

    L = np.sum(feat, axis=-1, keepdims=True)

    feat /= L

    IDEAL_LENGTH = np.array([1/2, 1/4, 1/8, 1/8])

    err_1 = 1 - p_el.mean()
    err_2 = np.var(L)/np.mean(L)
    err_2 = min(err_2, 1)

    err_3 = feat - IDEAL_LENGTH
    err_3 = np.linalg.norm(err_3, axis=-1)
    err_3 = np.mean(err_3)

    err_list.append([err_1, err_2, err_3])

err_list = np.array(err_list).T

print("Fraction of good data:")
print(f"1: {100*np.sum(err_list[0]<0.2)/N : 0.1f}%")
print(f"2: {100*np.sum(err_list[1]<0.1)/N : 0.1f}%")
print(f"All: {100*np.sum((err_list[0]<0.2) & (err_list[1]<0.1))/N : 0.1f}%")

plt.plot(err_list[0], label="Error 1")
plt.plot(err_list[1], label="Error 2")
plt.plot(err_list[2], label="Error 3")
plt.legend()
plt.show()

In [None]:
XY = df_features['features', 'xy'].values
N = len(XY)
mean_angle = np.zeros(N)
max_angle = np.zeros(N)
mean_angular_speed = np.zeros(N)
max_angular_speed = np.zeros(N)

for i in range(len(XY)):
    feat = XY[i]
    feat = np.diff(feat, axis=-2)
    angle_array = np.zeros(feat.shape[0])
    for t in range(feat.shape[0]):
        angle_array[t] = angle_between(feat[t, 0], feat[t, -1])

    mean_angle[i] = np.mean(angle_array)
    max_angle[i] = np.max(angle_array)
    angular_speed = driver.fps * 1e-3 * np.abs(np.diff(angle_array))
    mean_angular_speed[i] = np.mean(angular_speed)
    max_angular_speed[i] = np.max(angular_speed)
    
df_final = pd.DataFrame({
    'Mean angle' : mean_angle, 
    'Max angle' : max_angle,
    'Mean angular speed' : mean_angular_speed, 
    'Max angular speed' : max_angular_speed})

result_concat = pd.concat([df_final, df_features['labels']], axis=1)

display(result_concat.head(5))

In [None]:
def find_unit(feature):
    if 'angle' in feature:
        return '°'
    elif 'angular speed' in feature:
        return '°/ms'
    else:
        return ''

In [None]:
features_list = df_final.columns.to_list()
features_to_compare = [
    [features_list[0], features_list[1]],
    [features_list[2], features_list[3]],
    [features_list[0], features_list[2]],
    [features_list[1], features_list[3]]
]
features_to_compare

In [None]:
# Electrical Stimulations - Movement

num_rows, num_cols = 2, 2
fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 8), gridspec_kw={'top': 0.85})

movement_types = [
    "No movement",
    "Movement",
    "Movement with arm curl"
]
class_type = 'move_class'

class_list = movement_types
no_of_points = len(class_list)
colors = plt.cm.Blues(np.linspace(0.4, 1, num=no_of_points))

for c in range(no_of_points):
    class_data = result_concat.query(f'{class_type} == {c}')
    for i in range(num_rows):
        for j in range(num_cols):
            k = i*num_cols + j
            to_compare = features_to_compare[k]
            x_lb, y_lb = to_compare

            axs[i, j].scatter(class_data[x_lb], class_data[y_lb],
                color=colors[c], label=class_list[c], s=20)
            
            axs[i, j].set_xlabel(f'{x_lb} ({find_unit(x_lb)})')
            axs[i, j].set_ylabel(f'{y_lb} ({find_unit(y_lb)})')

legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=class_list[i]) for i in range(no_of_points)]
fig.legend(title="", handles=legend_elements, loc='upper right')
fig.suptitle(f'Electrical Stimulations - Movement')
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()

In [None]:
# Electrical Stimulations - Location

num_rows, num_cols = 2, 2
fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 8), gridspec_kw={'top': 0.85})

class_type = 'stim_class'

class_list = driver.stim_class_list
no_of_points = len(class_list)
colors = plt.cm.Paired(np.linspace(0.2, 0.8, num=no_of_points))

for c in range(no_of_points):
    class_data = result_concat.query(f'{class_type} == {c}')
    for i in range(num_rows):
        for j in range(num_cols):
            k = i*num_cols + j
            to_compare = features_to_compare[k]
            x_lb, y_lb = to_compare

            axs[i, j].scatter(class_data[x_lb], class_data[y_lb],
                color=colors[c], label=class_list[c], s=20)
            
            axs[i, j].set_xlabel(f'{x_lb} ({find_unit(x_lb)})')
            axs[i, j].set_ylabel(f'{y_lb} ({find_unit(y_lb)})')

legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=class_list[i]) for i in range(no_of_points)]
fig.legend(title="", handles=legend_elements, loc='upper right')
fig.suptitle(f'Electrical Stimulations - Location')
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

data = df_final.to_numpy()  # 56 samples with 4 features each

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)

pca = PCA(n_components=2)
reduced_data = pca.fit_transform(scaled_data)

In [None]:
class_type = 'move_class'

class_list = movement_types
no_of_points = len(class_list)
colors = plt.cm.Blues(np.linspace(0.4, 1, num=no_of_points))

for c in range(no_of_points):
    indices = result_concat.query(f'{class_type} == {c}').index
    class_data = reduced_data[indices]
    plt.scatter(class_data[:, 0], class_data[:, 1],
                color=colors[c], label=class_list[c], s=25)
            
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.xticks([])
plt.yticks([])
plt.legend()
plt.title(f'Electrical Stimulations - Movement')
plt.show()

In [None]:
class_type = 'stim_class'

class_list = driver.stim_class_list
no_of_points = len(class_list)
colors = plt.cm.Paired(np.linspace(0.2, 0.8, num=no_of_points))

for c in range(no_of_points):
    indices = result_concat.query(f'{class_type} == {c}').index
    class_data = reduced_data[indices]
    plt.scatter(class_data[:, 0], class_data[:, 1],
                color=colors[c], label=class_list[c], s=25)
            
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.xticks([])
plt.yticks([])
plt.legend()
plt.title(f'Electrical Stimulations - Location')
plt.show()

In [None]:
# Clustering - KMeans

from sklearn.cluster import KMeans

data = reduced_data

NO_OF_CLUSTERS = 3
# Create a KMeans instance with 3 clusters
kmeans = KMeans(n_clusters=NO_OF_CLUSTERS, random_state=42)

# Fit the KMeans model to the data and predict the cluster labels
cluster_labels = kmeans.fit_predict(data)

# Get the cluster centers (centroids)
cluster_centers = kmeans.cluster_centers_

colors = ['red', 'blue', 'green']

for c in range(NO_OF_CLUSTERS): 
    # Separate the data points by cluster
    cluster_data = data[cluster_labels == c]
    plt.scatter(cluster_data[:, 0], cluster_data[:, 1], label=f'Cluster {c+1}', color=colors[c])

# Plot the cluster centers
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], marker='X', s=100, c='black', label='Centroids')

plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.xticks([])
plt.yticks([])
plt.title('Clustering with K-Means')
plt.legend()
plt.show()

In [None]:
# Clustering - AgglomerativeClustering

from sklearn.cluster import AgglomerativeClustering

data = reduced_data

NO_OF_CLUSTERS = 3
# Create a KMeans instance with 3 clusters
agg_clustering = AgglomerativeClustering(n_clusters=NO_OF_CLUSTERS)

# Fit the KMeans model to the data and predict the cluster labels
cluster_labels = agg_clustering.fit_predict(data)

colors = ['red', 'blue', 'green']

for c in range(NO_OF_CLUSTERS): 
    # Separate the data points by cluster
    cluster_data = data[cluster_labels == c]
    plt.scatter(cluster_data[:, 0], cluster_data[:, 1], label=f'Cluster {c+1}', color=colors[c])

# Plot the cluster centers
# plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], marker='X', s=100, c='black', label='Centroids')

plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.xticks([])
plt.yticks([])
plt.title('Agglomerative Clustering')
plt.legend()
plt.show()

In [None]:
print(df_features.shape)

In [None]:
df_features['label', 'cluster'] = cluster_labels

In [None]:
df_features

In [None]:
df_features.sample(5)

In [None]:
df_features['label', 'cluster'].value_counts()

In [None]:
# # Plot angles and p-likelihood

# N = 2
# X = df_features['features', 'angles'].values[N-1:N]
# y = df_features['labels'].iloc[N-1:N]

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# tx = driver.tx

# stim_class_list = driver.stim_class_list

# for i in range(1):
#     feature = X[i]
#     label = y.iloc[i]
#     print(df_features['metadata', 'filename'].iloc[i])
#     print(movement_types[label['move_class']],"|", stim_class_list[label['stim_class']])
#     for k in range(feature.shape[-1]):
#         plt.plot(tx, X[i][:, k], label=k)
#     plt.legend()
#     plt.show()

# X = df_features['features', 'p'].values[N-1:N]
# y = df_features['labels'].iloc[N-1:N]

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# tx = driver.tx

# stim_class_list = driver.stim_class_list

# for i in range(1):
#     feature = X[i]
#     label = y.iloc[i]
#     for k in range(feature.shape[-1]):
#         plt.plot(tx, X[i][:, k], label=k)
#     plt.legend()
#     plt.show()

In [None]:
# %matplotlib tk

# import matplotlib.pyplot as plt
# import numpy as np

# WIDTH, HEIGHT = 640, 480

# R = 1
# df_ = df_features.iloc[R]
# print(df_['metadata', 'filename'])
# xy = df_['features', 'xy']
# x_coords, y_coords = xy[..., 0], xy[..., 1]

# y_coords = HEIGHT-y_coords

# num_points = xy.shape[1]  # Number of points to visualize

# # Create a figure and axis
# fig, ax = plt.subplots()

# colors = plt.cm.Blues(np.linspace(1, 0.2, num=num_points))

# # Create an empty scatter plot for each point
# scatters = [ax.scatter([], [], color=colors[i]) for i in range(num_points)]

# # Set up the axis limits
# ax.set_xlim(0, WIDTH)
# ax.set_ylim(0, HEIGHT)

# # Function to update the scatter plots
# def update_plots(i):
#     # Iterate through each point
#     for j in range(num_points):
#         # Get the current x and y coordinates for the point
#         x = x_coords[i, j]
#         y = y_coords[i, j]
        
#         # Update the scatter plot data for the point
#         scatters[j].set_offsets([(x, y)])
        
#     # Set the title to the current index
#     ax.set_title(f"Time Step: {driver.tx[i]:0.1f} s")
    
#     # Pause for a short duration (in seconds) to observe each point
#     plt.pause(1.5*1/30)

# # Iterate through each time step and update the plots
# for i in range(xy.shape[0]):
#     update_plots(i)

# # Show the final plots in a new window
# plt.show()


In [None]:
df_features['labels', 'move_class'].value_counts()

In [None]:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM


num_classes = 3

# Load and preprocess the data
X = np.stack(df_features['features', 'angles'].values, axis=-1)
X = X.reshape(-1, X.shape[-1]).T
y = df_features['labels', 'stim_class'].values.astype(int)

# Encode the target variable
# le = LabelEncoder()
# y = le.fit_transform(y)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

# Reshape the input data for LSTM
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))

# Build the RNN model
model = Sequential()
model.add(LSTM(64, input_shape=(1, X_train.shape[2]), activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))

# Compile and train the model
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=300, batch_size=4, verbose=0);

In [None]:
%matplotlib inline
# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test)
print('Test loss:', loss)
print('Test accuracy:', accuracy)

# Perform predictions on the test data
y_pred_prob = model.predict(X_test)
y_pred = np.argmax(y_pred_prob, axis=1)

class_names = [0, 1, 2]

# Generate the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred, labels=class_names)

# Visualize the confusion matrix as a heatmap
first_words = [string.split()[0] for string in driver.stim_class_list]
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=first_words, yticklabels=first_words)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Stimulation point of contact - using DLC')
plt.show()