In [None]:
import sys
import os

# Get absolute path to 'src' folder relative to this notebook
src_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Imports

import os
import gc
from glob import glob
import joblib
import numpy as np
import pandas as pd
import time

# Scipy
from scipy.signal import butter, filtfilt, iirnotch, hilbert
from scipy.stats import kurtosis
from scipy.io import savemat 

# Scikit-Learn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Settings
CONTRALATERAL_BASE_PATH =  '../data/'
DATA_ROOT_PATH = '/home/linux-pc/gh/projects/NeuralNexus/New-Features/Thought-to-Motion/CRCNS/src/motor_cortex/data/data/Contralateral/2018-04-12_(S4)/'

ECOG_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data.csv'
MOTION_DATA_FILENAME = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data.csv'

ECOG_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_ecog_data_DATA_ONLY.csv'
MOTION_DATA_FILENAME_DATA_ONLY = 'Contralateral_2018-04-12_(S4)_cleaned_aligned_motion_data_DATA_ONLY.csv'

CONTRALATERAL_ECOG_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + ECOG_DATA_FILENAME
CONTRALATERAL_MOTION_DATA_FULL_FILE_PATH = CONTRALATERAL_BASE_PATH + MOTION_DATA_FILENAME


MOTION_NP = "../data/motion_values_normalized.npy"
ECOG_NP = "../data/ecog_values_normalized.npy"

from models.dataset import MotionECoGDataset

from torch.utils.data import Subset
from torch.utils.data import DataLoader

# Model Classes
from models.variational_motion_encoder import VariationalMotionEncoder
from models.waveform_decoder import WaveformDecoder
from models.variational_waveform_encoder import VariationalWaveformEncoder
from models.motion_decoder import MotionDecoder
from models.spasm_dataset import SpasmDataset

import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from tqdm import tqdm

# Paths to models
model_paths = {
    "motion_encoder": "../models/motion_encoder_best.pt",
    "waveform_decoder": "../models/waveform_decoder_best.pt",
    "waveform_encoder": "../models/waveform_encoder_best.pt",
    "motion_decoder": "../models/motion_decoder_best.pt",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def add_spasms_to_motion(data, num_spasms=10, spasm_duration=5, max_amplitude=0.2, seed=None):
    """
    Adds simulated spasms to a 3D motion trajectory.

    Parameters:
    - data: np.ndarray of shape (N, 3), the original xyz coordinates.
    - num_spasms: number of distinct spasm events to simulate.
    - spasm_duration: how many consecutive frames each spasm lasts.
    - max_amplitude: maximum displacement of spasm noise.
    - seed: random seed for reproducibility.

    Returns:
    - spasm_data: modified data with added spasms.
    - spasm_indices: indices where spasms occurred.
    """
    if seed is not None:
        np.random.seed(seed)

    data = data.copy()
    total_frames = data.shape[0]

    spasm_indices = np.sort(np.random.choice(
        total_frames - spasm_duration, num_spasms, replace=False))

    for idx in spasm_indices:
        for i in range(spasm_duration):
            noise = np.random.normal(loc=0.0, scale=max_amplitude, size=3)
            data[idx + i] += noise

    return data, spasm_indices


## Load Data

In [None]:
dataset = MotionECoGDataset(MOTION_NP, ECOG_NP)
test_indices = torch.load("../models/test_indices.pt")
test_dataset = Subset(dataset, test_indices)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False)

In [None]:
# Instantiating the Model Classes

# Initialize models
latent_dim = 128
motion_encoder = VariationalMotionEncoder(latent_dim=latent_dim).to(device)
# MotionEncoder(latent_dim).to(device)
waveform_decoder = WaveformDecoder(latent_dim=latent_dim).to(device)
waveform_encoder = VariationalWaveformEncoder(
    input_dim=64, hidden_dim=128, latent_dim=128
).to(device)
# WaveformEncoder(latent_dim).to(device)
motion_decoder = MotionDecoder(latent_dim).to(device)

