In [11]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os
from PIL import Image
import yaml
from utils import *
from helpers import *
import pickle

import pdb
from tqdm import tqdm

from dlclive import DLCLive, Processor

In [12]:
class DataManager:
    def __init__(self, keys=[]):
        self.data = {key: [] for key in keys}
    
    def _print_key_error(self, key):
        print(f"Key '{key}' does not exist in the data manager.")

    def add_nan(self, key, shape):
        data_ = self.data[key]
        pop_count = 0
        while len(data_) > 0 and data_[-1] is None:
            data_.pop()
            pop_count += 1
        while pop_count > 0:
            data_.append(np.full(shape, np.nan))
            pop_count -= 1

    def add(self, key, value):
        if key in self.data:
            if value is not None:
                self.add_nan(key, value.shape)
            self.data[key].append(value)
        else:
            self._print_key_error(key)

    def to_numpy(self):
        for key in self.data:
            data_ = self.data[key]
            if len(data_) == 0 or data_[0] is None:
                self.data[key] = None
                return
            self.add_nan(key, data_[0].shape)
            self.data[key] = np.array(data_)

    def save(self, dir):
        for key in self.data:
            np.save(f'{dir}/{key}.npy', self.data.get(key))

    def load(self, dir, keys):
        if keys is None:
            keys = self.data.keys
        for key in keys:
            self.data[key] = np.load(f'{dir}/{key}.npy')
        return self.data

    def remove(self, key):
        if key in self.data:
            del self.data[key]
        else:
            self._print_key_error(key)

    def get(self, key):
        if key in self.data:
            return self.data[key]
        else:
            self._print_key_error(key)
            return None

In [13]:
class DLCManager:
    def __init__(self, 
            model_path,
            processor=Processor(),
            pcutoff=0.2,
            resize=1):
        
        self.model = DLCLive(
            model_path=model_path,
            processor=processor,
            pcutoff=pcutoff,
            display=True,
            resize=resize)
        self.frame = None
        self.is_first_frame = None
        self.dm = DataManager()

    def init_data(self, feature_keys):
        self.dm = DataManager(feature_keys)
        self.prev_pose_xy = None
        
    def update_frame(self, frame, is_first_frame):
        self.frame = frame
        self.is_first_frame = is_first_frame
        
    def detect_pose_helper(self):
        if self.is_first_frame:
            pose = self.model.init_inference(self.frame)
        else:
            pose = self.model.get_pose(self.frame)
        return pose
        
    def detect_pose(self):
        curr_pose = self.detect_pose_helper()
        
        curr_pose_xy, curr_pose_p = curr_pose[:, :-1], curr_pose[:, -1]
        if self.prev_pose_xy is None:
            self.prev_pose_xy = curr_pose_xy

        pose = np.stack([self.prev_pose_xy, curr_pose_xy])
        feature_angles_item, pose_speed_item = extract_pose_features(pose)

        self.dm.add('xy', curr_pose_xy)
        self.dm.add('p', curr_pose_p)
        self.dm.add('angles', feature_angles_item)
        self.dm.add('speed', pose_speed_item)

        self.prev_pose_xy = curr_pose_xy

    def save_data(self, dir):
        self.dm.to_numpy()
        self.dm.save(dir)

    def load_data(self, dir, feature_keys):
        return self.dm.load(dir, feature_keys)

