In [None]:
import sys
import os

# Get absolute path to 'src' folder relative to this notebook
src_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Imports

import os
import gc
from glob import glob
import joblib
import numpy as np
import pandas as pd
import time

# Scipy
from scipy.signal import butter, filtfilt, iirnotch, hilbert
from scipy.stats import kurtosis
from scipy.io import savemat 

# Scikit-Learn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau



# Settings
CONTRALATERAL_BASE_PATH =  '../data/'
DATA_ROOT_PATH = '/home/linux-pc/gh/projects/NeuralNexus/New-Features/Thought-to-Motion/CRCNS/src/motor_cortex/data/data/Contralateral/2018-04-12_(S4)/'

ECOG_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data.csv'
MOTION_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data.csv'

ECOG_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data_DATA_ONLY.csv'
MOTION_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data_DATA_ONLY.csv'

CONTRALATERAL_ECOG_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + ECOG_DATA_FILENAME
CONTRALATERAL_MOTION_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + MOTION_DATA_FILENAME


MOTION_NP = "../data/motion_values_normalized.npy"
ECOG_NP = "../data/ecog_values_normalized.npy"

from models.dataset import MotionECoGDataset

from torch.utils.data import Subset
from torch.utils.data import DataLoader


## Load Data

In [None]:
os.getcwd()

In [None]:
dataset = MotionECoGDataset(MOTION_NP, ECOG_NP)
test_indices = torch.load("../models/test_indices.pt")
test_dataset = Subset(dataset, test_indices)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False)

## Visualization


### Single Instance


In [None]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Scatter3d(
    x=motion_df['Left_Wrist_X'],
    y=motion_df['Left_Wrist_Y'],
    z=motion_df['Left_Wrist_Z'],
    mode='lines',
    line=dict(color='teal', width=3),
    name='Trajectory'
)])

fig.update_layout(
    title="Interactive 3D Wrist Motion",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)

fig.show()


### Temporal Development Visualization

In [None]:
import plotly.graph_objects as go
import numpy as np

# Use smaller subset for performance (e.g. every 10th sample)
step = 10
x_vals = motion_df['Left_Wrist_X'].values[::step]
y_vals = motion_df['Left_Wrist_Y'].values[::step]
z_vals = motion_df['Left_Wrist_Z'].values[::step]

# Create frames for animation
frames = [
    go.Frame(
        data=[go.Scatter3d(
            x=x_vals[:k],
            y=y_vals[:k],
            z=z_vals[:k],
            mode='lines+markers',
            line=dict(color='teal', width=4),
            marker=dict(size=3, color='purple')
        )],
        name=str(k)
    )
    for k in range(1, len(x_vals) + 1)
]

# Create initial trace
initial_trace = go.Scatter3d(
    x=[x_vals[0]],
    y=[y_vals[0]],
    z=[z_vals[0]],
    mode='markers',
    marker=dict(size=5, color='purple')
)

# Build the figure
fig = go.Figure(
    data=[initial_trace],
    frames=frames
)

# Layout with play/pause buttons and slider
fig.update_layout(
    title='Animated 3D Wrist Motion',
    scene=dict(
        xaxis=dict(range=[x_vals.min(), x_vals.max()], title='X'),
        yaxis=dict(range=[y_vals.min(), y_vals.max()], title='Y'),
        zaxis=dict(range=[z_vals.min(), z_vals.max()], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[
                    None,
                    dict(frame=dict(duration=20, redraw=True),
                         fromcurrent=True, mode='immediate')
                ]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[
                    [None],
                    dict(frame=dict(duration=0, redraw=False),
                         mode='immediate')
                ]
            )
        ],
        x=0.1,
        y=0,
        xanchor='right',
        yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, len(x_vals) + 1, 20)
        ]
    )]
)

fig.show()