# Load the Models
motion_encoder.load_state_dict(torch.load(model_paths["motion_encoder"], map_location=device))
waveform_decoder.load_state_dict(torch.load(model_paths["waveform_decoder"], map_location=device))
waveform_encoder.load_state_dict(torch.load(model_paths["waveform_encoder"], map_location=device))
motion_decoder.load_state_dict(torch.load(model_paths["motion_decoder"], map_location=device))

# Set to eval mode
motion_encoder.eval()
waveform_decoder.eval()
waveform_encoder.eval()
motion_decoder.eval()


In [None]:
GLOBAL_MOTION_STD = np.load('../data/global_motion_std.npy')
GLOBAL_MOTION_MEAN = np.load('../data/global_motion_mean.npy')

## Analysis

In [None]:
motion_np = "../data/motion_values_normalized.npy"
ecog_np = "../data/ecog_values_normalized.npy"

In [None]:
val_split = 0.2
test_split = 0.1
batch_size = 128
epochs=300

In [None]:
dataset = MotionECoGDataset(motion_np, ecog_np)
total_length = len(dataset)
test_size = int(test_split * total_length)
remaining_size = total_length - test_size
val_size = int(val_split * remaining_size)
train_size = remaining_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=False
)

In [None]:
for epoch in range(epochs):
    for batch in tqdm(test_dataset_sample.indices["test_indices"], desc=f"Epoch {epoch+1}/{epochs}"):
                motion = batch["motion"].to(device)
                ecog = batch["ecog"].to(device)


In [None]:
test_loader

In [None]:
test_dataset.indices

In [None]:
dataset_sample = MotionECoGDataset(MOTION_NP, ECOG_NP)
test_indices_sample = torch.load("../models/test_indices.pt")
test_dataset_sample = Subset(dataset_sample, test_indices_sample)
test_loader_sample = DataLoader(test_dataset_sample, batch_size=64, shuffle=False, drop_last=False)

In [None]:
for epoch in range(epochs):
    for batch in tqdm(test_loader_sample, desc=f"Epoch {epoch+1}/{epochs}"):
                motion = batch["motion"].to(device)
                ecog = batch["ecog"].to(device)


In [None]:
test_dataset_sample

In [None]:
test_dataset

In [None]:
test_loader_2 = DataLoader(
    test_dataset_sample, batch_size=batch_size, shuffle=False, drop_last=False
)

## Reconstruction

In [None]:
# Collect motion predictions
predicted_coordinates = []
actual_coordinates = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating Motion Predictions"):
        motion = batch["motion"].to(device)
        # Full latent pass
        motion_latent, _, _ = motion_encoder(motion)
        ecog_synth = waveform_decoder(motion_latent)
        waveform_latent, _, _ = waveform_encoder(ecog_synth)
        motion_reconstructed = motion_decoder(waveform_latent)

        # motion_reconstructed shape: (B, 3)
        # Normalize back to the real space

        predicted_coordinates.append((motion_reconstructed.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)
        actual_coordinates.append((motion.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)

predicted_coords = np.concatenate(predicted_coordinates, axis=0)
actual_coordinates = np.concatenate(actual_coordinates, axis=0)

## Visualization


### Single Instance


In [None]:
import plotly.graph_objects as go
import numpy as np

# Downsample for clarity (every 10th point)
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]

fig = go.Figure()

# Actual Trajectory
fig.add_trace(go.Scatter3d(
    x=actual[:, 0], y=actual[:, 1], z=actual[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='green'),
    line=dict(color='green', width=4),
    name='Actual Motion'
))

# Predicted Trajectory
fig.add_trace(go.Scatter3d(
    x=predicted[:, 0], y=predicted[:, 1], z=predicted[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='red'),
    line=dict(color='red', width=4),
    name='Predicted Motion'
))

fig.update_layout(
    title="Predicted vs Actual 3D Wrist Motion",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.7, y=0.9)
)

fig.show()


### Temporal Development Visualization

