In [None]:
import sys
sys.path
sys.path.append('/volatile/aurelien_stumpf_mascles/project/code/')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torch.nn as nn
import cebra_v2 as cebra2
from collections import defaultdict
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA,FastICA
from torch.utils.data.sampler import BatchSampler
import package
import pandas as pd
import matplotlib.animation as animation
import tembedding
import scipy as sc
from skimage.metrics import structural_similarity as ssim
from sklearn.cluster import KMeans,OPTICS
import math
from joblib import Memory,Parallel,delayed,parallel_backend
import time
from multiprocessing import Lock, Process, Queue
import multiprocessing
import queue 
import os
import networkx as nx
from scipy.spatial.transform import Rotation as R

In [None]:
import importlib
importlib.reload(package.preprocessing)
importlib.reload(cebra2.distribution)
importlib.reload(cebra2.dataset)

## Présentation

Dans ce notebook, on essaye de classifier la condition deep-sevoflurane grâce à la méthode CEBRA.

## Fonctions utiles

In [None]:
def display(states,li_titles):
    fig = plt.figure(figsize=(11, 11))
    
    # setting values to rows and column variables
    n = len(states)
    rows = int(np.sqrt(n))+1
    columns = n // (rows-1)
    
    for i in range(n):
        fig.add_subplot(rows, columns, i+1)
        # showing image
        plt.imshow(states[i])
        plt.axis('off')
        plt.title(li_titles[i])

In [None]:
def single_session_solver(data_loader, **kwargs):
    """Train a single session CEBRA model."""
    norm = True
    if kwargs['distance'] == 'euclidean':
        norm = False
    model = kwargs["model"]

    if kwargs['distance'] == 'euclidean':
        criterion = cebra2.criterion.EuclideanInfoNCE(temperature=kwargs['temperature'],beta = kwargs['beta'])
    elif kwargs['distance'] == 'cosine':        
        criterion = cebra2.crite55rion.CosineInfoNCE(temperature=kwargs['temperature'],beta = kwargs['beta'])

    optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['learning_rate'])

    return cebra2.solver.SingleSessionSolver(model=model,
                                            criterion=criterion,
                                            optimizer=optimizer)

@torch.no_grad()
def get_emissions(model, dataset):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    model.to(device)
    return model(dataset).cpu().numpy()

def _compute_emissions_single(solver, dataset):
    return get_emissions(solver.model, dataset)

## Data Loading

Dans le train, on garde les singes 'almira', 'khali', 'kimiko', 'rana'.
Dans le test, on utilise 'jade'.

In [None]:
dfc = np.load('/neurospin/lbi/monkeyfmri/deepstim/database/ANESTHETIC_database/derivatives/reference_kmeans/inputs/inputs.npy')
meta = pd.read_csv("/neurospin/lbi/monkeyfmri/deepstim/database/ANESTHETIC_database/derivatives/reference_kmeans/inputs/metadata.tsv", sep="\t")

In [None]:
n_runs = len(set(meta["unique_id"] + meta["monkey"]))
n_wins = 464
dfc_train = dfc[meta["monkey"].isin(['almira', 'khali', 'kimiko', 'rana']) & (meta["condition"].isin(['awake', 'light-propofol', 'deep-propofol']))].reshape((-1, n_wins, 82, 82))
dfc_test = dfc[meta["monkey"] == "jade"].reshape((-1, n_wins, 82, 82))
dfc_all = dfc.reshape((-1, n_wins, 82, 82))

