In [None]:
import sys
import os

# Get absolute path to 'src' folder relative to this notebook
src_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Imports

import os
import gc
from glob import glob
import joblib
import numpy as np
import pandas as pd
import time

# Scipy
from scipy.signal import butter, filtfilt, iirnotch, hilbert
from scipy.stats import kurtosis
from scipy.io import savemat 

# Scikit-Learn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Settings
CONTRALATERAL_BASE_PATH =  '../data/'
DATA_ROOT_PATH = '/home/linux-pc/gh/projects/NeuralNexus/New-Features/Thought-to-Motion/CRCNS/src/motor_cortex/data/data/Contralateral/2018-04-12_(S4)/'

ECOG_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data.csv'
MOTION_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data.csv'

ECOG_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data_DATA_ONLY.csv'
MOTION_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data_DATA_ONLY.csv'

CONTRALATERAL_ECOG_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + ECOG_DATA_FILENAME
CONTRALATERAL_MOTION_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + MOTION_DATA_FILENAME


MOTION_NP = "../data/motion_values_normalized.npy"
ECOG_NP = "../data/ecog_values_normalized.npy"

from models.dataset import MotionECoGDataset

from torch.utils.data import Subset
from torch.utils.data import DataLoader

# Model Classes
from models.variational_motion_encoder import VariationalMotionEncoder
from models.waveform_decoder import WaveformDecoder
from models.variational_waveform_encoder import VariationalWaveformEncoder
from models.motion_decoder import MotionDecoder

import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from tqdm import tqdm

# Paths to models
model_paths = {
    "motion_encoder": "../models/motion_encoder_best.pt",
    "waveform_decoder": "../models/waveform_decoder_best.pt",
    "waveform_encoder": "../models/waveform_encoder_best.pt",
    "motion_decoder": "../models/motion_decoder_best.pt",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Load Data

In [None]:
dataset = MotionECoGDataset(MOTION_NP, ECOG_NP)
test_indices = torch.load("../models/test_indices.pt")
test_dataset = Subset(dataset, test_indices)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False)

In [None]:
# Instantiating the Model Classes

# Initialize models
latent_dim = 128
motion_encoder = VariationalMotionEncoder(latent_dim=latent_dim).to(device)
# MotionEncoder(latent_dim).to(device)
waveform_decoder = WaveformDecoder(latent_dim=latent_dim).to(device)
waveform_encoder = VariationalWaveformEncoder(
    input_dim=64, hidden_dim=128, latent_dim=128
).to(device)
# WaveformEncoder(latent_dim).to(device)
motion_decoder = MotionDecoder(latent_dim).to(device)

# Load the Models
motion_encoder.load_state_dict(torch.load(model_paths["motion_encoder"], map_location=device))
waveform_decoder.load_state_dict(torch.load(model_paths["waveform_decoder"], map_location=device))
waveform_encoder.load_state_dict(torch.load(model_paths["waveform_encoder"], map_location=device))
motion_decoder.load_state_dict(torch.load(model_paths["motion_decoder"], map_location=device))

# Set to eval mode
motion_encoder.eval()
waveform_decoder.eval()
waveform_encoder.eval()
motion_decoder.eval()


In [None]:
GLOBAL_MOTION_STD = np.load('../data/global_motion_std.npy')
GLOBAL_MOTION_MEAN = np.load('../data/global_motion_mean.npy')

## Analysis

In [None]:
motion_np = "../data/motion_values_normalized.npy"
ecog_np = "../data/ecog_values_normalized.npy"

In [None]:
val_split = 0.2
test_split = 0.1
batch_size = 128
epochs=300

In [None]:
dataset = MotionECoGDataset(motion_np, ecog_np)
total_length = len(dataset)
test_size = int(test_split * total_length)
remaining_size = total_length - test_size
val_size = int(val_split * remaining_size)
train_size = remaining_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=False
)

In [None]:
for epoch in range(epochs):
    for batch in tqdm(test_dataset_sample.indices["test_indices"], desc=f"Epoch {epoch+1}/{epochs}"):
                motion = batch["motion"].to(device)
                ecog = batch["ecog"].to(device)


In [None]:
test_loader

In [None]:
test_dataset.indices