In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]
N = actual.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=actual[:k, 0], y=actual[:k, 1], z=actual[:k, 2],
                mode='lines+markers',
                line=dict(color='green', width=4),
                marker=dict(size=3, color='green'),
                name='Actual'
            ),
            go.Scatter3d(
                x=predicted[:k, 0], y=predicted[:k, 1], z=predicted[:k, 2],
                mode='lines+markers',
                line=dict(color='red', width=4),
                marker=dict(size=3, color='red'),
                name='Predicted'
            )
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[actual[0, 0]], y=[actual[0, 1]], z=[actual[0, 2]],
        mode='markers', marker=dict(size=5, color='green'), name='Actual'
    ),
    go.Scatter3d(
        x=[predicted[0, 0]], y=[predicted[0, 1]], z=[predicted[0, 2]],
        mode='markers', marker=dict(size=5, color='red'), name='Predicted'
    )
], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Predicted vs Actual',
    scene=dict(
        xaxis=dict(range=[min(actual[:, 0].min(), predicted[:, 0].min()),
                          max(actual[:, 0].max(), predicted[:, 0].max())], title='X'),
        yaxis=dict(range=[min(actual[:, 1].min(), predicted[:, 1].min()),
                          max(actual[:, 1].max(), predicted[:, 1].max())], title='Y'),
        zaxis=dict(range=[min(actual[:, 2].min(), predicted[:, 2].min()),
                          max(actual[:, 2].max(), predicted[:, 2].max())], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Actual Motion Only

In [None]:
actual_coordinates

In [None]:
import plotly.graph_objects as go
import numpy as np

# Downsample for clarity (every 10th point)
step = 10
actual = actual_coordinates[::step]
predicted = predicted_coords[::step]

fig = go.Figure()

# Actual Trajectory
fig.add_trace(go.Scatter3d(
    x=actual[:, 0], y=actual[:, 1], z=actual[:, 2],
    mode='lines+markers',
    marker=dict(size=3, color='purple'),
    line=dict(color='teal', width=4),
    name='Actual Motion'
))

# Predicted Trajectory
# fig.add_trace(go.Scatter3d(
#     x=predicted[:, 0], y=predicted[:, 1], z=predicted[:, 2],
#     mode='lines+markers',
#     marker=dict(size=3, color='red'),
#     line=dict(color='red', width=4),
#     name='Predicted Motion'
# ))

fig.update_layout(
    title="Actual 3D Wrist Motion",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.7, y=0.9)
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 2
actual = actual_coordinates[::step]
# predicted = predicted_coords[::step]
N = actual.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=actual[:k, 0], y=actual[:k, 1], z=actual[:k, 2],
                mode='lines+markers',
                line=dict(color='purple', width=4),
                marker=dict(size=3, color='teal'),
                name='Actual'
            ),
            # go.Scatter3d(
            #     x=predicted[:k, 0], y=predicted[:k, 1], z=predicted[:k, 2],
            #     mode='lines+markers',
            #     line=dict(color='red', width=4),
            #     marker=dict(size=3, color='red'),
            #     name='Predicted'
            # )
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[actual[0, 0]], y=[actual[0, 1]], z=[actual[0, 2]],
        mode='markers', marker=dict(size=5, color='teal'), name='Actual'
    ),
    # go.Scatter3d(
    #     x=[predicted[0, 0]], y=[predicted[0, 1]], z=[predicted[0, 2]],
    #     mode='markers', marker=dict(size=5, color='red'), name='Predicted'
    # )
], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Actual',
    scene=dict(
        xaxis=dict(range=[actual[:, 0].min(),
                          actual[:, 0].max()], title='X'),
        yaxis=dict(range=[actual[:, 1].min(),
                          actual[:, 1].max()], title='Y'),
        zaxis=dict(range=[actual[:, 2].min(),
                          actual[:, 2].max()], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Simulating spasms

In [None]:
# Create Physical Movements with introduced jitter
# Create Synthetic Waveforms from this new dataset
# Detect Anomalies in the synthetic waveforms in real time
# Calculate the Counter Current
# Calculate the Adjusted Signal after this current is introduced
# Reconstruct the physical motion (proceed as normal) 
# Evaluate Results (Original Motion should be the same as Hyperpolarized Reconstructed Spasm Motion)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
SPASM_DURATION = 50 # Samples (sampling rate is 50 Hz)

# Assuming actual_coordinates and spasm_data are defined
spasm_data, spasm_indices = add_spasms_to_motion(
    actual_coordinates,
    num_spasms=15,
    spasm_duration=SPASM_DURATION,
    max_amplitude=0.2,
    seed=42
)

plt.figure(figsize=(12, 6))
plt.plot(spasm_data[:, 0], label='Spasm X', linewidth=1.2)
plt.plot(actual_coordinates[:, 0], label='Original X', linewidth=0.6)

# Highlight spasm durations
for idx in spasm_indices:
    start = idx
    end = idx + SPASM_DURATION  # spasm_duration
    plt.axvspan(start, end, color='red', alpha=0.15)

plt.title('X Coordinate with Highlighted Spasm Durations')
plt.xlabel('Frame Index')
plt.ylabel('X Position')
plt.legend()
plt.tight_layout()
plt.show()


Rationale behind simulated spasms hyperparameters:


The rhesus macaque has an armspan of approximately 2 feet. The units of this data are unlisted so I am taking the liberty to assert that they are units of feet given that the monkey is reaching for food on the table and placing it in its mouth. Given this, I am introducing 0.2 of a foot into this system indicating 2.4 inches of spasm for 50 samples. At 50 Hz, the spasm is a single second. Spasms typically are up to 15 cm, but 2.4 inches for a spasm does not seem outlandish and falls within the context of the reality of the problem. 

In [None]:
import plotly.graph_objects as go
import numpy as np
n_frames = 1
# Downsample
step = 2
spasm = spasm_data[::step]

N = spasm.shape[0]

# Create frames
frames = [
    go.Frame(
        data=[
            go.Scatter3d(
                x=spasm[:k, 0], y=spasm[:k, 1], z=spasm[:k, 2],
                mode='lines+markers',
                line=dict(color='purple', width=4),
                marker=dict(size=3, color='teal'),
                name='Actual'
            ),
        ],
        name=str(k)
    )
    for k in range(1, N + 1)
]

# Initial trace
fig = go.Figure(data=[
    go.Scatter3d(
        x=[spasm[0, 0]], y=[spasm[0, 1]], z=[spasm[0, 2]],
        mode='markers', marker=dict(size=5, color='teal'), name='Actual'
    ),

], frames=frames)

# Layout and controls
fig.update_layout(
    title='Animated 3D Wrist Motion: Simulated Spasms',
    scene=dict(
        xaxis=dict(range=[spasm[:, 0].min(),
                          spasm[:, 0].max()], title='X'),
        yaxis=dict(range=[spasm[:, 1].min(),
                          spasm[:, 1].max()], title='Y'),
        zaxis=dict(range=[spasm[:, 2].min(),
                          spasm[:, 2].max()], title='Z'),
    ),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[
            dict(
                label='Play',
                method='animate',
                args=[None, dict(frame=dict(duration=20, redraw=True), fromcurrent=True, mode='immediate')]
            ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]
            )
        ],
        x=0.1, y=0, xanchor='right', yanchor='top'
    )],
    sliders=[dict(
        active=0,
        pad=dict(t=50),
        steps=[
            dict(
                method='animate',
                args=[[str(k)], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                label=str(k)
            )
            for k in range(1, N + 1, n_frames)
        ]
    )]
)

