In [1]:
from utils.data.create_local_t5data import get_trial_data
from datasets import get_testing_data
from utils_f import get_config
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from datetime import datetime
from utils.plot.plot_true_vs_pred_mvmnt import plot_true_vs_pred_mvmnt
import torch
from utils_f import get_config_from_file, set_seeds, set_device
from datasets import get_trial_data, chop, smooth_spikes
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
import plotly.graph_objects as go
from matplotlib import cm, colors
import matplotlib.pyplot as plt
import shutil
import os
import sys
import wandb
import pandas as pd
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [37]:
# path = '/home/dmifsud/Projects/NDT-U/runs/train/honest-sweep-10/last.pt'
# path = '/home/dmifsud/Projects/NDT-U/runs/train/stellar-sweep-1/last.pt'
# path = '/home/dmifsud/Projects/NDT-U/runs/train/sweet-serenity-164/last.pt'
# path = '/home/dmifsud/Projects/NDT-U/runs/train/polar-violet-159/last.pt'
# path = '/home/dmifsud/Projects/NDT-U/runs/train/neat-sweep-6/last.pt'
# path = '/home/dmifsud/Projects/NDT-U/runs/train/resilient-sweep-7/last.pt'
path = '/home/dmifsud/Projects/NDT-U/runs/train/bumbling-sweep-28/last.pt'

In [None]:
name = path[:path.rindex('/')].split('/')[-1]
config = get_config_from_file(path[:path.rindex('/')+1]+'config.yaml')
if not os.path.isdir(f"plots/{name}"): os.makedirs(f"plots/{name}")
shutil.copyfile(path[:path.rindex('/')+1]+'config.yaml', f"plots/{name}/config.yaml")

set_seeds(config)
set_device(config, {})
device = torch.device('cuda:0')

model = torch.load(path).to(device)
model.name = name

model.eval()

datasets = get_testing_data(config)
session_csv = pd.read_csv(f'{config.data.dir}/sessions.csv')

In [None]:
session = 't5.2021.05.05'
# session = 't5.2021.05.17'

In [None]:
trialized_data = {}
for session in config.data.pretrain_sessions:
    dataset = copy.deepcopy(datasets[session]) # do not want to run xcorr on test data

    session_csv = pd.read_csv(f'{config.data.dir}/sessions.csv')

    if config.data.rem_xcorr: 
        corr, corr_chans = dataset.get_pair_xcorr('spikes', threshold=0.2, zero_chans=True)
        
    dataset.resample(config.data.bin_size / 1000)
    dataset.smooth_spk(20, name='smth') # for use if we want to take mean and std of smth values

    failed_trials = ~dataset.trial_info['is_successful'] 
    center_trials = dataset.trial_info['is_center_target']
    ol_block = session_csv.loc[session_csv['session_id'] == session, 'ol_blocks'].item()
    cl_blocks =  ~dataset.trial_info['block_num'].isin([ol_block]).values.squeeze()

    spks = dataset.data[dataset.data['blockNums'].isin([ol_block]).values.squeeze()].spikes.to_numpy()
    spks_idx = dataset.data[dataset.data['blockNums'].isin([ol_block]).values.squeeze()].spikes.index

    n_channels = dataset.data.spikes.shape[-1]

    n_heldout = int(config.data.heldout_pct * n_channels)
    n_heldin = n_channels - n_heldout
    np.random.seed(config.setup.seed)
    heldout_channels = np.random.choice(n_channels, n_heldout, replace=False)
    heldin_channels = torch.ones(n_channels, dtype=bool)
    heldin_channels[heldout_channels] = False

    chopped_spks = chop(np.array(spks[:, heldin_channels]), 30, 29)
    hi_chopped_spks = torch.Tensor(chopped_spks).to(device)

    names = [session for i in range(hi_chopped_spks.shape[0])]
    with torch.no_grad():
        rates, output = model(hi_chopped_spks, names)

    factors_df = pd.DataFrame(output[:, -1, :].cpu().numpy(), index=spks_idx[29:], columns=pd.MultiIndex.from_tuples([('factors', f'{i}') for i in range(output.shape[-1])]))
    dataset.data = pd.concat([dataset.data, factors_df], axis=1)

    dataset.smooth_spk(config['data']['smth_std'], signal_type='factors', name='smth')

    ignored_trials = failed_trials | center_trials | cl_blocks
    ignored_trials[1] = True

    trial_data = dataset.make_trial_data(
        align_field='start_time',
        align_range=(0, config.data.trial_len),
        allow_overlap=True,
        ignored_trials= ignored_trials
    )

    trial_data.sort_index(axis=1, inplace=True)
    trial_data['X&Y'] = list(zip(trial_data['targetPos']['x'], trial_data['targetPos']['y']))
    trial_data['condition'] = 0

    for xy, id in list(zip(trial_data['X&Y'].unique(), np.arange(1,9))):
        indices = trial_data.index[trial_data['X&Y'] == xy]
        trial_data.loc[indices, 'condition'] = id
    
    trialized_data[session] = trial_data