In [None]:
dataset_sample = MotionECoGDataset(MOTION_NP, ECOG_NP)
test_indices_sample = torch.load("../models/test_indices.pt")
test_dataset_sample = Subset(dataset_sample, test_indices_sample)
test_loader_sample = DataLoader(test_dataset_sample, batch_size=64, shuffle=False, drop_last=False)

In [None]:
for epoch in range(epochs):
    for batch in tqdm(test_loader_sample, desc=f"Epoch {epoch+1}/{epochs}"):
                motion = batch["motion"].to(device)
                ecog = batch["ecog"].to(device)


In [None]:
test_dataset_sample

In [None]:
test_dataset

In [None]:
test_loader_2 = DataLoader(
    test_dataset_sample, batch_size=batch_size, shuffle=False, drop_last=False
)

## Reconstruction

In [None]:

# Collect motion predictions
predicted_coordinates = []
actual_coordinates = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating Motion Predictions"):
        motion = batch["motion"].to(device)
        ecog = batch["ecog"].to(device)

        # Full latent pass
        motion_latent, _, _ = motion_encoder(motion)
        ecog_synth = waveform_decoder(motion_latent)
        waveform_latent, _, _ = waveform_encoder(ecog_synth)
        motion_reconstructed = motion_decoder(waveform_latent)

        # motion_reconstructed shape: (B, 3)
        # Normalize back to the real space

        predicted_coordinates.append((motion_reconstructed.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)
        actual_coordinates.append((motion.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)

predicted_coords = np.concatenate(predicted_coordinates, axis=0)
actual_coordinates = np.concatenate(actual_coordinates, axis=0)

## Visualization


### Single Instance


In [None]:
import plotly.graph_objects as go
import numpy as np

# Downsample for clarity (every 10th point)
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]

fig = go.Figure()

# Actual Trajectory
fig.add_trace(go.Scatter3d(
    x=actual[:, 0], y=actual[:, 1], z=actual[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='green'),
    line=dict(color='green', width=4),
    name='Actual Motion'
))

# Predicted Trajectory
fig.add_trace(go.Scatter3d(
    x=predicted[:, 0], y=predicted[:, 1], z=predicted[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='red'),
    line=dict(color='red', width=4),
    name='Predicted Motion'
))

fig.update_layout(
    title="Predicted vs Actual 3D Wrist Motion",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.7, y=0.9)
)

fig.show()


### Temporal Development Visualization

