In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os
from PIL import Image
import yaml
import pickle
from utils import *
from helpers import *

import pdb
from tqdm import tqdm

from proSVD import proSVD
from dlclive import DLCLive, Processor

In [None]:
config_file = "octo-2-unlabelled-small"
with open(f'configs/{config_file}.yaml', 'r') as file:
    config = yaml.safe_load(file)
    
IS_METADATA_PRESENT = (config['path']['xls'] is not None)
files_info = None
filenames = None

fps = config['info']['fps']

root_dir = config['path']['root']
video_dir = f"{root_dir}/{config['path']['video']}"

if IS_METADATA_PRESENT:
    xls_path = f"{root_dir}/{config['path']['xls']}"
    files_info = read_octopus_xlsx(xls_path)
    files_info = files_info[files_info['Stim Method'] == 'Electrical']
    filenames = files_info["File Name"].to_list()
else:
    filenames = [os.path.splitext(file)[0] for file in os.listdir(video_dir)]

###

working_dir = f"{root_dir}/{config['path']['working']}" # to save processed data and figures

model_path = f"{root_dir}/{config['path']['model']}"

dlc_live = DLCLive(
    model_path,
    processor=Processor(),
    pcutoff=0.2,
    resize=1)

PROSVD_K = 4 # no. of dims to reduce to

TIME_MARGIN = (-120, 180) # to trim videos
total_f = TIME_MARGIN[1] - TIME_MARGIN[0]

init_frame_crop = 10 # No of initial frames used to set cropping info
init_frame_prosvd = 90 # No of initial frames used to initialize proSVD
init_frame = init_frame_crop + init_frame_prosvd


DEV_MODE = False
if DEV_MODE:
    files_info = files_info.iloc[[0,3,13,14,15,16]] #[3:4]
    filenames = files_info["File Name"].to_list()

print(f"Processing {len(filenames)} videos from {video_dir}")
if len(filenames) < 4:
    print('\t', end='')
    print(*filenames, sep="\n\t")

In [None]:
def get_fig_dir(filename):
    figs_dir = f"{working_dir}/figs"
    os.makedirs(figs_dir, exist_ok=True)
    return figs_dir

def get_data_dir(filename):
    data_dir = f"{working_dir}/data/{filename}"
    os.makedirs(data_dir, exist_ok=True)
    return data_dir

def is_badpose(likely, threshold=0.8):
    return np.mean(likely) < threshold

def calculate_nancount(data):
    nan_count = np.sum(np.isnan(data))
    return 100*nan_count/data.size

def fill_nan_linear_interpolation_axis(arr, axis):
    def fill_nan_linear_interpolation(row):
        nan_mask = np.isnan(row)
        indices = np.arange(len(row))
        row[nan_mask] = np.interp(indices[nan_mask], indices[~nan_mask], row[~nan_mask])
        return row
    
    return np.apply_along_axis(fill_nan_linear_interpolation, axis, arr)

class DataManager:
    def __init__(self):
        self.data = {
            'Q': []
        }
    
    def _print_key_error(self, key):
        print(f"Key '{key}' does not exist in the data manager.")

    def add_nan(self, key, shape):
        data_ = self.data[key]
        pop_count = 0
        while len(data_) > 0 and data_[-1] is None:
            data_.pop()
            pop_count += 1
        while pop_count > 0:
            data_.append(np.full(shape, np.nan))
            pop_count -= 1

    def add(self, key, value):
        if key in self.data:
            if value is not None:
                self.add_nan(key, value.shape)
            self.data[key].append(value)
        else:
            self._print_key_error(key)

    def to_numpy(self):
        for key in self.data:
            data_ = self.data[key]
            if len(data_) == 0 or data_[0] is None:
                self.data[key] = None
                return
            self.add_nan(key, data_[0].shape)
            self.data[key] = np.array(data_)

    def remove(self, key):
        if key in self.data:
            del self.data[key]
        else:
            self._print_key_error(key)

    def get(self, key):
        if key in self.data:
            return self.data[key]
        else:
            self._print_key_error(key)
            return None