In [None]:
labels_session = np.array([meta["condition"].iloc[i*464] for i in range(len(meta)//464)])

In [None]:
np.unique(labels_session)

## Animation

In [None]:
idx = np.argwhere(labels_session == 'deep-propofol').flatten()

In [None]:
import matplotlib.animation as animation

# ims is a list of lists, each row is a list of artists to draw in the
# current frame; here we are just animating one artist, the image, in
# each frame
for a in idx :
    print(a)
    ims = []
    fig, ax = plt.subplots()
    for i in range(464):
        im = ax.imshow(dfc_all[a,i,:,:], animated=True)
        if i == 0:
            ax.imshow(dfc_all[a,i,:,:])  # show an initial one first
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                    repeat_delay=1000)
    
    fig.colorbar(im, ax=ax)
    
    ani.save("/volatile/aurelien_stumpf_mascles/project/code/visuals/animation/deep-propofol/movie{}.mp4".format(a))
    print("Saved at /volatile/aurelien_stumpf_mascles/project/code/visuals/animation/movie{}.mp4".format(a))

# To save the animation, use e.g.
#
# ani.save("movie.mp4")
#
# or
#
# writer = animation.FFMpegWriter(
#     fps=15, metadata=dict(artist='Me'), bitrate=1800)
# ani.save("movie.mp4", writer=writer)

plt.show()

## Classification des états de Propofol/Deep-Propofol/Awake

In [None]:
def STRUCTURE(x,y):
        return 1 - package.preprocessing.structure(x,y,is_batch = False)

metric = STRUCTURE

In [None]:
dfc_test_deep_propofol = dfc[(meta["monkey"].isin(['jade'])) & (meta["condition"] == "deep-propofol")].reshape(-1,n_wins,82,82)
dfc_test_light_propofol = dfc[(meta["monkey"].isin(['jade'])) & (meta["condition"] == "light-propofol")].reshape(-1,n_wins,82,82)
dfc_test_awake = dfc[(meta["monkey"].isin(['jade'])) & (meta["condition"] == "awake")].reshape(-1,n_wins,82,82)

### Basis : Deep-propofol mean states

In [None]:
mean_states = np.load("./BrainStates/deep-propofol/mean_states.npy")

In [None]:
display(mean_states,np.arange(len(mean_states)))

In [None]:
list_session_deep_propofol = np.zeros((dfc_test_deep_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_deep_propofol_ = torch.from_numpy(dfc_test_deep_propofol)
for i in range(len(mean_states)):
    print(i)
    for session in range(dfc_test_deep_propofol.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_deep_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_deep_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_deep_propofol,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_awake = np.zeros((dfc_test_awake.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_awake_ = torch.from_numpy(dfc_test_awake)
for i in range(len(mean_states)):
    for session in range(dfc_test_awake.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_awake_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_awake[session,i] = min_val

In [None]:
plt.plot(list_session_awake,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_light_propofol = np.zeros((dfc_test_light_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_light_propofol_ = torch.from_numpy(dfc_test_light_propofol)
for i in range(len(mean_states)):
    for session in range(dfc_test_light_propofol_.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_light_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_light_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_light_propofol,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_light_propofol

In [None]:
plt.plot(np.mean(list_session_light_propofol,axis=0),color="r")
plt.plot(np.mean(list_session_deep_propofol,axis=0),color = "b")
plt.plot(np.mean(list_session_awake,axis=0), color = "k")

In [None]:
np.argsort(np.mean(list_session_awake,axis=0) - np.mean(list_session_deep_propofol,axis=0))

In [None]:
np.argsort(np.mean(list_session_light_propofol,axis=0) - np.mean(list_session_deep_propofol,axis=0))

In [None]:
np.sort(np.mean(list_session_light_propofol,axis=0) - np.mean(list_session_deep_propofol,axis=0))

In [None]:
plt.imshow(mean_states[74,:])

### Basis : Awake mean states

In [None]:
mean_states = np.load("./BrainStates/awake/mean_states.npy")

In [None]:
list_session_awake.shape

In [None]:
list_session_awake = np.zeros((dfc_test_awake.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_awake_ = torch.from_numpy(dfc_test_awake)
for i in range(len(mean_states)):
    for session in range(dfc_test_awake.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_awake_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_awake[session,i] = min_val

In [None]:
plt.plot(list_session_awake,label = list(np.arange(len(mean_states))))
leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_deep_propofol = np.zeros((dfc_test_deep_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_deep_propofol_ = torch.from_numpy(dfc_test_deep_propofol)
for i in range(len(mean_states)):
    for session in range(dfc_test_deep_propofol.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_deep_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_deep_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_deep_propofol,label = list(np.arange(len(mean_states))))
leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_light_propofol = np.zeros((dfc_test_light_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_light_propofol_ = torch.from_numpy(dfc_test_light_propofol)
for i in range(len(mean_states)):
    for session in range(dfc_test_light_propofol.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_light_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_light_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_light_propofol,label = list(np.arange(len(mean_states))))
leg = plt.legend(loc='upper center')
plt.show()

### Light-propofol mean states

In [None]:
mean_states = np.load("./BrainStates/light-propofol/mean_states.npy")

In [None]:
display(mean_states,np.arange(len(mean_states)))

In [None]:
list_session_awake = np.zeros((dfc_test_awake.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_awake_ = torch.from_numpy(dfc_test_awake)
for i in range(len(mean_states)):
    for session in range(dfc_test_awake.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_awake_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_awake[session,i] = min_val

In [None]:
plt.plot(list_session_awake,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_deep_propofol = np.zeros((dfc_test_deep_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_deep_propofol_ = torch.from_numpy(dfc_test_deep_propofol)
for i in range(len(mean_states)):
    print(i)
    for session in range(dfc_test_deep_propofol.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_deep_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_deep_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_deep_propofol,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
list_session_light_propofol = np.zeros((dfc_test_light_propofol.shape[0],len(mean_states)))
mean_states_ = torch.from_numpy(mean_states)
dfc_test_light_propofol_ = torch.from_numpy(dfc_test_light_propofol)
for i in range(len(mean_states)):
    for session in range(dfc_test_light_propofol_.shape[0]):
        min_val = 1
        for t in range(464):
            accu = STRUCTURE(dfc_test_light_propofol_[session,t,:,:],mean_states_[i,:,:])
            min_val = min(accu,min_val)
        list_session_light_propofol[session,i] = min_val

In [None]:
plt.plot(list_session_light_propofol,label = list(np.arange(len(mean_states))))
#leg = plt.legend(loc='upper center')
plt.show()

In [None]:
plt.plot(np.mean(list_session_light_propofol,axis=0),color="r")
plt.plot(np.mean(list_session_deep_propofol,axis=0),color = "b")
plt.plot(np.mean(list_session_awake,axis=0), color = "k")

In [None]:
np.argsort(np.mean(list_session_awake,axis=0) - np.mean(list_session_light_propofol,axis=0))

In [None]:
np.argsort(np.mean(list_session_deep_propofol,axis=0) - np.mean(list_session_light_propofol,axis=0))

In [None]:
plt.imshow(mean_states[6])

## Classification avec les labels

In [None]:
dict_labels = {"awake" : 0, "light-propofol" : 1, "deep-propofol" : 2}

In [None]:
meta

In [None]:
meta_test = meta[meta["monkey"].isin(['jade']) & (meta["condition"].isin(["awake","light-propofol","deep-propofol"]))]["condition"]
labels_test_condition = np.array([meta_test.iloc[i] for i in range(len(meta_test))])
labels_test = np.array([dict_labels[elem] for elem in labels_test_condition]).reshape((len(labels_test_condition)//464,464))

In [None]:
meta_train = meta[meta["monkey"].isin(['almira', 'khali', 'kimiko', 'rana']) & (meta["condition"].isin(["awake","light-propofol","deep-propofol"]))]["condition"]
labels_train_condition = np.array([meta_train.iloc[i] for i in range(len(meta_train))])
labels_train = np.array([dict_labels[elem] for elem in labels_train_condition]).reshape((len(labels_train_condition)//464,464))

In [None]:
#We define the architecture of the model we will use

num_output = 6
normalize = True
num_neurons = 3321

model = cebra2.model.Model(
    nn.Dropout(0),
    nn.Linear(
        num_neurons,
        1000,
    ),
    nn.Linear(
        1000,
        500,
    ),
    nn.Linear(
        500,
        100,
    ),
    nn.Linear(
        100,
        50,
    ),
    nn.Linear(
        50,
        20,
    ),
    nn.Linear(
        20,
        6,
    ),
    num_input=num_neurons,
    num_output=num_output,
    normalize = True
        )

In [None]:
dfc = dfc_train
fc_train = torch.from_numpy(dfc.reshape((dfc.shape[0],464,1,82,82)))
discrete = labels_train

In [None]:
fc_train = torch.squeeze(fc_train,dim=2)
fc_train_vector = torch.from_numpy(package.preprocessing.flatten_higher_triangular(fc_train))
fc_dataset = cebra2.dataset.SimpleMultiSessionDataset(fc_train_vector,discrete = discrete)
fc_loader = cebra2.dataset.MultiSessionLoader(fc_dataset, num_steps = 1000, batch_size = 6000,  time_delta = 5, matrix_delta = 0.5)

In [None]:
cebra_fc = single_session_solver(data_loader = fc_loader, model_architecture = 'offset1-model', 
                 distance = 'cosine', num_hidden_units = 128, output_dimension = 128,
                verbose = True, temperature = 1, beta = 1,learning_rate = 3e-4, model = model)

In [None]:
cebra_fc.fit(fc_loader)

In [None]:
a,b,c = fc_dataset.neural.shape
data = fc_dataset.neural.resize(a*b,c)
fc_emb = _compute_emissions_single(cebra_fc, data)
fc_emb = fc_emb.reshape((a,b,6))
fig = plt.figure(figsize = (12,5))

ax1 = plt.subplot(121)
ax1.set_title('Embedding du train')
colors = ["black",'red','green','blue','purple','yellow']
for i in range(58):
    ax1.scatter(fc_emb[i,:,0], fc_emb[i,:,3], cmap = matplotlib.colors.ListedColormap(colors), c = colors[labels_train[i,0]], s=1)
#ax1.xticks() 
ax1.axis('on')

In [None]:
dfc_test_vector = torch.from_numpy(package.preprocessing.flatten_higher_triangular(torch.from_numpy(dfc_test)))
dfc_test_vector = dfc_test_vector.type(torch.float32)

In [None]:
#dfc_test_vector = torch.from_numpy(package.preprocessing.flatten_higher_triangular(torch.from_numpy(dfc_test)))
a,b,c = dfc_test_vector.shape
data = dfc_test_vector.resize(a*b,c)
fc_emb = _compute_emissions_single(cebra_fc, data)

In [None]:
a,b,c = dfc_test_vector.shape
data = dfc_test_vector.resize(a*b,c)
fc_emb = _compute_emissions_single(cebra_fc, data)
fc_emb = fc_emb.reshape((a,b,3))
fig = plt.figure(figsize = (12,5))

ax1 = plt.subplot(121)
ax1.set_title('Embedding du train')
colors = ["black",'red','green','blue','purple','yellow']
for i in range(28):
    if labels_test[i,0] in [0,1,2]:
        ax1.scatter(fc_emb[i,:,0], fc_emb[i,:,1], cmap = matplotlib.colors.ListedColormap(colors), c = colors[labels_test[i,0]], s=1)
#ax1.xticks() 
ax1.axis('on')

### Algorithme de classification

In [None]:
net = package.torch_classifier.MLP(input_dim = 3)
a,b,c = fc_dataset.neural.shape
data = fc_dataset.neural.resize(a*b,c)
fc_emb = _compute_emissions_single(cebra_fc, data)
X = fc_emb.reshape((-1,3))
y = labels_train.reshape((-1))
X_train, X_test,y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
train = package.torch_classifier.SimpleDataset(X_train,y_train)
test = package.torch_classifier.SimpleDataset(X_test,y_test)
balanced_batch_sampler = package.torch_classifier.BalancedBatchSampler(train, n_classes = 3, n_samples = 1000)
train_loader = torch.utils.data.DataLoader(train, batch_sampler = balanced_batch_sampler)
test_loader = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True, num_workers=6)

In [None]:
fig = plt.figure(figsize = (12,5))

ax1 = plt.subplot(121)
ax1.set_title('Embedding du train')
colors = ["black",'red','green','blue','purple','yellow']
ax1.scatter(X_train[:,0], X_train[:,1], cmap = matplotlib.colors.ListedColormap(colors), c = y_train, s=1)
#ax1.xticks() 
ax1.axis('on')

In [None]:
import importlib
importlib.reload(package.torch_classifier)

In [None]:
package.torch_classifier.Train(net,train_loader,test_loader,100,lr = 0.1)

In [None]:
a,b = np.meshgrid(np.linspace(0,2*np.pi,300),np.linspace(0,np.pi,200))
x = (np.cos(a)*np.cos(b)).reshape(-1,1)
y = (np.sin(a)*np.cos(b)).reshape(-1,1)
z = np.sin(b).reshape(-1,1)
sphere = np.concatenate((x,y,z),axis = 1)

net.eval()
probas = torch.exp(net(torch.from_numpy(sphere).type(torch.float32)))
res = torch.argmax(probas,dim=1)

fig = plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.set_title('Embedding du train')
colors = ["black",'red','green','blue','purple',"yellow"]
ax1.scatter(sphere[:,0], sphere[:,1], cmap = matplotlib.colors.ListedColormap(colors), c = res, s=1)
#ax1.xticks() 
ax1.axis('on')


In [None]:
a,b = np.meshgrid(np.linspace(0,2*np.pi,300),-np.linspace(0,np.pi,200))
x = (np.cos(a)*np.cos(b)).reshape(-1,1)
y = (np.sin(a)*np.cos(b)).reshape(-1,1)
z = np.sin(b).reshape(-1,1)
sphere = np.concatenate((x,y,z),axis = 1)

net.eval()
probas = torch.exp(net(torch.from_numpy(sphere).type(torch.float32)))
res = torch.argmax(probas,dim=1)

fig = plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.set_title('Embedding du train')
colors = ["black",'red','green','blue','purple',"yellow"]
ax1.scatter(sphere[:,0], sphere[:,1], cmap = matplotlib.colors.ListedColormap(colors), c = res, s=1)
#ax1.xticks() 
ax1.axis('on')

Résultats

In [None]:
net.eval()
for session in range(28):
    probas = torch.zeros((6))
    emb = torch.from_numpy(_compute_emissions_single(cebra_fc, dfc_test_vector[session,:,:]))
    probas = torch.sum(torch.exp(net(emb.type(torch.float32))),dim=0)
    res = torch.argmax(probas,dim=0)
    print("True label : {}, Predicted label : {}".format(labels_test[session,0],res))