In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os
from PIL import Image
import yaml
from utils import *
from helpers import *

import pdb
from tqdm import tqdm

from proSVD import proSVD
from dlclive import DLCLive, Processor

In [38]:
with open('configs/octo-1.yaml', 'r') as file:
    config = yaml.safe_load(file)
    
IS_METADATA_PRESENT = (config['path']['xls'] is not None)
files_info = None
filenames = None

fps = config['info']['fps']

root_dir = config['path']['root']
video_dir = f"{root_dir}/{config['path']['video']}"

if IS_METADATA_PRESENT:
    xls_path = f"{root_dir}/{config['path']['xls']}"
    files_info = read_octopus_xlsx(xls_path)
    files_info['Stimulation Class'] = files_info['Stimulation Type'].apply(get_stim_class)
    files_info.reset_index(inplace=True)
    filenames = files_info["File Name"].to_list()
else:
    files = os.listdir(video_dir)
    filenames = [os.path.splitext(file)[0] for file in files]

print(f"Processing {len(filenames)} videos from {video_dir}")
if len(filenames) < 4:
    print('\t', end='')
    print(*filenames, sep="\n\t")

###

working_dir = f"{root_dir}/{config['path']['working']}" # to save processed data and figures

model_path = f"{root_dir}/{config['path']['model']}"

dlc_proc = Processor()
dlc_live = DLCLive(
    model_path,
    processor=dlc_proc,
    pcutoff=0.2,
    resize=1)

PROSVD_K = 4 # no. of dims to reduce to

init_frame_crop = 10 # No of initial frames used to set cropping info
init_frame_prosvd = 90 # No of initial frames used to initialize proSVD
init_frame = init_frame_crop + init_frame_prosvd

Processing 89 videos from /home/sachinks/Code/MyProjects/OctopusVideos1/videos


In [39]:
def detect_crop_box_wrapper(dlc, frame, index, threshold=0.9, margin=40):
    pose = detect_pose(dlc, frame, index)
    return detect_crop_box(pose, frame.shape, threshold, margin)

def get_fig_dir(filename):
    figs_dir = f"{working_dir}/figs"
    os.makedirs(figs_dir, exist_ok=True)
    return figs_dir

def get_data_dir(filename):
    data_dir = f"{working_dir}/data/{filename}"
    os.makedirs(data_dir, exist_ok=True)
    return data_dir