fig.show()


## Creating Random Sythetic Spasm Waveforms: 

In [None]:
actual_coordinates

In [None]:
torch.tensor(actual_coordinates[0], device=device, dtype=torch.float32)

In [None]:
latents, mean, logvar = motion_encoder(torch.tensor(actual_coordinates[0], device=device, dtype=torch.float32).unsqueeze(0))

In [None]:
latents.shape

In [None]:
spasm_data[0]

In [None]:
len(test_loader)

In [None]:
len(spasm_data)

In [None]:
spasm_indices

In [None]:
spasm_data = spasm_data.astype(np.float32)

In [None]:
spasm_data.dtype

In [None]:
spasm_dataset = SpasmDataset(spasm_data, spasm_indices)
spasm_data_loader = DataLoader(spasm_dataset, batch_size=64, shuffle=False)
ecog_synth_spasm_all = []
# Collect motion predictions
with torch.no_grad():
    for batch in tqdm(spasm_data_loader, desc="Generating Simulated Spasm Waveforms"):
        spasm = batch["spasm"].to(device)

        # Full latent pass
        motion_latent, _, _ = motion_encoder(spasm)
        ecog_synth_spasm = waveform_decoder(motion_latent)
        # waveform_latent, _, _ = waveform_encoder(ecog_synth)
        # motion_reconstructed = motion_decoder(waveform_latent)

        # motion_reconstructed shape: (B, 3)
        # Normalize back to the real space

        ecog_synth_spasm_all.append(ecog_synth_spasm.cpu().numpy())

