In [17]:
import pandas as pd
import numpy as np
import xarray as xr

import os
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from src.helpers import *
from src.visualize import *
from src.trainer import *
from Models.models import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.metrics import *
from copy import deepcopy
import torch.utils.data as data
from torch.utils.data import Dataset

import pickle
import math

In [18]:
data_path = './Data/Processed_Data/Tidy_Sansa_13_04.pkl'

with open(data_path, 'rb') as file:
    df = pickle.load(file)

In [4]:
df['stim_params'].astype(str).unique()

array(['[0 0 0 0 0]', '[2.0e+00 4.5e+02 3.0e-01 5.0e+01     nan]',
       '[2.e+00 4.e+02 3.e-01 5.e+01    nan]',
       '[2.e+00 4.e+02 3.e-01 8.e+01    nan]',
       '[2.e+00 4.e+02 3.e-01 1.e+02    nan]',
       '[8.0e+00 5.5e+02 3.0e-01 5.0e+01     nan]',
       '[8.e+00 6.e+02 3.e-01 5.e+01    nan]'], dtype=object)

In [5]:
# Assuming 'df' is your DataFrame and 'column_name' is the name of the column containing the values
# Create a mapping dictionary
mapping = {
    '[0 0 0 0 0]': 0,
    '[2.0e+00 4.5e+02 3.0e-01 5.0e+01     nan]': 1,
    '[2.e+00 4.e+02 3.e-01 5.e+01    nan]': 1,
    '[2.e+00 4.e+02 3.e-01 8.e+01    nan]': 2,
    '[2.e+00 4.e+02 3.e-01 1.e+02    nan]': 3,
    '[8.0e+00 5.5e+02 3.0e-01 5.0e+01     nan]': 4,
    '[8.e+00 6.e+02 3.e-01 5.e+01    nan]': 4
}

# Replace values in the column using the mapping dictionary
df['stim_params'] = df['stim_params'].replace(mapping)

# Display the modified DataFrame
df.stim_params.unique()

array([0, 1, 2, 3, 4])

In [6]:
import umap

In [7]:
from umap import UMAP

In [8]:
df