In [None]:
for video_idx in tqdm(range(len(filenames))):
    video_filename = filenames[video_idx]

    video = None
    try:
        video = load_video(video_dir, video_filename)
        # print("Video loaded:", video_filename)
    except:
        continue

    figs_dir = get_fig_dir(video_filename)
    data_dir = get_data_dir(video_filename)

    crop_box = np.zeros(4, int)

    index = -1

    frames = []  # for proSVD initialization
    coordi_full = []
    Q_full = []
    pro = None

    while video.isOpened():
        ret, frame = video.read()
        if not ret:
            break

        index += 1

        # Cropping starts
        if index < init_frame_crop:
            crop_box += detect_crop_box_wrapper(dlc_live, frame, index, margin=max(40, int(frame.size//(3*5*1e4))))
            continue

        if index == init_frame_crop:
            crop_box //= init_frame_crop

        frame = frame[crop_box[0]:crop_box[1],crop_box[2]:crop_box[3],:]

        # Cropping ends

        frame = rgb_to_grayscale(frame)
        frame = downsample_image(frame)

        # save the cropped frame for checking if cropping done correctly
        if index == init_frame_crop:
            crop_dir = f'{figs_dir}/cropped'
            os.makedirs(crop_dir, exist_ok=True)
            im = Image.fromarray(frame)
            im.save(f'{crop_dir}/{video_filename}.png')

        frame = frame.flatten()

        if index < init_frame:
            frames.append(frame)
            continue

        # proSVD starts

        if index == init_frame:
            frames = np.array(frames).T
            pro = proSVD(k=PROSVD_K, w_len=1,history=0, decay_alpha=1, trueSVD=True)
            pro.initialize(frames)
            del frames

        pro.preupdate()
        pro.updateSVD(frame[:, None])
        pro.postupdate()

        pro_coordi = frame @ pro.Q
        Q_full.append(pro.Q)
        coordi_full.append(pro_coordi)

    video.release()

    Q_full = np.array(Q_full)
    coordi_full = np.array(coordi_full)
    np.save(f'{data_dir}/Q.npy', Q_full)
    np.save(f'{data_dir}/score.npy', coordi_full)

In [101]:

colors.shape

(6, 4)

In [108]:
# Plot mean-variance of experiments with basis as separate plots, and stimulation type as separate subplots
stim_class_list = sorted(files_info['Stimulation Class'].unique().tolist())
indices = files_info.index

num_figures = PROSVD_K
num_rows = int(np.ceil(np.sqrt(num_figures)))
num_cols = num_figures//num_rows

movement_types = [
    "No movement",
    "Movement",
    "Movement with arm curl"
]

fig_all = []
axs_all = []

num_lines = len(stim_class_list)
colorsB = plt.cm.Blues(np.linspace(0.3, 1, num=num_lines//2))
colorsG = plt.cm.Greens(np.linspace(0.3, 1, num=num_lines//2))
colors = np.concatenate([colorsB, colorsG], axis=0)

titles = [
    "n-Norm of t-Diff of Q",
]

ylims = [1e-2, 4e-1, 4e-1, 4e-1]

legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=stim_class_list[i]) for i in range(len(stim_class_list))]

plt.ioff()
# plt.ion()
for m in range(1): #PROSVD_K):
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 8), gridspec_kw={'top': 0.75});
    fig_all.append(fig)
    axs_all.append(axs)

for video_idx in tqdm(indices):
    video_filename = filenames[video_idx]

    row = None
    if IS_METADATA_PRESENT:
        row = files_info.iloc[video_idx]

    move_class = int(row['Classification'])
    if move_class == 0: #skip videos which shows no movement
        continue

    stim_class = row["Stimulation Class"]
    stim_idx = stim_class_list.index(stim_class)

    figs_dir = get_fig_dir(video_filename)
    data_dir = get_data_dir(video_filename)

    try:
        Q_full = np.load(f'{data_dir}/Q.npy')
    except:
        print("File not found: ", data_dir)
        continue

    total_frames = Q_full.shape[0]

    metadata = load_metadata(video_filename, total_frames, row, fps, init_frame)

    Q_norm = np.linalg.norm(Q_full, axis=1)
    Q_diff = np.diff(Q_full, axis=0)
    Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
    Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

    ## PLOTTING ##

    tx = np.arange(init_frame, init_frame+Q_full.shape[0], dtype=float)
    tx /= fps
    tx -= metadata['stim']['t']
    start_f = metadata['start']['f']
    end_f = metadata['end']['f']

    data = [
        Q_norm_diff,
    ]

    # pdb.set_trace()

    for m in range(1):
        for r in range(len(titles)):
            for i in range(num_rows):
                for j in range(num_cols):
                    k = j + i*num_rows
                    data_smoothed = smooth_data(data[r][:, k])
                    data_smoothed = data_smoothed[start_f: end_f]
                    axs_all[m][i, j].plot(tx[start_f: end_f], data_smoothed, c=colors[stim_idx],
                        label=stim_class, linewidth=1)


    del Q_diff
    del Q_full

    for x in data:
        del x

for m in range(1): 
    fig_all[m].legend(handles=legend_elements, loc='upper right')
    fig_all[m].subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

    for i in range(num_rows):
        for j in range(num_cols):
            k = j + i*num_rows
            axs_all[m][i, j].set_title(f'Basis {k}')
            axs_all[m][i, j].set_ylim(0, ylims[k])
            axs_all[m][i, j].axvline(x=0, color='orange', linewidth=2, alpha=0.3)
            axs_all[m][-1, j].set_xlabel("Time (s)")

    fig_all[m].suptitle(f'Super Title')

    figs_dir_full = f'{figs_dir}/{titles[r]}'
    os.makedirs(figs_dir_full, exist_ok=True)
    fig_all[m].savefig(f'{figs_dir_full}/{titles[0]}.png', facecolor='white')


distal_pinch_220616_115848_000
-0.5099999999999998
elec_proximal_100Hz_5mA_220616_124802_000
-0.5200000000000014
elec_proximal_100Hz_5mA_220616_125117_000
-0.5
distal_pinch_220616_120052_000
-0.5133333333333336
distal_pinch_to_check_reaction_220616_132852_000
-0.5066666666666659
elec_distal_100Hz_5mA_220616_125918_000
-0.5266666666666673
pinch_distal_220823_123854_000
-0.5133333333333336
elec_distal_100Hz_5mA_220616_130314_000
-0.5
elec_distal_100Hz_5mA_220616_131207_000
-0.5099999999999998
elec_proximal_100Hz_5mA_220616_125500_000
-0.5266666666666673
elec_right_cord_100Hz_5mA_220616_131822_000
-0.5033333333333339
elec_right_cord_100Hz_5mA_220616_132230_000
-0.5033333333333321
elec_right_cord_100Hz_5mA_220616_132733_000
-0.5
distal_pinch_220712_121401_000
-0.5
distal_pinch_220712_121651_000
-0.5
distal_pinch_220712_121902_000
-0.5
proximal_pinch_220712_122016_000
-0.5166666666666657
proximal_pinch_220712_122610_000
-0.5
pinch_distal_220823_125213_000
-0.5033333333333339
pinch_distal_22

KeyboardInterrupt: 

In [None]:
# Plot all experiments separately with basis as separate plots, and stimulation type as separate subplots

stim_class_list = sorted(files_info['Stimulation Class'].unique().tolist())
indices = files_info.index

num_figures = len(stim_class_list)
num_rows = int(np.ceil(np.sqrt(num_figures)))
num_cols = num_figures//num_rows

movement_types = [
    "No movement",
    "Movement",
    "Movement with arm curl"
]

fig_all = []
axs_all = []

num_lines = 3 # for 3 movement classes
colors = plt.cm.Paired(np.linspace(0, 1, num=num_lines))

titles = [
    # "n-Norm of Q",
    # "t-Diff of n-Norm of Q",
    "n-Norm of t-Diff of Q",
    # "Coordinates in redu. space",
]

for k in range(PROSVD_K):
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 8), gridspec_kw={'top': 0.85})
    fig_all.append(fig)
    axs_all.append(axs)

for video_idx in tqdm(indices):
    video_filename = filenames[video_idx]

    figs_dir = get_fig_dir(video_filename)
    data_dir = get_data_dir(video_filename)

    try:
        Q_full = np.load(f'{data_dir}/Q.npy')
        coordi_full = np.load(f'{data_dir}/score.npy')
    except:
        print("File not found: ", data_dir)
        continue

    total_frames = Q_full.shape[0]

    row = None
    if IS_METADATA_PRESENT:
        row = files_info.iloc[video_idx]

    stim_class = row["Stimulation Class"]

    metadata = load_metadata(video_filename, total_frames, row, fps, init_frame)

    Q_norm = np.linalg.norm(Q_full, axis=1)
    Q_diff_norm = np.diff(Q_norm, axis=0, prepend=1)
    Q_diff = np.diff(Q_full, axis=0)
    Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
    Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

    Q_kdiff_norm = np.diff(Q_norm, axis=-1)
    Q_kdiff_diff_norm = np.diff(Q_diff_norm, axis=-1)
    Q_kdiff_norm_diff = np.diff(Q_norm_diff, axis=-1)


    ## PLOTTING ##

    tx = np.arange(init_frame, init_frame+coordi_full.shape[0], dtype=float)
    tx /= fps
    tx -= metadata['stim']['t']

    data = [
        # Q_norm,
        # Q_diff_norm,
        Q_norm_diff,
        # coordi_full,
    ]

    start_f = metadata['start']['f']
    end_f = metadata['end']['f']

    move_class = int(row['Classification'])
    stim_idx = stim_class_list.index(stim_class)
    i = stim_idx%num_rows
    j = stim_idx//num_rows

    for r in range(len(titles)):
        for k in range(PROSVD_K):
            data_smoothed = smooth_data(data[r][:, k])
            axs_all[k][i, j].plot(tx[start_f: end_f], data_smoothed[start_f: end_f], c=colors[move_class], label=move_class, linewidth=2)
            axs_all[k][i, j].set_title(f'{stim_class}')

    for x in data:
        del x

    del Q_full
    del coordi_full
            
    
legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=movement_types[i]) for i in range(3)]