In [None]:
factors = []
for session in config.data.pretrain_sessions:    
    for cond_id, trials in trialized_data[session].groupby('condition'):
        for trial_id, trial in trials.groupby('trial_id'):
            factors.append(trial.factors_smth)

factors = np.array(factors)
fs = factors.shape
factors = factors.reshape((fs[0] * fs[1], fs[2]))
pca = Pipeline([('scaling', StandardScaler()), ('pca', PCA(n_components=3))])
pca.fit(factors)

Pipeline(steps=[('scaling', StandardScaler()), ('pca', PCA(n_components=3))])

In [None]:
# COND AVG
    
fig = go.Figure()
for session in config.data.pretrain_sessions:    
    for cond_id, trials in trialized_data[session].groupby('condition'):
        avg_trials = []
        for trial_id, trial in trials.groupby('trial_id'):
            avg_trials.append(pca.transform(trial.factors_smth))
        avg_trials = np.array(avg_trials).mean(0)
        fig.add_trace(
            go.Scatter3d(
                x=avg_trials[:, 0], 
                y=avg_trials[:, 1], 
                z=avg_trials[:, 2],
                mode='lines',
                line=dict(color=f'{colors.rgb2hex(cm.tab10(cond_id))}'),
            )
        )

fig.update_layout(
    width=430,
    height=410,
    autosize=False,
    showlegend=False,
    title={
        'text': "Condition Averaged PCs",
        'y':0.96,
        'yanchor': 'bottom'
    },
    scene=dict(
        xaxis_showspikes=False,
        yaxis_showspikes=False,
        zaxis_showspikes=False,
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3",
        camera=dict(
            center=dict(
                x=0.065,
                y=0.0,
                z=-0.075,
                # z=-0.12,
            ),
            eye=dict(
                x=1.3, 
                y=1.3, 
                z=1.3
            )
        ),
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual'
    ),
)

fig.update_layout(margin=dict(r=0, l=0, b=0, t=20))
config2 = {'displayModeBar': False}
fig.show(config=config2)

In [None]:
# SINGLE TRIAL

fig = go.Figure()
for session in config.data.pretrain_sessions:    
    for cond_id, trials in trialized_data[session].groupby('condition'):
        for trial_id, trial in trials.groupby('trial_id'):
            pc_factors = pca.transform(trial.factors_smth)
            fig.add_trace(
                go.Scatter3d(
                    x=pc_factors[:, 0], 
                    y=pc_factors[:, 1], 
                    z=pc_factors[:, 2],
                    mode='lines',
                    line=dict(color=f'{colors.rgb2hex(cm.tab10(cond_id))}'),
                )
            )

fig.update_layout(
    width=430,
    height=410,
    autosize=False,
    showlegend=False,
    title={
        'text': "Single Trial PCs",
        'y':0.96,
        'yanchor': 'bottom'
    },
    scene=dict(
        xaxis_showspikes=False,
        yaxis_showspikes=False,
        zaxis_showspikes=False,
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3",
        camera=dict(
            center=dict(
                x=0.065,
                y=0.0,
                z=-0.075,
                # z=-0.12,
            ),
            eye=dict(
                x=1.3, 
                y=1.3, 
                z=1.3
            )
        ),
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual'
    ),
)

fig.update_layout(margin=dict(r=0, l=0, b=0, t=20))
config2 = {'displayModeBar': False}
fig.show(config=config2)