In [1]:
%pip install --upgrade optuna cebra>=0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv tensorboardX optuna-dashboard duckdb-engine PyMySQL

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [3]:
from cebra import CEBRA
import torch
import torch.utils
import numpy as np
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from utils.overrides import transform
from utils.utils import pandas_series_to_pytorch, Decoder, Dataset
from functools import partial
from tqdm import tqdm
import optuna
import dotenv
import os
import time
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter("/main/external/tensorboard_runs")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [4]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Data Load

In [5]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/25 18:32:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.driver.memory', '120g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.driver.host', '1ba80379b8a2'), ('spark.app.id', 'local-173515157

In [6]:
EXPERIMENT_ID = "ID18170/DataFrame_Imaging_dFF_18170_day4"

In [7]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])\
    .where(F.col("file_name") == EXPERIMENT_ID)\
    .where(F.abs(F.col("velocity")) > 0.1)
df.show(5)

                                                                                

+-----------+--------------------+--------------------+--------------------+
|      index|         neural_data| positional_encoding|           file_name|
+-----------+--------------------+--------------------+--------------------+
|51539695483|[-0.0045404555276...|[-0.9954778664531...|ID18170/DataFrame...|
|51539695484|[0.05658219009637...|[-0.9968232372491...|ID18170/DataFrame...|
|51539695485|[-0.0102527663111...|[-0.9979180828244...|ID18170/DataFrame...|
|51539695486|[0.01038869749754...|[-0.9987675465275...|ID18170/DataFrame...|
|51539695487|[-0.0462875999510...|[-0.9993926506164...|ID18170/DataFrame...|
+-----------+--------------------+--------------------+--------------------+
only showing top 5 rows



# Models

In [8]:
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [9]:
batch_size = 512 #Used for both embedding and decoder
test_ratio = 0.3

In [10]:
indices = df.select('index').rdd.flatMap(lambda x: x).collect()
files = [x.file_name for x in df.select("file_name").distinct().collect()]
len(files)

                                                                                

1

In [11]:
if not os.path.exists('/main/external/models'):
    os.makedirs('/main/external/models')

In [12]:
criterion = torch.nn.MSELoss(reduction="sum")
def objective(trial, file):
    learning_rate_embedding = trial.suggest_float('learning_rate_embedding', 1e-5, 0.1, log = True)
    learning_rate_decoder = trial.suggest_float('learning_rate_decoder', 1e-5, 0.1, log = True)
    latent_dimension = trial.suggest_int('latent_dimension', 3, 100)
    decoder_epochs = trial.suggest_int('decoder_epochs', 10, 100)
    time_offsets = trial.suggest_int('time_offsets', 5, 100)
    num_hidden_units = trial.suggest_int('num_hidden_units', 2,100)
    embedding_version = trial.suggest_categorical('model',choices=[
        'offset1-model',
        'offset1-model-v2',
        'offset1-model-v3',
        'offset1-model-v4',
        'offset1-model-v5',
        'offset5-model',
        'offset10-model'
    ])
    
    logging.info(f"Started training on file {file}")
    
    logging.debug("Trying to load dataframe into memory")
    du = df.where(F.col('file_name') == file).toPandas()
    
    n_samples = du.shape[0]
    logging.debug(f"There are {n_samples}")
    n_test = int(test_ratio * n_samples)
    
    with device:
        X_test = pandas_series_to_pytorch(du.neural_data[-n_test:], device)
        y_test = pandas_series_to_pytorch(du.positional_encoding[-n_test:], device)
        X_train = pandas_series_to_pytorch(du.neural_data[:n_test], device)
        y_train = pandas_series_to_pytorch(du.positional_encoding[:n_test], device)
    
    #Train embedding
    logging.info("Training embedding")
    embedding = CEBRA(
        model_architecture=embedding_version,
        batch_size=batch_size,
        learning_rate=learning_rate_embedding,
        temperature_mode='auto',
        output_dimension=latent_dimension,
        max_iterations=10000,
        min_temperature=0.001,
        distance='cosine',
        conditional='time_delta',
        device=str(device),
        verbose=False,
        time_offsets=time_offsets,
        num_hidden_units=num_hidden_units
    )
    embedding.fit(X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy())
    
    logging.info("Training decoder")
    decoder = Decoder(latent_dimension).to(device)
    with device:
        # Train Decoder
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate_decoder)
        for _ in range(decoder_epochs):
            for i, (X_batch, y_batch) in enumerate(DataLoader(Dataset(X_train,y_train), batch_size=batch_size, shuffle=False)):
                decoder.train()
                decoder_optimizer.zero_grad()
                U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
                y_pred = decoder(U)
                loss = criterion(y_pred, y_batch)
                loss.backward()
                decoder_optimizer.step()
                writer.add_scalar(f"{file}/decoder/train", loss.item(), i)
    
    U = torch.Tensor(transform(embedding, X_test.detach().cpu().numpy())).to(device)
    total_loss = criterion(decoder(U), y_test)
        
    #Calculating metric
    if (hasattr(objective, "best_loss") and total_loss < objective.best_loss) or not hasattr(objective, "best_loss"):
        objective.best_loss = total_loss

        logging.info("Saving models")
        savepath = os.path.join("/main/external/models", file)
        if not os.path.exists(savepath):
            os.makedirs(savepath, exist_ok=True)
        embedding.save(os.path.join(savepath, "embedding.pt"))
        torch.save(decoder, os.path.join(savepath, "decoder.pt"))
    
    return total_loss

In [13]:
for file in tqdm(files):
    objective_ = partial(objective, file = file)
    study = optuna.create_study(
        storage = "mysql+pymysql://root:password@131.220.127.56/optuna",
        load_if_exists=True,
        study_name=file + "_" + str(int(round(time.time() * 1000))),
        direction="minimize"
    )
    study.optimize(objective_, n_trials=1000, n_jobs=2)

  0%|          | 0/1 [00:00<?, ?it/s][I 2024-12-25 18:32:57,204] A new study created in RDB with name: ID18170/DataFrame_Imaging_dFF_18170_day4_1735151576949
2024-12-25 18:32:57 [INFO] Started training on file ID18170/DataFrame_Imaging_dFF_18170_day4
2024-12-25 18:32:57 [INFO] Started training on file ID18170/DataFrame_Imaging_dFF_18170_day4
2024-12-25 18:33:07 [INFO] Training embedding                                   
2024-12-25 18:33:07 [INFO] Training embedding
2024-12-25 18:33:45 [INFO] Training decoder
2024-12-25 18:33:49 [INFO] Saving models
[I 2024-12-25 18:33:49,646] Trial 0 finished with value: 18088.26953125 and parameters: {'learning_rate_embedding': 0.004814383677494288, 'learning_rate_decoder': 0.0001648633734075282, 'latent_dimension': 82, 'decoder_epochs': 20, 'time_offsets': 11, 'num_hidden_units': 94, 'model': 'offset1-model-v3'}. Best is trial 0 with value: 18088.26953125.
2024-12-25 18:33:49 [INFO] Started training on file ID18170/DataFrame_Imaging_dFF_18170_day4
2