ylims = [1e-2, 4e-1, 4e-1, 4e-1]
for k in range(PROSVD_K):   
    fig_all[k].legend(handles=legend_elements, loc='upper right')
    fig_all[k].subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

    for i in range(num_rows):
        for j in range(num_cols):
            axs_all[k][i, j].set_ylim(0, ylims[k])
            axs_all[k][i, j].axvline(x=0, color='orange', linewidth=2, alpha=0.3)
            axs_all[k][-1, j].xlabel("Time (s)")

    fig_all[k].suptitle(f'Basis {k}')

    figs_dir_full = f'{figs_dir}/{titles[r]}'
    os.makedirs(figs_dir_full, exist_ok=True)
    fig_all[k].savefig(f'{figs_dir_full}/basis_{k}.png', facecolor='white')

In [None]:
for stim_class in files_info['Stimulation Class'].unique():
    indices = files_info[files_info['Stimulation Class'] == stim_class].index

    num_rows = int(np.ceil(np.sqrt(PROSVD_K)))
    num_cols = num_rows
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 6), gridspec_kw={'top': 0.8})

    for video_idx in indices:
        video_filename = filenames[video_idx]

        figs_dir = get_fig_dir(video_filename)
        data_dir = get_data_dir(video_filename)

        try:
            Q_full = np.load(f'{data_dir}/Q.npy')
            coordi_full = np.load(f'{data_dir}/score.npy')
        except:
            print("File not found: ", video_filename)
            continue

        total_frames = Q_full.shape[0]

        row = None
        if IS_METADATA_PRESENT:
            row = files_info.iloc[video_idx]

        metadata = load_metadata(video_filename, total_frames, row, fps, init_frame)

        Q_norm = np.linalg.norm(Q_full, axis=1)
        Q_diff_norm = np.diff(Q_norm, axis=0, prepend=1)
        Q_diff = np.diff(Q_full, axis=0)
        Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
        Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

        Q_kdiff_norm = np.diff(Q_norm, axis=-1)
        Q_kdiff_diff_norm = np.diff(Q_diff_norm, axis=-1)
        Q_kdiff_norm_diff = np.diff(Q_norm_diff, axis=-1)

        # coodi_diff = np.diff(coordi_full, axis=0)
        # coodi_diff = np.insert(coodi_diff, 0, 0, axis=0)


        ## PLOTTING ##

        num_lines = 3 # for 3 movement classes
        colors = plt.cm.Paired(np.linspace(0, 1, num=num_lines))

        tx = np.arange(init_frame, init_frame+coordi_full.shape[0], dtype=float)
        tx /= fps
        tx -= metadata['stim']['t']

        titles = [
            # "n-Norm of Q",
            # "t-Diff of n-Norm of Q",
            "n-Norm of t-Diff of Q",
            # "Coordinates in redu. space",
            # "Difference of coordinates in redu. space",
        ]

        data = [
            # Q_norm,
            # Q_diff_norm,
            Q_norm_diff,
            # coordi_full,
            # coodi_diff,
        ]

        start_f = metadata['start']['f']
        end_f = metadata['end']['f']

        for r in range(len(titles)):
            for i in range(num_rows):
                for j in range(num_cols):
                    k = i*num_cols + j
                    data_smoothed = smooth_data(data[r][:, k])
                    # pdb.set_trace()
                    move_class = int(row['Classification'])
                    axs[i, j].plot(tx[start_f: end_f], data_smoothed[start_f: end_f], c=colors[move_class], label=move_class, linewidth=2)
                    axs[i, j].set_title(f'Basis {k}')
                    axs[i, j].axvline(x=0, color='orange', linewidth=2, alpha=0.3)

        for x in data:
            del x

        del Q_full
        del coordi_full

    legend_elements = [Line2D([0], [0], color=colors[i], lw=4, label=movement_types[i]) for i in range(3)]
    
    fig.legend(handles=legend_elements, loc='upper right')
    # plt.tight_layout(pad=5.0)
    plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

    plt.suptitle(stim_class)

    figs_dir_full = f'{figs_dir}/{titles[r]}'
    os.makedirs(figs_dir_full, exist_ok=True)
    plt.savefig(f'{figs_dir_full}/{stim_class}.png', facecolor='white')
    plt.close()

