In [1]:
%pip install --upgrade optuna cebra>=0.4.0 matplotlib==3.9.2 numpy pandas scipy seaborn umap_learn pyspark python-dotenv tensorboardX optuna-dashboard duckdb-engine

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("/main/external/dimensionality-reduction")

In [3]:
from cebra import CEBRA
import torch
import torch.utils
import numpy as np
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from utils.overrides import transform
from utils.utils import pandas_series_to_pytorch
from functools import partial
from tqdm import tqdm
import optuna
import dotenv
import os
import time
dotenv.load_dotenv()
dotenv.load_dotenv("/main/external/dimensionality-reduction/.env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter("/main/external/tensorboard_runs")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [4]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Data Load

In [5]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

MAX_MEMORY = "120g"

spark = SparkSession \
    .builder \
    .appName("UranusCluster") \
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .config("spark.memory.offHeap.enabled",True)\
    .config("spark.memory.offHeap.size","16g")   \
    .getOrCreate()

# Verify the SparkContext
print(spark.sparkContext.getConf().getAll())

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/23 09:45:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.driver.memory', '120g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.submitTime', '1734947107366'), ('spark.executor.memory', '12

In [6]:
EXPERIMENT_ID = "ID18170/DataFrame_Imaging_dFF_18170_day4"

In [7]:
df = spark.read.format("parquet").load("/main/external/data/transformed")\
    .select(["index","neural_data", "positional_encoding", "file_name"])\
    .where(F.col("file_name") == EXPERIMENT_ID)
df.show(5)

                                                                                

+-----------+--------------------+--------------------+--------------------+
|      index|         neural_data| positional_encoding|           file_name|
+-----------+--------------------+--------------------+--------------------+
|51539752512|[0.57670706510543...|-0.01445482831236...|ID18170/DataFrame...|
|51539695187|[-0.0095341661944...|-0.00437369369180...|ID18170/DataFrame...|
|51539752513|[0.54644668102264...|-0.01685933642369...|ID18170/DataFrame...|
|51539695188|[-0.0045404555276...|0.010276542098878195|ID18170/DataFrame...|
|51539752514|[0.51589649915695...|-0.00761060092563...|ID18170/DataFrame...|
+-----------+--------------------+--------------------+--------------------+
only showing top 5 rows



# Models

In [8]:
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [9]:
number_of_partitions = 32
batch_size = 128 #Used for both embdding and decoder
n_splits = 16
latent_dimension = 3
test_ratio = 0.3
scores = []

In [10]:
criterion = torch.nn.MSELoss()

class Dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [11]:
indices = df.select('index').rdd.flatMap(lambda x: x).collect()
files = [x.file_name for x in df.select("file_name").distinct().collect()]
len(files)

                                                                                

1

In [12]:
if not os.path.exists('/main/external/models'):
    os.makedirs('/main/external/models')

In [21]:
def objective(trial, file):
    learning_rate_embedding = trial.suggest_float('learning_rate_embedding', 1e-10, 0.1, log = True)
    learning_rate_decoder = trial.suggest_float('learning_rate_decoder', 1e-10, 0.1, log = True)
    embedding_version = trial.suggest_categorical('model',choices=[
        'offset1-model',
        'offset1-model-v1',
        'offset1-model-v2',
        'offset1-model-v3',
        'offset1-model-v4',
        'offset1-model-v5'
    ])
    
    logging.info(f"Started training on file {file}")
    decoder = torch.nn.Sequential(
        torch.nn.Linear(latent_dimension, 3),
        torch.nn.GELU(),
        torch.nn.Linear(3,3),
        torch.nn.GELU(),
        torch.nn.Linear(3,1),
        torch.nn.Tanh()
    ).to(device)
    
    logging.debug("Trying to load dataframe into memory")
    du = df.where(F.col('file_name') == file).toPandas()
    
    n_samples = du.shape[0]
    logging.debug(f"There are {n_samples}")
    n_test = int(test_ratio * n_samples)
    
    with device:
        X_test = pandas_series_to_pytorch(du.neural_data[-n_test:], device)
        y_test = pandas_series_to_pytorch(du.positional_encoding[-n_test:], device)
        X_train = pandas_series_to_pytorch(du.neural_data[:n_test], device)
        y_train = pandas_series_to_pytorch(du.positional_encoding[:n_test], device)
    
    #Train embedding
    logging.info("Training embedding")
    embedding = CEBRA(
        model_architecture=embedding_version,
        batch_size=batch_size,
        learning_rate=learning_rate_embedding,
        temperature_mode='auto',
        output_dimension=latent_dimension,
        max_iterations=50000,
        min_temperature=0.001,
        distance='cosine',
        conditional='time_delta',
        device=str(device),
        verbose=False,
        time_offsets=10
    )
    embedding.fit(X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy())
    
    logging.info("Training decoder")
    with device:
        # Train Decoder
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate_decoder)
        for i, (X_batch, y_batch) in enumerate(DataLoader(Dataset(X_train,y_train), batch_size=batch_size, shuffle=False)):
            decoder.train()
            decoder_optimizer.zero_grad()
            U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
            y_pred = decoder(U)
            loss = criterion(y_pred, y_batch.unsqueeze(1))
            loss.backward()
            decoder_optimizer.step()
            writer.add_scalar(f"{file}/decoder/train", loss.item(), i)
        
            # Test Decoder
            losses = []
            for (X_batch, y_batch) in DataLoader(Dataset(X_test,y_test), batch_size=batch_size, shuffle=False):
                decoder.eval()
                U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
                y_pred = decoder(U)
                loss = criterion(y_pred, y_batch.unsqueeze(1))
                losses.append(loss.item())
            writer.add_scalar(f"{file}/decoder/test", np.mean(losses), i)

        logging.info("Logging reconstruction")
        # Reconstruction comparison
        counter = 0
        total_loss = 0
        for (X_batch, y_batch) in DataLoader(Dataset(X_test,y_test), batch_size=batch_size, shuffle=False):
            decoder.eval()
            U = torch.Tensor(transform(embedding,X_batch.detach().cpu().numpy())).to(device)
            y_pred = decoder(U).flatten()
            for y_true, y_pred in zip(y_batch, y_pred):
                diff = y_true - y_pred
                total_loss += diff*diff
                writer.add_scalar(f"{file}/series/prediction_difference_relative", diff/y_true, i)
                i += 1
        
    #Calculating metric
    if (hasattr(objective, "best_loss") and total_loss < objective.best_loss) or not hasattr(objective, "best_loss"):
        objective.best_loss = total_loss

        logging.info("Saving models")
        savepath = os.path.join("/main/external/models", file)
        if not os.path.exists(savepath):
            os.makedirs(savepath, exist_ok=True)
        embedding.save(os.path.join(savepath, "embedding.pt"))
        torch.save(decoder, os.path.join(savepath, "decoder.pt"))
    
    return total_loss

In [None]:
for file in tqdm(files):
    objective_ = partial(objective, file = file)
    study = optuna.create_study(
        storage = "sqlite:///optuna.db",
        load_if_exists=True,
        study_name=file + "_" + str(int(round(time.time() * 1000))),
        direction="minimize"
    )
    study.optimize(objective_, n_trials=100)

  0%|          | 0/1 [00:00<?, ?it/s][I 2024-12-23 09:51:40,173] A new study created in RDB with name: ID18170/DataFrame_Imaging_dFF_18170_day4_1734947500160


2024-12-23 09:51:40 [INFO] Started training on file ID18170/DataFrame_Imaging_dFF_18170_day4
2024-12-23 09:51:50 [INFO] Training embedding                                   
2024-12-23 09:53:18 [INFO] Training decoder
2024-12-23 09:54:04 [INFO] Logging reconstruction
2024-12-23 09:54:08 [INFO] Saving models
[I 2024-12-23 09:54:08,654] Trial 0 finished with value: 476.3794250488281 and parameters: {'learning_rate_embedding': 6.3391233170425225e-06, 'learning_rate_decoder': 2.2964145483889253e-08, 'model': 'offset1-model'}. Best is trial 0 with value: 476.3794250488281.
2024-12-23 09:54:08 [INFO] Started training on file ID18170/DataFrame_Imaging_dFF_18170_day4
2024-12-23 09:54:18 [INFO] Training embedding                                   
