In [1]:
%pip install optuna cebra==0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [3]:
import cebra
import torch
import torch.utils
import pandas as pd
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold
from utils import transform
import dotenv
import os
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [4]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .appName("Foo") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/17 14:54:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.driver.memory', '120g'), ('spark.app.id', 'local-1734447249800'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.submitTime', '17344

# Data Load

In [5]:
#Taking just one because this still doesn`t support multi session data
DATA_PATH = "ID18150/Day6/DataFrame_Imaging_spiking_18150_day6"

In [6]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])\
    .where(F.col("file_name") == DATA_PATH)
df.show(5)

                                                                                

+------------+--------------------+--------------------+--------------------+
|       index|         neural_data| positional_encoding|           file_name|
+------------+--------------------+--------------------+--------------------+
|317827579904|[0.00353814595291...|0.014736815295795708|ID18150/Day6/Data...|
|317827707084|[7.32697662897408...|0.002792727523517...|ID18150/Day6/Data...|
|317827579905|[0.00376488942492...|0.014736815295795708|ID18150/Day6/Data...|
|317827707085|[5.27972821146249...|0.002792727523517...|ID18150/Day6/Data...|
|317827579906|[0.00368517413153...| 0.01473313965622418|ID18150/Day6/Data...|
+------------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [7]:
# n_neurons = max([row.neural_size for row in df.select(F.max(F.size("neural_data")).alias("neural_size")).collect()])
# n_rows = df.count()
indices = df.select('index').rdd.flatMap(lambda x: x).collect()
len(indices), max(indices), min(indices)

                                                                                

(131777, 317827711680, 317827579904)

In [8]:
def split(X, p):
    d = len(X) // p
    X_reduced = X[:(d * p)]
    partitions = np.split(np.array(X_reduced), p)
    partitions = [a.tolist() for a in partitions]
    if len(X[d*p:]) != 0:
        partitions += [X[d*p:]]
    return partitions

In [9]:
delta = 5000
def get_x_ticks(L:int):
    x_ticks = np.arange(0,L,delta)
    x_tick_labels = [f"{t[i]/100:.2f}" for i in x_ticks]
    return x_ticks, x_tick_labels

# CEBRA Encoder

# Models

In [None]:
import logging

In [None]:
partition_size = 512
kf = KFold(n_splits = 4)
latent_dimension = 8
scores = []

from cebra import CEBRA

multi_cebra_model = CEBRA(model_architecture='offset1-model',
                    batch_size=64,
                    learning_rate=3e-4,
                    temperature_mode='auto',
                    output_dimension=latent_dimension,
                    max_iterations=1000,
                    distance='cosine',
                    conditional='time_delta',
                    device='cuda_if_available',
                    verbose=True,
                    time_offsets=1)

In [12]:
decoder = torch.nn.Sequential(
    torch.nn.Linear(latent_dimension, 3),
    torch.nn.GELU(),
    torch.nn.Linear(3,3),
    torch.nn.GELU(),
    torch.nn.Linear(3,1),
    torch.nn.Tanh()
)

criterion = torch.nn.MSELoss()
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=3e-4)


In [None]:
scores_hist = []
for train_indices, test_indices in kf.split(indices):
    logging.info(f"Started with {len(train_indices)=} and {len(test_indices)=}")
    for train_partition_indices, test_partition_indices in zip(split(train_indices, partition_size),split(test_indices, partition_size)):
        logging.info("Starting new iteration")
        train_partition_indices = [[indices[i]] for i in train_partition_indices]
        test_partition_indices  = [[indices[i]] for i in test_partition_indices]
        
        logging.info("Filtering dataframe")
        df_train_index = spark.createDataFrame(train_partition_indices[:], ["index"])
        df_test_index  = spark.createDataFrame(test_partition_indices[:], ["index"])
        df_train = df_train_index.join(df, how = "left", on = "index")
        df_test  = df_test_index.join(df, how = "left", on = "index")
        
        logging.info("Grouping data")
        groups_train = [x[0] for x in df_train.select("file_name").distinct().collect()]
        groups_list_train = [df_train.filter(F.col("file_name")==x).collect() for x in groups_train]
        groups_test = [x[0] for x in df_test.select("file_name").distinct().collect()]
        groups_list_test = [df_test.filter(F.col("file_name")==x).collect() for x in groups_test]
        
        X_train = [torch.Tensor(list(map(lambda x: x.neural_data, group))) for group in groups_list_train]
        X_test = [torch.Tensor(list(map(lambda x: x.neural_data, group))) for group in groups_list_test]
        y_train = [torch.Tensor(list(map(lambda x: x.positional_encoding, group))) for group in groups_list_train]
        y_test = [torch.Tensor(list(map(lambda x: x.positional_encoding, group))) for group in groups_list_test]
        
        logging.info("Performing partial fits")
        for X,y in zip(X_train, y_train):
            with device:
                # Train Embedding
                multi_cebra_model.partial_fit(X, y) #Partial fit doesn`t work on multi session data :(
                embedding = torch.Tensor(transform(multi_cebra_model,X))
                
                # Train Decoder
                decoder.train()
                decoder_optimizer.zero_grad()
                predicted_embedding = decoder(embedding)
                loss = criterion(predicted_embedding, y)
                loss.backward()
                decoder_optimizer.step()
        
        scores = []
        for X,y in zip(X_test, y_test):
            with device:
                # Test
                test_embedding = torch.Tensor(transform(multi_cebra_model, X))
                decoder.eval()
                pred = decoder(test_embedding)
                score = r2_score(y.detach().numpy(), pred.detach().numpy())
                scores.append(score)
        mean_score = np.mean(scores)
        
        logging.info(f"Average score: {mean_score}")
        
        scores_hist.append(mean_score)
        
        

    

2024-12-17 15:03:10,685 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:10,686 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:10,708 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:03:17,377 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -1.1442 neg:  3.8109 total:  2.6667 temperature:  0.5114: 100%|██████████| 1000/1000 [00:02<00:00, 364.14it/s]
  return func(*args, **kwargs)


2024-12-17 15:03:20,128 - STDOUT - INFO - Average score: -234.38885498046875


INFO:STDOUT:Average score: -234.38885498046875


2024-12-17 15:03:20,128 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:20,129 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:20,150 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:03:26,824 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -1.3127 neg:  4.0164 total:  2.7037 temperature:  0.4490: 100%|██████████| 1000/1000 [00:02<00:00, 395.37it/s]
  return func(*args, **kwargs)


2024-12-17 15:03:29,359 - STDOUT - INFO - Average score: -237.61734008789062


INFO:STDOUT:Average score: -237.61734008789062


2024-12-17 15:03:29,359 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:29,359 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:29,379 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:03:36,076 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -1.5214 neg:  4.1707 total:  2.6493 temperature:  0.3971: 100%|██████████| 1000/1000 [00:02<00:00, 397.34it/s]
  return func(*args, **kwargs)


2024-12-17 15:03:38,601 - STDOUT - INFO - Average score: -226.2185516357422


INFO:STDOUT:Average score: -226.2185516357422


2024-12-17 15:03:38,602 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:38,602 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:38,620 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:03:45,351 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -1.7819 neg:  4.2536 total:  2.4717 temperature:  0.3554: 100%|██████████| 1000/1000 [00:02<00:00, 373.73it/s]
  return func(*args, **kwargs)


2024-12-17 15:03:48,031 - STDOUT - INFO - Average score: -253.38893127441406


INFO:STDOUT:Average score: -253.38893127441406


2024-12-17 15:03:48,032 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:48,032 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:48,054 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:03:54,749 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -2.2554 neg:  4.5425 total:  2.2872 temperature:  0.3186: 100%|██████████| 1000/1000 [00:02<00:00, 388.48it/s]
  return func(*args, **kwargs)


2024-12-17 15:03:57,328 - STDOUT - INFO - Average score: -80.11901092529297


INFO:STDOUT:Average score: -80.11901092529297


2024-12-17 15:03:57,329 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:03:57,329 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:03:57,347 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:04,020 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -2.3257 neg:  4.6719 total:  2.3462 temperature:  0.2879: 100%|██████████| 1000/1000 [00:03<00:00, 327.01it/s]
  return func(*args, **kwargs)


2024-12-17 15:04:07,084 - STDOUT - INFO - Average score: -89.7493896484375


INFO:STDOUT:Average score: -89.7493896484375


2024-12-17 15:04:07,085 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:07,085 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:07,105 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:13,799 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -3.2738 neg:  5.5297 total:  2.2559 temperature:  0.2631: 100%|██████████| 1000/1000 [00:03<00:00, 331.15it/s]

2024-12-17 15:04:16,823 - STDOUT - INFO - Average score: -81.24207305908203



  return func(*args, **kwargs)
INFO:STDOUT:Average score: -81.24207305908203


2024-12-17 15:04:16,824 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:16,824 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:16,843 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:23,528 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -2.8065 neg:  5.5940 total:  2.7876 temperature:  0.2427: 100%|██████████| 1000/1000 [00:02<00:00, 349.41it/s]
  return func(*args, **kwargs)


2024-12-17 15:04:26,395 - STDOUT - INFO - Average score: -73.75111389160156


INFO:STDOUT:Average score: -73.75111389160156


2024-12-17 15:04:26,396 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:26,396 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:26,414 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:33,076 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -3.8285 neg:  5.9093 total:  2.0808 temperature:  0.2244: 100%|██████████| 1000/1000 [00:02<00:00, 406.05it/s]
  return func(*args, **kwargs)


2024-12-17 15:04:35,543 - STDOUT - INFO - Average score: -71.16227722167969


INFO:STDOUT:Average score: -71.16227722167969


2024-12-17 15:04:35,544 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:35,544 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:35,563 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:42,244 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -3.8293 neg:  6.2521 total:  2.4228 temperature:  0.2100: 100%|██████████| 1000/1000 [00:03<00:00, 327.54it/s]


2024-12-17 15:04:45,302 - STDOUT - INFO - Average score: -70.93733215332031


  return func(*args, **kwargs)
INFO:STDOUT:Average score: -70.93733215332031


2024-12-17 15:04:45,302 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:45,303 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:45,321 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data

2024-12-17 15:04:51,993 - STDOUT - INFO - Performing partial fits


INFO:STDOUT:Performing partial fits                                             
pos: -3.6132 neg:  6.2063 total:  2.5931 temperature:  0.1963: 100%|██████████| 1000/1000 [00:03<00:00, 325.92it/s]
  return func(*args, **kwargs)


2024-12-17 15:04:55,066 - STDOUT - INFO - Average score: -75.11351776123047


INFO:STDOUT:Average score: -75.11351776123047


2024-12-17 15:04:55,067 - STDOUT - INFO - Starting new iteration


INFO:STDOUT:Starting new iteration


2024-12-17 15:04:55,067 - STDOUT - INFO - Filtering dataframe


INFO:STDOUT:Filtering dataframe


2024-12-17 15:04:55,087 - STDOUT - INFO - Grouping data


INFO:STDOUT:Grouping data
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

Exception in thread "serve-DataFrame" java.net.SocketTimeoutException: Accept timed out
	at java.net.PlainSocketImpl.socketAccept(Native Method)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.ServerSocket.implAccept(ServerSocket.java:560)
	at java.net.ServerSocket.accept(ServerSocket.java:528)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:65)