In [None]:
for video_idx in range(len(filenames)): #range(5): #tqdm(range(len(filenames))):
    video_filename = filenames[video_idx]

    figs_dir = get_fig_dir(video_filename)
    data_dir = get_data_dir(video_filename)

    try:
        Q_full = np.load(f'{data_dir}/Q.npy')
        coordi_full = np.load(f'{data_dir}/score.npy')
    except:
        print("File not found: ", video_filename)
        continue

    total_frames = Q_full.shape[0]

    row = None
    if IS_METADATA_PRESENT:
        row = files_info.iloc[video_idx]
    metadata = load_metadata(video_filename, total_frames, row, fps, init_frame)

    Q_norm = np.linalg.norm(Q_full, axis=1)
    Q_diff_norm = np.diff(Q_norm, axis=0, prepend=1)
    Q_diff = np.diff(Q_full, axis=0)
    Q_norm_diff = np.linalg.norm(Q_diff, axis=1)
    Q_norm_diff = np.insert(Q_norm_diff, 0, 0, axis=0)

    Q_kdiff_norm = np.diff(Q_norm, axis=-1)
    Q_kdiff_diff_norm = np.diff(Q_diff_norm, axis=-1)
    Q_kdiff_norm_diff = np.diff(Q_norm_diff, axis=-1)

    # coodi_diff = np.diff(coordi_full, axis=0)
    # coodi_diff = np.insert(coodi_diff, 0, 0, axis=0)


    ## PLOTTING ##

    num_lines = PROSVD_K
    colors = plt.cm.Paired(np.linspace(0, 1, num=num_lines))

    tx = np.arange(init_frame, init_frame+coordi_full.shape[0], dtype=float)
    tx /= fps
    # tx -= metadata['stim']['t']

    titles = [
        "n-Norm of Q",
        "t-Diff of n-Norm of Q",
        "n-Norm of t-Diff of Q",
        "Coordinates in redu. space",
        # "Difference of coordinates in redu. space",
    ]

    data = [
        Q_norm,
        Q_diff_norm,
        Q_norm_diff,
        coordi_full,
        # coodi_diff,
    ]

    start_f = metadata['start']['f']
    end_f = metadata['end']['f']

    for r in range(len(titles)):
        for i in range(PROSVD_K):
            data_smoothed = smooth_data(data[r][:, i])
            plt.plot(tx[start_f: end_f], data_smoothed[start_f: end_f], c=colors[i], label=i, linewidth=2)
        title = f'{metadata["title"]} ({titles[r]})'
        plt.title(title)
        plt.xlabel("Time (s)")
        plt.legend(title="Basis vectors")
        if IS_METADATA_PRESENT:
            plt.axvline(x=0, color='orange', linewidth=3, alpha=0.5)
        # plt.show()

        figs_dir_full = f'{figs_dir}/{titles[r]}'
        os.makedirs(figs_dir_full, exist_ok=True)
        plt.savefig(f'{figs_dir_full}/{metadata["filename"]}.png', facecolor='white')
        plt.close()

    for x in data:
        del x

    del Q_full
    del coordi_full