In [None]:
for video_idx in tqdm(range(2)):
    row = files_info.iloc[video_idx]
    md = load_metadata_new(row, time_margin = TIME_MARGIN)

    start_f, end_f = md[0], md[1]
    total_f = end_f - start_f

    video_filename = filenames[video_idx]

    video = None

    try:
        video = load_video(video_dir, video_filename)
        video.set(cv2.CAP_PROP_POS_FRAMES, start_f)
    except:
        continue

    figs_dir = get_fig_dir(video_filename)
    data_dir = get_data_dir(video_filename)

    crop_box = np.zeros(4, int)

    index = -1

    frames = []  # for proSVD initialization
    dm = DataManager()
    pro = None

    while video.isOpened():
        index += 1
        if index >= total_f:
            break

        ret, frame = video.read()
        if not ret:
            break

        # Cropping starts
        if index < init_frame_crop:
            crop_box += detect_crop_box_wrapper(dlc_live, frame, index, margin=max(40, int(frame.size//(3*5*1e4))))
            continue

        if index == init_frame_crop:
            crop_box //= init_frame_crop

        frame2 = frame[crop_box[0]:crop_box[1],crop_box[2]:crop_box[3],:]

        # Cropping ends

        pose = detect_pose(dlc_live, frame, index)

        frame2 = rgb_to_grayscale(frame2)
        frame2 = downsample_image(frame2)

        # save the cropped frame for checking if cropping done correctly
        if index == init_frame_crop:
            crop_dir = f'{figs_dir}/cropped'
            os.makedirs(crop_dir, exist_ok=True)
            im = Image.fromarray(frame2)
            im.save(f'{crop_dir}/{video_filename}.png')

        frame2 = frame2.flatten()

        if index < init_frame:
            frames.append(frame2)
            continue

        # proSVD starts

        if index == init_frame:
            frames = np.array(frames).T
            pro = proSVD(k=PROSVD_K, w_len=1,history=0, decay_alpha=1, trueSVD=True)
            pro.initialize(frames)
            del frames

        if is_badpose(pose[...,-1]):
            dm.add('Q', None)
            continue

        pro.preupdate()
        pro.updateSVD(frame2[:, None])
        pro.postupdate()

        # pro_coordi = frame2 @ pro.Q
        dm.add('Q', pro.Q)

    video.release()

    dm.to_numpy()
    np.save(f'{data_dir}/Q.npy', dm.get('Q'))

In [None]:
display(files_info['Stimulation Class'].value_counts())
display(files_info['Classification'].value_counts())

In [None]:
# # Post processing!

# stim_class_list = sorted(files_info['Stimulation Class'].unique().tolist())
# stim_count = len(stim_class_list)

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# tx = np.arange(init_frame + TIME_MARGIN[0], TIME_MARGIN[1])/fps

# err_log = {
#     'poor_pose': [],
#     'file_missing': []
# }

# feature_all = []
# label_all = []

# for video_idx in tqdm(range(len(filenames))):
#     video_filename = filenames[video_idx]

#     figs_dir = get_fig_dir(video_filename)
#     data_dir = get_data_dir(video_filename)

#     try:
#         Q_full = np.load(f'{data_dir}/Q.npy')
#     except ValueError:
#         err_log['poor_pose'].append(video_filename)
#         continue
#     except FileNotFoundError:
#         err_log['file_missing'].append(video_filename)
#         continue
#     except:
#         raise("Uncaught exception")

#     total_frames = Q_full.shape[0]

#     row = None
#     if IS_METADATA_PRESENT:
#         row = files_info.iloc[video_idx]

#     stim_class = row["Stimulation Class"]

#     Q_diff = np.diff(Q_full, axis=0)
#     Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
#     Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

#     move_idx = int(row['Classification'])
#     stim_idx = stim_class_list.index(stim_class)

#     feature_all.append(Q_norm_diff)
#     label_all.append(stim_idx)

#     del Q_full
#     del Q_diff
#     del Q_norm_diff

# df = pd.DataFrame({
#     'features': feature_all,
#     'labels': label_all,
# })

# with open(f'data/pro_features.pkl', 'wb') as f:
#     pickle.dump(df, f)

In [None]:
with open(f'data/pro_features.pkl', 'rb') as f:
    df_features = pickle.load(f)
print(df_features['features'].iloc[0].shape)
df_features.head(1)

In [None]:
# # Plot 2: all experiments separately with basis as separate plots

# stim_class_list = sorted(files_info['Stimulation Class'].unique().tolist())
# stim_count = len(stim_class_list)

# num_figures = PROSVD_K
# num_rows = int(np.ceil(np.sqrt(num_figures)))
# num_cols = num_figures//num_rows

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# fig_all = []
# axs_all = []

# num_lines = 3 # for 3 movement classes
# colors = plt.cm.Blues(np.linspace(0.2, 1, num=num_lines))

# titles = [
#     "n-Norm of t-Diff of Q",
# ]

# tx = np.arange(init_frame + TIME_MARGIN[0], TIME_MARGIN[1])/fps

# err_log = {
#     'poor_pose': [],
#     'file_missing': []
# }

# for r in range(1): 
#     fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 6), gridspec_kw={'top': 0.8})
#     fig_all.append(fig)
#     axs_all.append(axs)

# feature_all = []
# label_all = []

# for video_idx in tqdm(range(len(filenames))):
#     video_filename = filenames[video_idx]

#     figs_dir = get_fig_dir(video_filename)
#     data_dir = get_data_dir(video_filename)

#     try:
#         Q_full = np.load(f'{data_dir}/Q.npy')
#     except ValueError:
#         err_log['poor_pose'].append(video_filename)
#         continue
#     except FileNotFoundError:
#         err_log['file_missing'].append(video_filename)
#         continue
#     except:
#         raise("Uncaught exception")

#     total_frames = Q_full.shape[0]

#     row = None
#     if IS_METADATA_PRESENT:
#         row = files_info.iloc[video_idx]

#     stim_class = row["Stimulation Class"]

#     metadata = load_metadata(video_filename, total_frames, row, fps, init_frame)

#     Q_diff = np.diff(Q_full, axis=0)
#     Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
#     Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

#     move_class = int(row['Classification'])
#     stim_idx = stim_class_list.index(stim_class)

#     feature_all.append(Q_norm_diff[:, 1])
#     label_all.append(stim_idx)

#     ## PLOTTING ##

#     data = [
#         Q_norm_diff,
#     ]

#     start_f = metadata['start']['f']
#     end_f = metadata['end']['f']

#     for r in range(len(titles)):
#         for i in range(num_rows):
#             for j in range(num_cols):
#                 k = j + i*num_rows
#                 data_ = data[r][:, k]
#                 # data_ = smooth_data(data_, kernel_size = 3)
#                 axs_all[m][i, j].plot(tx, data_, c=colors[move_class], alpha=0.7, label=move_class, linewidth=1)
#                 axs_all[m][i, j].set_title(f'Basis {k}')

#     for x in data:
#         del x

#     del Q_full

# feature_all = np.array(feature_all)
# label_all = np.array(label_all)
            
    
# # legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=movement_types[i]) for i in range(3)]

# # ylims = [1e-2, 4e-1, 4e-1, 4e-1]
# # for r in range(len(titles)):
# #     fig_all[r].legend(handles=legend_elements, loc='upper right')
# #     fig_all[r].subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

# #     for i in range(num_rows):
# #         for j in range(num_cols):
# #             # axs_all[r][i].set_ylim(0, ylims[r])
# #             axs_all[r][i, j].axvline(x=0, color='orange', linewidth=2, alpha=0.3)
# #             axs_all[r][-1, j].set_xlabel("Time (s)")

# #     fig_all[r].suptitle(f'All Data')

# #     figs_dir_full = f'{figs_dir}/{titles[r]}'
# #     os.makedirs(figs_dir_full, exist_ok=True)
# #     fig_all[r].savefig(f'{figs_dir_full}/all_data.png', facecolor='white')

# # for key, items in err_log.items():
# #     print(f"{key}: {len(items)}")
# #     for item in items:
# #         print(f"\t{item}")

In [None]:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM


num_classes = 3

# Load and preprocess the data
with open(f'data/pro_features.pkl', 'rb') as f:
    df_features = pickle.load(f)
X = np.stack(df_features['features'].values, axis=-1)
X = X.reshape(-1, X.shape[-1]).T
# linear interpolation to fill nan values
X = fill_nan_linear_interpolation_axis(X, axis=1)

y = df_features['labels'].values.astype(int)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=50)

# Reshape the input data for LSTM
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))

# Build the RNN model
model = Sequential()
model.add(LSTM(32, input_shape=(1, X_train.shape[2]), activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))

# Compile and train the model
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=300, batch_size=4, verbose=0);


In [None]:
%matplotlib inline
# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test)
print('Test loss:', loss)
print('Test accuracy:', accuracy)

# Perform predictions on the test data
y_pred_prob = model.predict(X_test)
y_pred = np.argmax(y_pred_prob, axis=1)

class_names = le.classes_

# Generate the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred, labels=class_names)

# Visualize the confusion matrix as a heatmap
first_words = [string.split()[0] for string in stim_class_list]
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=first_words, yticklabels=first_words)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Stimulation point of contact - using proSVD')
plt.show()
