In [1]:
%pip install optuna cebra==0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [3]:
import cebra
import torch
import torch.utils
import pandas as pd
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold
from utils import TensorDataset, SimpleTensorDataset, SupervisedNNSolver, RDDDataset
import dotenv
import os
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cpu')

In [4]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .appName("Foo") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/16 13:36:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.driver.memory', '120g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.driver.host', 'b918cc307c33'), ('spark.app.name', 'Foo'), ('spar

# Data Load

In [5]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])
df.show(5)

                                                                                

+------------+--------------------+--------------------+--------------------+
|       index|         neural_data| positional_encoding|           file_name|
+------------+--------------------+--------------------+--------------------+
|317827579904|[0.00353814595291...|0.014736815295795708|ID18150/Day6/Data...|
|317827707084|[7.32697662897408...|0.002792727523517...|ID18150/Day6/Data...|
|317827579905|[0.00376488942492...|0.014736815295795708|ID18150/Day6/Data...|
|317827707085|[5.27972821146249...|0.002792727523517...|ID18150/Day6/Data...|
|317827579906|[0.00368517413153...| 0.01473313965622418|ID18150/Day6/Data...|
+------------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [6]:
# n_neurons = max([row.neural_size for row in df.select(F.max(F.size("neural_data")).alias("neural_size")).collect()])
# n_rows = df.count()
indices = df.select('index').rdd.flatMap(lambda x: x).collect()
len(indices), max(indices), min(indices)

                                                                                

(5308650, 395137123064, 0)

In [7]:
def split(X, p):
    d = len(X) // p
    X_reduced = X[:(d * p)]
    partitions = np.split(np.array(X_reduced), p)
    partitions = [a.tolist() for a in partitions]
    if len(X[d*p:]) != 0:
        partitions += [X[d*p:]]
    return partitions

In [8]:
delta = 5000
def get_x_ticks(L:int):
    x_ticks = np.arange(0,L,delta)
    x_tick_labels = [f"{t[i]/100:.2f}" for i in x_ticks]
    return x_ticks, x_tick_labels

# CEBRA Encoder

# Models

In [None]:
partition_size = 512
kf = KFold(n_splits = 4)
scores = []

from cebra import CEBRA

multi_cebra_model = CEBRA(model_architecture='offset1-model',
                    batch_size=16,
                    learning_rate=3e-4,
                    temperature=1,
                    output_dimension=3,
                    max_iterations=1000,
                    distance='cosine',
                    conditional='time_delta',
                    device='cuda_if_available',
                    verbose=True,
                    time_offsets=1)

for train_indices, test_indices in kf.split(indices):
    for train_partition_indices, test_partition_indices in zip(split(train_indices, partition_size),split(test_indices, partition_size)):
        train_partition_indices = [[indices[i]] for i in train_partition_indices]
        test_partition_indices  = [[indices[i]] for i in test_partition_indices]
        
        df_train_index = spark.createDataFrame(train_partition_indices[:], ["index"])
        df_test_index  = spark.createDataFrame(test_partition_indices[:], ["index"])
        df_train = df_train_index.join(df, how = "left", on = "index").collect()
        df_test  = df_test_index.join(df, how = "left", on = "index").collect()
        
        X_train_tensor = torch.Tensor(list(map(lambda x: x.neural_data, df_train)))
        X_test_tensor = torch.Tensor(list(map(lambda x: x.neural_data, df_test)))
        y_train_tensor = torch.Tensor(list(map(lambda x: x.positional_encoding, df_train)))
        y_test_tensor = torch.Tensor(list(map(lambda x: x.positional_encoding, df_test)))
        
        multi_cebra_model.partial_fit(X_train_tensor, y_train_tensor)

    

pos: -0.7827 neg:  3.0400 total:  2.2572 temperature:  1.0000: 100%|██████████| 1000/1000 [00:02<00:00, 421.09it/s]
ERROR:root:KeyboardInterrupt while sending command.              (0 + 64) / 192]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

[Stage 45:====>                                                 (15 + 64) / 192]

In [None]:
torch.Tensor([[1,2,3],[4]])

ValueError: expected sequence of length 3 at dim 1 (got 1)

24/12/16 14:52:26 WARN JavaUtils: Attempt to delete using native Unix OS command failed for path = /tmp/blockmgr-2b504557-d8c8-485e-9e8b-3a519153c191. Falling back to Java IO way
java.io.IOException: Failed to delete: /tmp/blockmgr-2b504557-d8c8-485e-9e8b-3a519153c191
	at org.apache.spark.network.util.JavaUtils.deleteRecursivelyUsingUnixNative(JavaUtils.java:174)
	at org.apache.spark.network.util.JavaUtils.deleteRecursively(JavaUtils.java:109)
	at org.apache.spark.network.util.JavaUtils.deleteRecursively(JavaUtils.java:90)
	at org.apache.spark.util.SparkFileUtils.deleteRecursively(SparkFileUtils.scala:121)
	at org.apache.spark.util.SparkFileUtils.deleteRecursively$(SparkFileUtils.scala:120)
	at org.apache.spark.util.Utils$.deleteRecursively(Utils.scala:1126)
	at org.apache.spark.storage.DiskBlockManager.$anonfun$doStop$1(DiskBlockManager.scala:368)
	at org.apache.spark.storage.DiskBlockManager.$anonfun$doStop$1$adapted(DiskBlockManager.scala:364)
	at scala.collection.IndexedSeqOptimize