ecog_synth_spasm_all = np.concatenate(ecog_synth_spasm_all, axis=0)
spasm_dataset.ecog_synth_spasms = ecog_synth_spasm_all

In [None]:
spasm_dataset.ecog_synth_spasms.shape

In [None]:
# np.save("../data/simulated_spasm_data.npy", spasm_dataset.spasm_data)
# np.save("../data/simulated_spasm_indices.npy", spasm_dataset.spasm_indices)
# np.save("../data/ecog_synth_spasms.npy", spasm_dataset.ecog_synth_spasms)

## Load Simulated Spasm Data

In [None]:
spasm_data = np.load("../data/simulated_spasm_data.npy")
spasm_indices = np.load("../data/simulated_spasm_indices.npy")
ecog_synth_spasms = np.load("../data/ecog_synth_spasms.npy")

spasm_dataset = SpasmDataset(spasm_data, spasm_indices, ecog_synth_spasms)

In [None]:
spasm_dataset.ecog_synth_spasms

In [None]:
spasm_dataset.spasm_data

## Detecting Onset of spasms (Anomalies)

In [None]:
spasm_dataset.ecog_synth_spasms

### Analysis: comparing the ecog of non-spasm data at the same timestep as the spasms data ecog

In [None]:
actual_coordinates

In [None]:
"""
    Note: These are NOT spasms. They are real physical movements translated into synthetic waveforms with indices corresponding to the spasm dataset
"""

actual_coordinates_with_spasm_indices_dataset = SpasmDataset(actual_coordinates.astype(np.float32), spasm_indices)
actual_coordinates_with_spasm_indices_dataloader = DataLoader(actual_coordinates_with_spasm_indices_dataset, batch_size=64, shuffle=False)

ecog_synth_spasm_all = []

# Collect motion predictions
with torch.no_grad():
    for batch in tqdm(actual_coordinates_with_spasm_indices_dataloader, desc="Generating Simulated ECoG Waveforms"):
        spasm = batch["spasm"].to(device)

        # Full latent pass
        motion_latent, _, _ = motion_encoder(spasm)
        ecog_synth_spasm = waveform_decoder(motion_latent)
        # waveform_latent, _, _ = waveform_encoder(ecog_synth)
        # motion_reconstructed = motion_decoder(waveform_latent)

        # motion_reconstructed shape: (B, 3)
        # Normalize back to the real space

        ecog_synth_spasm_all.append(ecog_synth_spasm.cpu().numpy())

ecog_synth_spasm_all = np.concatenate(ecog_synth_spasm_all, axis=0)
actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms = ecog_synth_spasm_all

In [None]:
predicted_coordinates

In [None]:
spasm_indices

In [None]:
max_indices = len(actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms)

In [None]:
max_indices

In [None]:
# Get all indices from 0 to N-1
all_indices = np.arange(max_indices)

In [None]:
non_spasm_indices = np.setdiff1d(all_indices, actual_coordinates_with_spasm_indices_dataset.spasm_indices)

In [None]:
non_spasm_indices

In [None]:
actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms[non_spasm_indices][0]