In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]
N = actual.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=actual[:k, 0], y=actual[:k, 1], z=actual[:k, 2],
                mode='lines+markers',
                line=dict(color='green', width=4),
                marker=dict(size=3, color='green'),
                name='Actual'
            ),
            go.Scatter3d(
                x=predicted[:k, 0], y=predicted[:k, 1], z=predicted[:k, 2],
                mode='lines+markers',
                line=dict(color='red', width=4),
                marker=dict(size=3, color='red'),
                name='Predicted'
            )
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[actual[0, 0]], y=[actual[0, 1]], z=[actual[0, 2]],
        mode='markers', marker=dict(size=5, color='green'), name='Actual'
    ),
    go.Scatter3d(
        x=[predicted[0, 0]], y=[predicted[0, 1]], z=[predicted[0, 2]],
        mode='markers', marker=dict(size=5, color='red'), name='Predicted'
    )
], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Predicted vs Actual',
    scene=dict(
        xaxis=dict(range=[min(actual[:, 0].min(), predicted[:, 0].min()),
                          max(actual[:, 0].max(), predicted[:, 0].max())], title='X'),
        yaxis=dict(range=[min(actual[:, 1].min(), predicted[:, 1].min()),
                          max(actual[:, 1].max(), predicted[:, 1].max())], title='Y'),
        zaxis=dict(range=[min(actual[:, 2].min(), predicted[:, 2].min()),
                          max(actual[:, 2].max(), predicted[:, 2].max())], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Actual Motion Only

In [None]:
actual_coordinates

In [None]:
import plotly.graph_objects as go
import numpy as np

# Downsample for clarity (every 10th point)
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]

fig = go.Figure()

# Actual Trajectory
fig.add_trace(go.Scatter3d(
    x=actual[:, 0], y=actual[:, 1], z=actual[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='purple'),
    line=dict(color='teal', width=4),
    name='Actual Motion'
))

# Predicted Trajectory
# fig.add_trace(go.Scatter3d(
#     x=predicted[:, 0], y=predicted[:, 1], z=predicted[:, 2],
#     mode='lines+markers',
#     marker=dict(size=3, color='red'),
#     line=dict(color='red', width=4),
#     name='Predicted Motion'
# ))

fig.update_layout(
    title="Actual 3D Wrist Motion",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.7, y=0.9)
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 2
actual = actual_coordinates[::step]
# predicted = predicted_coords[::step]
N = actual.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=actual[:k, 0], y=actual[:k, 1], z=actual[:k, 2],
                mode='lines+markers',
                line=dict(color='purple', width=4),
                marker=dict(size=3, color='teal'),
                name='Actual'
            ),
            # go.Scatter3d(
            #     x=predicted[:k, 0], y=predicted[:k, 1], z=predicted[:k, 2],
            #     mode='lines+markers',
            #     line=dict(color='red', width=4),
            #     marker=dict(size=3, color='red'),
            #     name='Predicted'
            # )
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[actual[0, 0]], y=[actual[0, 1]], z=[actual[0, 2]],
        mode='markers', marker=dict(size=5, color='teal'), name='Actual'
    ),
    # go.Scatter3d(
    #     x=[predicted[0, 0]], y=[predicted[0, 1]], z=[predicted[0, 2]],
    #     mode='markers', marker=dict(size=5, color='red'), name='Predicted'
    # )
], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Predicted vs Actual',
    scene=dict(
        xaxis=dict(range=[actual[:, 0].min(),
                          actual[:, 0].max()], title='X'),
        yaxis=dict(range=[actual[:, 1].min(),
                          actual[:, 1].max()], title='Y'),
        zaxis=dict(range=[actual[:, 2].min(),
                          actual[:, 2].max()], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Simulating Spams

In [None]:
# Create Physical Movements with introduced jitter
# Create Synthetic Waveforms from this new dataset
# Detect Anomalies in the synthetic waveforms in real time
# Calculate the Counter Current
# Calculate the Adjusted Signal after this current is introduced
# Reconstruct the physical motion (proceed as normal) 
# Evaluate Results (Original Motion should be the same as Hyperpolarized Reconstructed Spasm Motion)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def add_spasms_to_motion(data, num_spasms=10, spasm_duration=5, max_amplitude=0.2, seed=None):
    """
    Adds simulated spasms to a 3D motion trajectory.

    Parameters:
    - data: np.ndarray of shape (N, 3), the original xyz coordinates.
    - num_spasms: number of distinct spasm events to simulate.
    - spasm_duration: how many consecutive frames each spasm lasts.
    - max_amplitude: maximum displacement of spasm noise.
    - seed: random seed for reproducibility.

    Returns:
    - spasm_data: modified data with added spasms.
    - spasm_indices: indices where spasms occurred.
    """
    if seed is not None:
        np.random.seed(seed)

    data = data.copy()
    total_frames = data.shape[0]

    spasm_indices = np.sort(np.random.choice(
        total_frames - spasm_duration, num_spasms, replace=False))

    for idx in spasm_indices:
        for i in range(spasm_duration):
            noise = np.random.normal(loc=0.0, scale=max_amplitude, size=3)
            data[idx + i] += noise

    return data, spasm_indices


In [None]:
actual_coordinates

In [None]:
spasm_indices

In [None]:
import numpy as np
import matplotlib.pyplot as plt
SPASM_DURATION = 50
# Assuming actual_coordinates and spasm_data are defined
spasm_data, spasm_indices = add_spasms_to_motion(
    actual_coordinates,
    num_spasms=15,
    spasm_duration=SPASM_DURATION,
    max_amplitude=0.2,
    seed=42
)

plt.figure(figsize=(12, 6))
plt.plot(spasm_data[:, 0], label='Spasm X', linewidth=1.2)
plt.plot(actual_coordinates[:, 0], label='Original X', linewidth=0.6)

# Highlight spasm durations
for idx in spasm_indices:
    start = idx
    end = idx + SPASM_DURATION  # spasm_duration
    plt.axvspan(start, end, color='red', alpha=0.15)

plt.title('X Coordinate with Highlighted Spasm Durations')
plt.xlabel('Frame Index')
plt.ylabel('X Position')
plt.legend()
plt.tight_layout()
plt.show()


Rationale behind simulated spasms hyperparameters:


The rhesus macaque has an armspan of approximately 2 feet. The units of this data are unlisted so I am taking the liberty to assert that they are units of feet given that the monkey is reaching for food on the table and placing it in its mouth. Given this, I am introducing 0.2 of a foot into this system indicating 2.4 inches of spasm for 50 samples. At 50 Hz, the spasm is a single second. Spasms typically are up to 15 cm, but 2.4 inches for a spasm does not seem outlandish and falls within the context of the reality of the problem. 

In [None]:
import plotly.graph_objects as go
import numpy as np

# Downsample for clarity
step = 10
actual = actual_coordinates[::step]
predicted = spasm_data[::step]

# Adjust spasm indices for downsampling
spasm_duration = 50
spasm_points = []
for idx in spasm_indices:
    spasm_points.extend(range(idx, idx + spasm_duration))
spasm_points = sorted(set(spasm_points))

# Filter points that remain after downsampling
spasm_points_downsampled = [
    i for i in spasm_points if i % step == 0 and (i // step) < len(predicted)
]
spasm_frame_indices = [i for i in spasm_points_downsampled]
spasm_coords = predicted[np.array(spasm_frame_indices) // step]

# Create hover text for each spasm point
hover_texts = [f"Spasm at Frame {i}" for i in spasm_frame_indices]

fig = go.Figure()

# Actual Trajectory
fig.add_trace(go.Scatter3d(
    x=actual[:, 0], y=actual[:, 1], z=actual[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='purple'),
    line=dict(color='teal', width=4),
    name='Actual Motion'
))

# Predicted Trajectory
fig.add_trace(go.Scatter3d(
    x=predicted[:, 0], y=predicted[:, 1], z=predicted[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='red'),
    line=dict(color='cyan', width=4),
    name='Predicted Motion'
))

# Spasm Points with Hover Text
fig.add_trace(go.Scatter3d(
    x=spasm_coords[:, 0], y=spasm_coords[:, 1], z=spasm_coords[:, 2],
    mode='markers',
    marker=dict(size=6, color='yellow', symbol='diamond'),
    text=hover_texts,
    hoverinfo='text',
    name='Spasm Points'
))

fig.update_layout(
    title="Actual vs Predicted 3D Motion with Spasms Highlighted",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.7, y=0.9)
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 2
actual = spasm_data[::step]
# predicted = predicted_coords[::step]
N = actual.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=actual[:k, 0], y=actual[:k, 1], z=actual[:k, 2],
                mode='lines+markers',
                line=dict(color='purple', width=4),
                marker=dict(size=3, color='teal'),
                name='Actual'
            ),
            # go.Scatter3d(
            #     x=predicted[:k, 0], y=predicted[:k, 1], z=predicted[:k, 2],
            #     mode='lines+markers',
            #     line=dict(color='red', width=4),
            #     marker=dict(size=3, color='red'),
            #     name='Predicted'
            # )
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[actual[0, 0]], y=[actual[0, 1]], z=[actual[0, 2]],
        mode='markers', marker=dict(size=5, color='teal'), name='Actual'
    ),
    # go.Scatter3d(
    #     x=[predicted[0, 0]], y=[predicted[0, 1]], z=[predicted[0, 2]],
    #     mode='markers', marker=dict(size=5, color='red'), name='Predicted'
    # )
], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Predicted vs Actual',
    scene=dict(
        xaxis=dict(range=[actual[:, 0].min(),
                          actual[:, 0].max()], title='X'),
        yaxis=dict(range=[actual[:, 1].min(),
                          actual[:, 1].max()], title='Y'),
        zaxis=dict(range=[actual[:, 2].min(),
                          actual[:, 2].max()], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Detecting Onset of Spams (Anomalies)

## Hodgkin & Huxley Calculation of Current to Induce Hyperpolarization and prevent Simulated Spasms

## Simulate Detecting Onset and Preventing Spasms to Create Laminar or Normal Healthy Motion 