In [1]:
%pip install optuna cebra==0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv tensorboardX

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [None]:
from cebra import CEBRA
import torch
import torch.utils
import numpy as np
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from utils.overrides import transform
from utils.utils import spark_rdd_to_tensor, k_folds
from tqdm import tqdm
import dotenv
import os
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter("/main/external/tensorboard_runs")
device

device(type='cuda')

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Data Load

In [4]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .appName("Foo") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

[('spark.driver.port', '38345'), ('spark.driver.memory', '120g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.name', 'Foo'), ('spark.execu

In [5]:
#Taking just one because this still doesn`t support multi session data
DATA_PATH = "ID18150/Day6/DataFrame_Imaging_spiking_18150_day6"

In [6]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])\
    .where(F.col("file_name") == DATA_PATH)
df.show(5)



+------------+--------------------+--------------------+--------------------+
|       index|         neural_data| positional_encoding|           file_name|
+------------+--------------------+--------------------+--------------------+
|317827579904|[0.00353814595291...|0.014736815295795708|ID18150/Day6/Data...|
|317827707084|[7.32697662897408...|0.002792727523517...|ID18150/Day6/Data...|
|317827579905|[0.00376488942492...|0.014736815295795708|ID18150/Day6/Data...|
|317827707085|[5.27972821146249...|0.002792727523517...|ID18150/Day6/Data...|
|317827579906|[0.00368517413153...| 0.01473313965622418|ID18150/Day6/Data...|
+------------+--------------------+--------------------+--------------------+
only showing top 5 rows



                                                                                

In [8]:
def split(X, p):
    d = len(X) // p
    X_reduced = X[:(d * p)]
    partitions = np.split(np.array(X_reduced), p)
    partitions = [a.tolist() for a in partitions]
    if len(X[d*p:]) != 0:
        partitions += [X[d*p:]]
    return partitions

In [9]:
delta = 5000
def get_x_ticks(L:int):
    x_ticks = np.arange(0,L,delta)
    x_tick_labels = [f"{t[i]/100:.2f}" for i in x_ticks]
    return x_ticks, x_tick_labels

# CEBRA Encoder

# Models

In [10]:
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [11]:
number_of_partitions = 32
batch_size = 128 #Used for both embdding and decoder
n_splits = 16
latent_dimension = 8
scores = []

In [12]:
criterion = torch.nn.MSELoss()

class Dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
indices = df.select('index').rdd.flatMap(lambda x: x).collect()

In [None]:
if not os.path.exists('/main/external/models'):
    os.makedirs('/main/external/models')

In [None]:
counter = 0
for it, (train_indices, test_indices) in enumerate(k_folds(n_splits, indices)):
    logger.info(f"Started with {len(train_indices)=} and {len(test_indices)=}")
    decoder = torch.nn.Sequential(
        torch.nn.Linear(latent_dimension, 3),
        torch.nn.GELU(),
        torch.nn.Linear(3,3),
        torch.nn.GELU(),
        torch.nn.Linear(3,1),
        torch.nn.Tanh()
    )
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=3e-4)
    multi_cebra_model = CEBRA(
        model_architecture='offset1-model',
        batch_size=batch_size,
        learning_rate=3e-4,
        temperature_mode='auto',
        output_dimension=latent_dimension,
        max_iterations=1000,
        min_temperature=0.001,
        distance='cosine',
        conditional='time_delta',
        device='cuda_if_available',
        verbose=False,
        time_offsets=10)
    train_splits, test_splits = split(train_indices, number_of_partitions), split(test_indices, number_of_partitions)
    for train_partition_indices, test_partition_indices in tqdm(zip(train_splits,test_splits), total = number_of_partitions):
        logger.debug("Starting new iteration")
        counter += 1
        
        logger.debug("Filtering dataframe")
        df_train_index = spark.createDataFrame(train_partition_indices[:], ["index"])
        df_test_index  = spark.createDataFrame(test_partition_indices[:], ["index"])
        df_train = df_train_index.join(df, how = "left", on = "index")
        df_test  = df_test_index.join(df, how = "left", on = "index")
        
        logger.debug("Grouping data")
        groups_train = [x[0] for x in df_train.select("file_name").distinct().collect()]
        groups_list_train = [df_train.filter(F.col("file_name")==x).collect() for x in groups_train]
        groups_test = [x[0] for x in df_test.select("file_name").distinct().collect()]
        groups_list_test = [df_test.filter(F.col("file_name")==x).collect() for x in groups_test]
        
        X_train = [spark_rdd_to_tensor(group, "neural_data") for group in groups_list_train]
        X_test = [spark_rdd_to_tensor(group, "neural_data") for group in groups_list_test]
        y_train = [spark_rdd_to_tensor(group, "positional_encoding") for group in groups_list_train]
        y_test = [spark_rdd_to_tensor(group, "positional_encoding") for group in groups_list_test]
        
        logger.debug("Performing partial fits")
        losses = []
        for X,y in zip(X_train, y_train):
            with device:
                # Train Embedding
                multi_cebra_model.partial_fit(X, y) #Partial fit doesn`t work on multi session data :(
                
                # Train Decoder
                for X_batch, y_batch in DataLoader(Dataset(X,y), batch_size=batch_size, shuffle=False):
                    decoder.train()
                    decoder_optimizer.zero_grad()
                    embedding = torch.Tensor(transform(multi_cebra_model,X_batch))
                    predicted_embedding = decoder(embedding)
                    loss = criterion(predicted_embedding, y_batch.unsqueeze(1))
                    loss.backward()
                    decoder_optimizer.step()
                    losses.append(loss.item())
        writer.add_scalar(f"AverageDecoderTrainingLoss/{str(it)}", np.mean(losses), counter)
        
        test_losses = []
        for X,y in zip(X_test, y_test):
            with device:
                # Test
                test_embedding = torch.Tensor(transform(multi_cebra_model, X))
                decoder.eval()
                pred = decoder(test_embedding)
                loss = criterion(pred, y.unsqueeze(1))
                test_losses.append(loss.item())
        writer.add_scalar(f"AverageDecoderTestLoss/{str(it)}", np.mean(test_losses), counter)
    multi_cebra_model.save(f"/main/external/models/embedding_{str(it)}.pt")
    torch.save(decoder, f"/main/external/models/decoder_{str(it)}.pt")

INFO:__main__:Started with len(train_indices)=98832 and len(test_indices)=32945
  1%|          | 1025/98832 [2:39:20<253:25:15,  9.33s/it]                      
INFO:__main__:Started with len(train_indices)=98833 and len(test_indices)=32944
  1%|          | 1025/98833 [2:39:51<254:14:21,  9.36s/it]                      
INFO:__main__:Started with len(train_indices)=98833 and len(test_indices)=32944
  1%|          | 1025/98833 [2:40:00<254:28:42,  9.37s/it]                      
INFO:__main__:Started with len(train_indices)=98833 and len(test_indices)=32944
  1%|          | 1025/98833 [2:39:11<253:11:01,  9.32s/it]                      