In [None]:
sample_idx = 0
plt.figure(figsize=(12, 6))
for ch in range(0,64):
    plt.plot(actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms[non_spasm_indices][sample_idx][:, ch], label=f'Ch {ch}')
    
plt.title(f'Sample #{sample_idx}: True ECoG Synthetic Waveform Across Time')
plt.xlabel('Time Step (0–19)')
plt.ylabel('Amplitude')
# plt.legend()
plt.grid(True)
plt.show()


In [None]:
actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms

In [None]:
# The first Spasm
sample_idx = 0
spasm_idx = 0


In [None]:
spasm_idx

In [None]:
actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms[actual_coordinates_with_spasm_indices_dataset.spasm_indices[spasm_idx]]

In [None]:
spasm_dataset.ecog_synth_spasms[spasm_dataset.spasm_indices[spasm_idx]]

In [None]:
spasm_idx += 1 
spasm_idx %= 15
# Number of channels
num_ch = 64
plt.figure(figsize=(12, 6))
for ch in range(0,num_ch):
    plt.plot(actual_coordinates_with_spasm_indices_dataset.ecog_synth_spasms[actual_coordinates_with_spasm_indices_dataset.spasm_indices[spasm_idx]][:, ch], label=f'Ch {ch}')
    
plt.title(f'Sample #{spasm_idx}: True ECoG Synthetic Waveform Across Time')
plt.xlabel('Time Step (0–19)')
plt.ylabel('Amplitude')
# plt.legend()
plt.grid(True)
plt.show()


# The first Spasm
plt.figure(figsize=(12, 6))
for ch in range(0,num_ch):
    plt.plot(spasm_dataset.ecog_synth_spasms[spasm_dataset.spasm_indices[spasm_idx]][:, ch], label=f'Ch {ch}')
    
plt.title(f'Sample #{spasm_idx}: Synthetic Spasm ECoG Synthetic Waveform Across Time')
plt.xlabel('Time Step (0–19)')
plt.ylabel('Amplitude')
# plt.legend()
plt.grid(True)
plt.show()

In [None]:
# The first Spasm
sample_idx = 0
spasm_idx = 0
plt.figure(figsize=(12, 6))
for ch in range(0,64):
    plt.plot(spasm_dataset.ecog_synth_spasms[spasm_dataset.spasm_indices[0]][:, ch], label=f'Ch {ch}')
    
plt.title(f'Sample #{sample_idx}: True ECoG Synthetic Waveform Across Time')
plt.xlabel('Time Step (0–19)')
plt.ylabel('Amplitude')
# plt.legend()
plt.grid(True)
plt.show()

In [None]:
spasm_dataset.ecog_synth_spasms[spasm_dataset.spasm_indices[0]]

In [None]:
spasm_dataset.ecog_synth_spasms[spasm_dataset.spasm_indices[0]]

## Hodgkin & Huxley Calculation of Current to Induce Hyperpolarization and prevent Simulated Spasms

## Simulate Detecting Onset and Preventing Spasms to Create Laminar or Normal Healthy Motion 

## Creating Novel Input to the Motion Reconstruction to Test Utility

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def generate_upward_spiral_3d(num_points=500, num_turns=3):
    """
    Generates a 3D upward spiral within the cube bounded by [-1, 1] in all dimensions.
    
    Args:
        num_points (int): Number of points in the spiral.
        num_turns (int): Number of spiral turns from bottom to top.

    Returns:
        np.ndarray: A (num_points, 3) array of (x, y, z) coordinates.
    """
    # Angle parameter (theta)
    theta = np.linspace(0, num_turns * 2 * np.pi, num_points)

    # Z goes from -1 to 1 (upward)
    z = np.linspace(-1, 1, num_points)

    # Radius reduces toward top to stay within x,y bounds
    # radius = f(z): define radius such that max(r) < 1 and smooth tapering
    radius = 0.8 * (1 - np.abs(z))  # Tapers at top and bottom

    x = radius * np.cos(theta)
    y = radius * np.sin(theta)

    # Stack into Nx3 shape
    spiral_points = np.vstack((x, y, z)).T
    return spiral_points