Unnamed: 0,num,type,stim_params,trial_num,reach_num,time_sample,x,y,z,angles,both_spikes,both_rates,target_pos,id
0,1,BASELINE,0,0,0,0,"[142.3787841796875, 120.14838027954102, 141.06...","[68.0731086730957, 86.92890548706055, 102.2939...","[247.4068832397461, 205.18695068359375, 206.01...","[140.39723531825678, 96.92269467518659, 38.203...","[0.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.732...","[3.3464184, 49.63821, 14.292594, 8.748773, 75....","[219.76524353027344, 130.76919555664062, 213.9...",0_0
1,1,BASELINE,0,0,0,1,"[143.18093872070312, 121.68184661865234, 142.9...","[67.82530975341797, 87.30451202392578, 102.091...","[247.6982879638672, 205.6662368774414, 206.309...","[140.3775790359186, 96.64710391489533, 37.9640...","[0.0, 1.0, 0.0, 0.0, 1.7320508075688772, 0.0, ...","[1.5401073, 50.027313, 11.921029, 6.977782, 81...","[221.51820373535156, 128.51687622070312, 215.1...",0_0
2,1,BASELINE,0,0,0,2,"[144.1379623413086, 122.84801864624023, 144.52...","[68.13214492797852, 88.26971435546875, 102.012...","[247.58580780029297, 205.85302734375, 206.2166...","[140.68175551713975, 96.2359836269319, 35.7343...","[0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.7320508075688...","[0.6093754, 52.325947, 10.036302, 5.9196334, 8...","[223.21636199951172, 126.55763626098633, 215.4...",0_0
3,1,BASELINE,0,0,0,3,"[144.51087951660156, 123.78560638427734, 145.5...","[68.0840950012207, 88.68589401245117, 102.2706...","[247.51468658447266, 205.95562744140625, 206.3...","[140.2870518362514, 96.55560738764017, 34.7878...","[0.0, 1.4142135623730951, 0.0, 0.0, 2.0, 0.0, ...","[0.20924558, 55.87233, 9.250959, 6.472432, 90....","[224.46119689941406, 126.08711242675781, 215.8...",0_0
4,1,BASELINE,0,0,0,4,"[144.9309539794922, 125.23460388183594, 147.02...","[68.29928588867188, 89.41569900512695, 102.825...","[247.50634002685547, 205.8799591064453, 206.73...","[139.69127227150585, 96.40998970387525, 35.131...","[0.0, 1.0, 0.0, 0.0, 2.449489742783178, 0.0, 1...","[0.07273595, 59.696102, 9.341586, 9.241178, 93...","[225.9362335205078, 125.02190780639648, 217.48...",0_0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4570,15,BASELINE,0,11,4,70,"[164.6456069946289, 201.37726593017578, 221.37...","[103.03368377685547, 123.88500595092773, 111.0...","[228.7010269165039, 182.9957504272461, 194.898...","[120.54061306741454, 100.11610158084144, 122.3...","[0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ...","[33.209927, 59.340115, 39.64349, 20.251204, 57...","[270.8739013671875, 29.489012718200684, 224.14...",11_4
4571,15,BASELINE,0,11,4,71,"[164.1747055053711, 201.19762420654297, 221.03...","[103.66886138916016, 124.7026481628418, 111.90...","[228.82796478271484, 182.82857513427734, 195.0...","[120.46390127138038, 99.27031148639416, 118.68...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[38.185688, 59.119392, 35.88647, 20.18418, 60....","[270.0938415527344, 30.61710262298584, 226.046...",11_4
4572,15,BASELINE,0,11,4,72,"[163.94602966308594, 201.20968627929688, 220.7...","[103.90918731689453, 125.61049270629883, 112.7...","[229.28658294677734, 182.6895980834961, 195.25...","[120.6254669757123, 98.03421822854327, 115.294...","[1.4142135623730951, 1.0, 0.0, 1.0, 1.41421356...","[41.25244, 57.70841, 32.870136, 18.964466, 63....","[269.0299530029297, 31.600306510925293, 227.36...",11_4
4573,15,BASELINE,0,11,4,73,"[164.01731872558594, 200.6312026977539, 220.45...","[103.52156829833984, 127.15031051635742, 113.9...","[230.02338409423828, 182.86080932617188, 195.5...","[120.99403256501938, 96.36142249228692, 115.47...","[1.0, 2.0, 0.0, 0.0, 1.4142135623730951, 0.0, ...","[42.173332, 54.70367, 31.565449, 16.965788, 66...","[267.66375732421875, 31.173575401306152, 226.1...",11_4


In [12]:
import plotly.express as px

features = np.vstack(df['both_rates'].apply(np.array))

umap_2d = umap.UMAP(n_components=2, init='random', random_state=0)
umap_3d = umap.UMAP(n_components=3, init='random', random_state=0)

proj_2d = umap_2d.fit_transform(features)
proj_3d = umap_3d.fit_transform(features)

fig_2d = px.scatter(
    proj_2d, x=0, y=1,
    color=df.stim_params, labels={'color': 'stim_params'}
)
fig_3d = px.scatter_3d(

    proj_3d, x=0, y=1, z=2,
    color=df.stim_params, labels={'color': 'stim_params'}
)
fig_3d.update_traces(marker_size=5)

fig_3d.show(width=1000, height=1000)


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [11]:
fig_2d.show()

### Trying PCA to compare results

In [16]:
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import plotly.express as px

features = np.vstack(df['both_rates'].apply(np.array))

# PCA with 2 components
pca_2d = PCA(n_components=2, random_state=0)
proj_2d_pca = pca_2d.fit_transform(features)

# PCA with 3 components
pca_3d = PCA(n_components=3, random_state=0)
proj_3d_pca = pca_3d.fit_transform(features)

# Create 2D plot
fig_2d_pca = px.scatter(
    x=proj_2d_pca[:, 0], y=proj_2d_pca[:, 1],
    color=df['stim_params'], labels={'color': 'stim_params'}
)

# Create 3D plot
fig_3d_pca = px.scatter_3d(
    x=proj_3d_pca[:, 0], y=proj_3d_pca[:, 1], z=proj_3d_pca[:, 2],
    color=df['stim_params'], labels={'color': 'stim_params'}
)
fig_3d_pca.update_traces(marker_size=5)

# Show the plots
fig_2d_pca.show(width=1000, height=1000)
fig_3d_pca.show(width=1000, height=1000)