In [14]:
class ConfigManager:
    def __init__(self, config_file, dev_mode = False, verbose = True):
        with open(f'configs/{config_file}.yaml', 'r') as file:
            config = yaml.safe_load(file)
        self.is_metadata_present = (config['path']['xls'] is not None)
        self.fps = config['info']['fps']
        root_dir = config['path']['root']
        self.video_dir = f"{root_dir}/{config['path']['video']}"

        if self.is_metadata_present:
            xls_path = f"{root_dir}/{config['path']['xls']}"
            self.files_info = read_octopus_xlsx(xls_path)
            self.files_info = self.files_info[self.files_info['Stim Method'] == 'Electrical']
            self.filenames = self.files_info["File Name"].to_list()
        else:
            self.filenames = [os.path.splitext(file)[0] for file in os.listdir(self.video_dir)]

        self.working_dir = f"{root_dir}/{config['path']['working']}" # to save processed data and figures

        model_path = f"{root_dir}/{config['path']['model']}"

        self.feature_keys = ['xy', 'p', 'angles', 'speed']
        self.dlc = DLCManager(model_path)
        
        self.prosvd_k = 4 # no. of dims to reduce to

        self.time_margin = (-120, 180) # to trim videos
        self.total_f = self.time_margin[1] - self.time_margin[0]

        self.init_frame_crop = 10 # No of initial frames used to set cropping info
        self.init_frame_prosvd = 90 # No of initial frames used to initialize proSVD
        self.init_frame = self.init_frame_crop + self.init_frame_prosvd

        self.tx = np.arange(self.init_frame + self.time_margin[0], self.time_margin[1])/self.fps


        self.stim_class_list = sorted(self.files_info['Stimulation Class'].unique().tolist())

        if dev_mode:
            self.files_info = self.files_info.iloc[[2]]
            self.filenames = self.files_info["File Name"].to_list()

        if verbose:
            print(f"Processing {len(self.filenames)} videos from {self.video_dir}")
            if len(self.filenames) < 4:
                print('\t', end='')
                print(*self.filenames, sep="\n\t")

    def get_fig_dir(self, filename):
        figs_dir = f"{self.working_dir}/figs"
        os.makedirs(figs_dir, exist_ok=True)
        return figs_dir

    def get_data_dir(self, filename=None):
        if filename is None:
            data_dir = f"{self.working_dir}"
        else:
            data_dir = f"{self.working_dir}/data/{filename}"
        os.makedirs(data_dir, exist_ok=True)
        return data_dir
    
    def load_video(self, video_idx):
        video_filename = self.filenames[video_idx]

        self.video = None
        try:
            self.video = load_video(self.video_dir, video_filename)
        except:
            return False

        self.index = -1

        self.dlc.init_data(feature_keys=self.feature_keys)

        return (self.video is not None)
    
    def is_video_empty(self):
        return self.video is None or not self.video.isOpened()
    
    def read_video(self):
        ret, self.frame = self.video.read()
        self.index += 1
        self.dlc.update_frame(self.frame, self.index == 0)
        return ret
    
    def release_video(self):
        self.video.release()
    
    def detect_pose(self):
        self.dlc.detect_pose()

    def save_data(self, video_idx):
        video_filename = self.filenames[video_idx]
        data_dir = self.get_data_dir(video_filename)
        self.dlc.save_data(data_dir)

    def post_process(self):
        err_log = {
            'poor_pose': [],
            'file_missing': []
        }

        columns = ['filename', *self.feature_keys, 'move_class', 'stim_class']
        df = pd.DataFrame(columns=columns)

        for video_idx in range(len(self.filenames)):
            video_filename = self.filenames[video_idx]

            data_dir = self.get_data_dir(video_filename)

            try:
                features = self.dlc.load_data(data_dir, self.feature_keys)
            except ValueError:
                err_log['poor_pose'].append(video_filename)
                continue
            except FileNotFoundError:
                err_log['file_missing'].append(video_filename)
                continue
            except:
                raise("Uncaught exception")
            
            row = None
            if self.is_metadata_present:
                row = self.files_info.iloc[video_idx]
            md = load_metadata_new(row, time_margin = self.time_margin)

            start_f, end_f = self.init_frame + md[0], md[1]

            data = {}

            for key, value in features.items():
                data[key] = value[start_f: end_f, ...]

            move_idx = int(row['Classification'])
            stim_class = row["Stimulation Class"]
            stim_idx = self.stim_class_list.index(stim_class)

            data['filename'] = video_filename
            data['move_class'] = move_idx
            data['stim_class'] = stim_idx

            df = df.append(data, ignore_index=True)

        for key, items in err_log.items():
            print(f"{key}: {len(items)}")
            for item in items:
                print(f"\t{item}")

        data_dir = self.get_data_dir()

        feature_key_tuple = [('features', key) for key in self.feature_keys]

        df.columns = pd.MultiIndex.from_tuples([
            ('metadata', 'filename'),
            *feature_key_tuple,
            ('labels', 'move_class'),
            ('labels', 'stim_class')
        ])

        # df = df.fillna(0)

        with open(f'{data_dir}/features.pkl', 'wb') as f:
            pickle.dump(df, f)

    def visualize_results(self):
        num_lines = 3 # for 3 angles
        colors = plt.cm.Paired(np.linspace(0, 1, num=num_lines))

        num_rows, num_cols = 1, 3
        fig, axs = plt.subplots(num_rows, num_cols, figsize=(16, 4), gridspec_kw={'top': 0.85})

        for i, key in enumerate(self.feature_angles_dict.keys()):
            feature_angle_mean = np.mean(self.feature_angles_dict[key], axis=0)

            for k in range(3):
                data = feature_angle_mean[..., k]
                data = smooth_data(data, 5)
                axs[i].plot(data, label=k, linewidth=2, c=colors[k])
            axs[i].set_ylim(-10, 190)
            axs[i].set_title(key)
            axs[i].set_xlabel("Time (s)")
            axs[i].axvline(x=0, color='orange', linewidth=2, alpha=0.3)

        legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=i) for i in range(3)]
        fig.legend(title="Angle", handles=legend_elements, loc='upper right')
        fig.suptitle(f'Electrical Stimulations - Angle')

        figs_dir = self.get_fig_dir("")
        figs_dir_full = f'{figs_dir}/dlc-summary'
        os.makedirs(figs_dir_full, exist_ok=True)
        fig.savefig(f'{figs_dir_full}/Electrical Stimulations - Angle.png', facecolor='white')
    