# Generate points
points = generate_upward_spiral_3d()

# Plotting
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(points[:, 0], points[:, 1], points[:, 2], color='blue')

# Set bounds
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
ax.set_title('3D Upward Spiral (Normalized)')

plt.show()


In [None]:
points = points.astype(np.float32)

In [None]:
points.dtype

In [None]:
# Collect motion predictions
# The data is already normalized


def normalize_data(motion_real_normal_space):
    motion_reduced_space = (motion_real_normal_space - GLOBAL_MOTION_MEAN) / (GLOBAL_MOTION_STD)
    return motion_reduced_space


def denormalize_data(motion_reduced_space):
    motion_real_normal_space = (motion_reduced_space * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN
    return motion_real_normal_space


In [None]:
# Converting to a Tensor
points = torch.tensor(points, device=device)

In [None]:
actual_coordinates

In [None]:
MotionECoGDataset(points, ECOG_NP)

In [None]:
simulated_points_loader = DataLoader(points, batch_size=64, shuffle=False, drop_last=False)

In [None]:
simulated_coordinates = []

with torch.no_grad():
    for coordinate_set in tqdm(simulated_points_loader, desc="Generating Motion Predictions"):
        # motion = batch["motion"].to(device)

        # Full latent pass
        motion_latent, _, _ = motion_encoder(coordinate_set)
        ecog_synth = waveform_decoder(motion_latent)
        waveform_latent, _, _ = waveform_encoder(ecog_synth)
        motion_reconstructed = motion_decoder(waveform_latent)

        # motion_reconstructed shape: (B, 3)
        # Normalize back to the real space

        # predicted_coordinates.append((motion_reconstructed.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)
        simulated_coordinates.append((motion_reconstructed.cpu().numpy() * GLOBAL_MOTION_STD) + GLOBAL_MOTION_MEAN)

# predicted_coords = np.concatenate(predicted_coordinates, axis=0)
simulated_coordinates = np.concatenate(simulated_coordinates, axis=0)

In [None]:
original_points = points

In [None]:
simulated_coordinates

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def generate_upward_spiral_3d(num_points=500, num_turns=3):
    """
    Generates a 3D upward spiral within the cube bounded by [-1, 1] in all dimensions.
    
    Args:
        num_points (int): Number of points in the spiral.
        num_turns (int): Number of spiral turns from bottom to top.

    Returns:
        np.ndarray: A (num_points, 3) array of (x, y, z) coordinates.
    """
    # Angle parameter (theta)
    theta = np.linspace(0, num_turns * 2 * np.pi, num_points)

    # Z goes from -1 to 1 (upward)
    z = np.linspace(-1, 1, num_points)

    # Radius reduces toward top to stay within x,y bounds
    # radius = f(z): define radius such that max(r) < 1 and smooth tapering
    radius = 0.8 * (1 - np.abs(z))  # Tapers at top and bottom

    x = radius * np.cos(theta)
    y = radius * np.sin(theta)

    # Stack into Nx3 shape
    spiral_points = np.vstack((x, y, z)).T
    return spiral_points

# Generate points
points = generate_upward_spiral_3d()


In [None]:
simulated_coordinates

In [None]:
original_points

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_spiral(points):
    # Extract x, y, z from numpy array
    x = points[:, 0]
    y = points[:, 1]
    z = points[:, 2]

    # Create 3D line plot using Plotly
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='lines',
        line=dict(color='blue', width=4)
    )])

    # Set layout and bounds
    fig.update_layout(
        title='3D Upward Spiral (Normalized)',
        scene=dict(
            xaxis=dict(range=[-1, 1]),
            yaxis=dict(range=[-1, 1]),
            zaxis=dict(range=[-1, 1])
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )

    fig.show()


In [None]:
points = points.cpu().numpy()

In [None]:
original_points = original_points.cpu().numpy()

In [None]:
simulated_coordinates

In [None]:
plot_spiral(original_points)

In [None]:
plot_spiral(simulated_coordinates)

In [None]:
original_points

In [None]:
simulated_coordinates