In [39]:
%pip install optuna cebra==0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv tensorboardX

IOStream.flush timed out
Note: you may need to restart the kernel to use updated packages.


In [40]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [41]:
from cebra import CEBRA
import torch
import torch.utils
import numpy as np
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from utils.overrides import transform
from utils.utils import pandas_series_to_pytorch, k_folds, split
from tqdm import tqdm
import dotenv
import os
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter("/main/external/tensorboard_runs")
device

device(type='cuda')

In [42]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Data Load

In [43]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

[('spark.driver.memory', '120g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.startTime', '1734693139538'), ('spark.executor.memory', '120

In [44]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])
df.show(5)

2024-12-20 11:20:20 [INFO] Error while sending or receiving.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/clientserver.py", line 503, in send_command
    self.socket.sendall(command.encode("utf-8"))
ConnectionResetError: [Errno 104] Connection reset by peer
2024-12-20 11:20:20 [INFO] Closing down clientserver connection
2024-12-20 11:20:20 [INFO] Exception while sending command.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/clientserver.py", line 503, in send_command
    self.socket.sendall(command.encode("utf-8"))
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/si

+------------+--------------------+--------------------+--------------------+
|       index|         neural_data| positional_encoding|           file_name|
+------------+--------------------+--------------------+--------------------+
|317827579904|[0.00353814595291...|0.014736815295795708|ID18150/Day6/Data...|
|317827707084|[7.32697662897408...|0.002792727523517...|ID18150/Day6/Data...|
|317827579905|[0.00376488942492...|0.014736815295795708|ID18150/Day6/Data...|
|317827707085|[5.27972821146249...|0.002792727523517...|ID18150/Day6/Data...|
|317827579906|[0.00368517413153...| 0.01473313965622418|ID18150/Day6/Data...|
+------------+--------------------+--------------------+--------------------+
only showing top 5 rows



                                                                                

# Models

In [45]:
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [46]:
number_of_partitions = 32
batch_size = 128 #Used for both embdding and decoder
n_splits = 16
latent_dimension = 8
test_ratio = 0.2
learning_rate = 1e-4
scores = []

In [47]:
criterion = torch.nn.MSELoss()

class Dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [48]:
indices = df.select('index').rdd.flatMap(lambda x: x).collect()
files = [x.file_name for x in df.select("file_name").distinct().collect()]
len(files)

                                                                                

51

In [49]:
if not os.path.exists('/main/external/models'):
    os.makedirs('/main/external/models')

In [None]:
for file in tqdm(files):
    logging.info(f"Started training on file {file}")
    decoder = torch.nn.Sequential(
        torch.nn.Linear(latent_dimension, 3),
        torch.nn.GELU(),
        torch.nn.Linear(3,3),
        torch.nn.GELU(),
        torch.nn.Linear(3,1),
        torch.nn.Tanh()
    ).to(device)
    
    logging.debug("Trying to load dataframe into memory")
    du = df.where(F.col('file_name') == file).toPandas()
    
    n_samples = du.shape[0]
    logging.debug(f"There are {n_samples}")
    n_test = int(test_ratio * n_samples)
    
    with device:
        X_test = pandas_series_to_pytorch(du.neural_data[-n_test:], device)
        y_test = pandas_series_to_pytorch(du.positional_encoding[-n_test:], device)
        X_train = pandas_series_to_pytorch(du.neural_data[:n_test], device)
        y_train = pandas_series_to_pytorch(du.positional_encoding[:n_test], device)
    
    #Train embedding
    logging.info("Training embedding")
    embedding = CEBRA(
        model_architecture='offset1-model',
        batch_size=batch_size,
        learning_rate=learning_rate,
        temperature_mode='auto',
        output_dimension=latent_dimension,
        max_iterations=50000,
        min_temperature=0.001,
        distance='cosine',
        conditional='time_delta',
        device=str(device),
        verbose=False,
        time_offsets=10
    )
    embedding.fit(X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy())
    
    logging.info("Training decoder")
    with device:
        # Train Decoder
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)
        for i, (X_batch, y_batch) in enumerate(DataLoader(Dataset(X_train,y_train), batch_size=batch_size, shuffle=False)):
            decoder.train()
            decoder_optimizer.zero_grad()
            U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
            y_pred = decoder(U)
            loss = criterion(y_pred, y_batch.unsqueeze(1))
            loss.backward()
            decoder_optimizer.step()
            writer.add_scalar(f"{file}/decoder/train", loss.item(), i)
        
            # Test Decoder
            losses = []
            for (X_batch, y_batch) in DataLoader(Dataset(X_test,y_test), batch_size=batch_size, shuffle=False):
                decoder.eval()
                U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
                y_pred = decoder(U)
                loss = criterion(y_pred, y_batch.unsqueeze(1))
                losses.append(loss.item())
            writer.add_scalar(f"{file}/decoder/test", np.mean(losses), i)
        
        logging.info("Logging reconstruction")
        # Reconstruction comparison
        counter = 0
        for (X_batch, y_batch) in DataLoader(Dataset(X_test,y_test), batch_size=batch_size, shuffle=False):
            decoder.eval()
            U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
            y_pred = decoder(U).flatten()
            for y_true, y_pred in zip(y_batch, y_pred):
                writer.add_scalar(f"{file}/series/prediction_difference", y_true - y_pred, i)
                i += 1
    
    logging.info("Saving models")
    savepath = os.path.join("/main/external/models", file)
    if not os.path.exists(savepath):
        os.makedirs(savepath, exist_ok=True)
    embedding.save(os.path.join(savepath, "embedding.pt"))
    torch.save(decoder, os.path.join(savepath, "decoder.pt"))

  0%|          | 0/51 [00:00<?, ?it/s]2024-12-20 11:20:22 [INFO] Started training on file ID18150/Day6/DataFrame_Imaging_spiking_18150_day6


2024-12-20 11:20:34 [INFO] Training embedding                                   
2024-12-20 11:20:54 [INFO] Training decoder
2024-12-20 11:21:39 [INFO] Logging reconstruction
2024-12-20 11:21:43 [INFO] Saving models
  2%|▏         | 1/51 [01:20<1:06:51, 80.22s/it]2024-12-20 11:21:43 [INFO] Started training on file ID18150/Day12/DataFrame_Imaging_spiking_18150_day12
2024-12-20 11:21:56 [INFO] Training embedding                                   
2024-12-20 11:22:14 [INFO] Training decoder
2024-12-20 11:22:46 [INFO] Logging reconstruction
2024-12-20 11:22:49 [INFO] Saving models
  4%|▍         | 2/51 [02:26<58:56, 72.17s/it]  2024-12-20 11:22:49 [INFO] Started training on file ID18150/Day9/DataFrame_Imaging_dFF_18150_day9
2024-12-20 11:22:51 [ERROR] KeyboardInterrupt while sending command.9 + 5) / 64]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
        

KeyboardInterrupt: 

Exception in thread "serve-DataFrame" java.net.SocketTimeoutException: Accept timed out
	at java.net.PlainSocketImpl.socketAccept(Native Method)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.ServerSocket.implAccept(ServerSocket.java:560)
	at java.net.ServerSocket.accept(ServerSocket.java:528)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:65)