In [15]:
cfg_manager = ConfigManager("octo", dev_mode = True)

Processing 1 videos from /home/sachinks/Code/MyProjects/OctopusVideos1/videos
	elec_proximal_100Hz_5mA_220616_125117_000


In [16]:
for video_idx in range(len(cfg_manager.filenames)):
    ret = cfg_manager.load_video(video_idx)
    if not ret:
        continue

    while not cfg_manager.is_video_empty():
        ret = cfg_manager.read_video()
        if not ret:
            break

        cfg_manager.detect_pose()

    cfg_manager.release_video()

    cfg_manager.save_data(video_idx)

In [17]:
cfg_manager.post_process()

poor_pose: 0
file_missing: 0


In [18]:
data_dir = cfg_manager.get_data_dir()
with open(f'{data_dir}/features.pkl', 'rb') as f:
    df = pickle.load(f)
df.head(10)


Unnamed: 0_level_0,metadata,features,features,features,features,labels,labels
Unnamed: 0_level_1,filename,xy,p,angles,speed,move_class,stim_class
0,elec_proximal_100Hz_5mA_220616_125117_000,"[[[397.75977, 200.55101], [369.5135, 221.11324...","[[0.99994606, 0.994769, 0.95619035, 0.99879336...","[[14.844358378202074, 33.78292325306843, 92.85...","[[1.602511, 0.12227834, 0.42795083, 0.18256009...",1,2


In [19]:
# N = 1
# X = df['features', 'angles'].values[:N]
# y = df['labels'].iloc[:N]

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# tx = cfg_manager.tx

# stim_class_list = cfg_manager.stim_class_list

# for i in range(N):
#     feature = X[i]
#     label = y.iloc[i]
#     print(df['metadata', 'filename'].iloc[i])
#     print(movement_types[label['move_class']],"|", stim_class_list[label['stim_class']])
#     for k in range(feature.shape[-1]):
#         plt.plot(tx, X[i][:, k], label=k)
#     plt.legend()
#     plt.show()

# X = df['features', 'p'].values[:N]
# y = df['labels'].iloc[:N]

# movement_types = [
#     "No movement",
#     "Movement",
#     "Movement with arm curl"
# ]

# tx = cfg_manager.tx

# stim_class_list = cfg_manager.stim_class_list

# for i in range(N):
#     feature = X[i]
#     label = y.iloc[i]
#     print(df['metadata', 'filename'].iloc[i])
#     print(movement_types[label['move_class']],"|", stim_class_list[label['stim_class']])
#     for k in range(feature.shape[-1]):
#         plt.plot(tx, X[i][:, k], label=k)
#     plt.legend()
#     plt.show()

In [20]:
%matplotlib tk

import matplotlib.pyplot as plt
import numpy as np

WIDTH, HEIGHT = 640, 480

R = 0
df_ = df.iloc[R]
print(df_['metadata', 'filename'])
xy = df_['features', 'xy']
x_coords, y_coords = xy[..., 0], xy[..., 1]

y_coords = HEIGHT-y_coords

num_points = xy.shape[1]  # Number of points to visualize

# Create a figure and axis
fig, ax = plt.subplots()

colors = plt.cm.Blues(np.linspace(1, 0.2, num=num_points))

# Create an empty scatter plot for each point
scatters = [ax.scatter([], [], color=colors[i]) for i in range(num_points)]

# Set up the axis limits
ax.set_xlim(0, WIDTH)
ax.set_ylim(0, HEIGHT)

# Function to update the scatter plots
def update_plots(i):
    # Iterate through each point
    for j in range(num_points):
        # Get the current x and y coordinates for the point
        x = x_coords[i, j]
        y = y_coords[i, j]
        
        # Update the scatter plot data for the point
        scatters[j].set_offsets([(x, y)])
        
    # Set the title to the current index
    ax.set_title(f"Time Step: {cfg_manager.tx[i]:0.1f} s")
    
    # Pause for a short duration (in seconds) to observe each point
    plt.pause(1.5*1/30)

# Iterate through each time step and update the plots
for i in range(xy.shape[0]):
    update_plots(i)

# Show the final plots in a new window
plt.show()


elec_proximal_100Hz_5mA_220616_125117_000


In [None]:
# cfg_manager.visualize_results()

In [None]:
df.head(1)

In [None]:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM


num_classes = 3

# Load and preprocess the data
df_new = df
X = np.stack(df['features', 'angles'].values, axis=-1)
X = X.reshape(-1, X.shape[-1]).T
y = df['labels', 'stim_class'].values

# Encode the target variable
le = LabelEncoder()
y = le.fit_transform(y)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

# Reshape the input data for LSTM
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))

# Build the RNN model
model = Sequential()
model.add(LSTM(32, input_shape=(1, X_train.shape[2]), activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))

# Compile and train the model
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=100, batch_size=4)

# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test)
print('Test loss:', loss)
print('Test accuracy:', accuracy)
