# Lipid2Position

In [None]:
import warnings
import scanpy as sc
import pandas as pd
import numpy as np
import os
import sys
import os
import random

from tqdm import tqdm
import matplotlib.pyplot as plt

warnings.simplefilter(action='ignore')
%reload_ext autoreload
%autoreload 2

sys.path.append(os.path.abspath("./lipid2position"))

sc.set_figure_params(frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))

np.random.seed(4242)
random.seed(4242)
project_name = "lipid2position"
project_path = "."

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
src_path = os.path.join(project_path, project_name)

data_path = os.path.join(src_path, "data/")

results_path = os.path.join(project_path, "results")

directories = [src_path, data_path, results_path]
for directory in directories:
    if not os.path.exists(directory):
        os.makedirs(directory)

initial_format = 'exp'
final_format = 'norm_exp'

datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
df = datavignettes.loc[datavignettes['Sample'] == "ReferenceAtlas",:]

df = df[~df['name'].isna()]
zcenter = (df['z_index'].max() - df['z_index'].min()) /2
allenbrain_df = df.loc[df['z_index'] > zcenter,:]
df = df.loc[df['z_index'] < zcenter, :]
print(df.shape)

allen_global_min_z = allenbrain_df['z_index'].min()
allen_global_max_z = allenbrain_df['z_index'].max()
allen_global_min_y = -allenbrain_df['y_index'].max() 
allen_global_max_y = -allenbrain_df['y_index'].min()  

In [None]:
from data_handler import DataHandler # it is in "assets"
data_handler = DataHandler(df, 
                           initial_format=initial_format, 
                           final_format=final_format)

del df

adata = data_handler.adata
adata

In [None]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import logging
from lipid2position import compute_loss

def train(X, Y, 
          model, optimizer, scheduler,
          loss_type,
          k=None, a=None, b=None,
          num_epochs=100, batch_size=64,
          savepath=None,
          device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler) 
    logging.basicConfig(filename=os.path.join(savepath, "logging.log"), 
                        level=logging.INFO, 
                        format="%(asctime)s - %(message)s")
    logging.info("Starting training session")
    
    # Data preparation
    train_data = TensorDataset(torch.tensor(X, dtype=torch.float32), 
                               torch.tensor(Y, dtype=torch.float32))
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    # Early stopping setup
    best_loss = float('inf')
    patience_counter = 0
    patience = 10

    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = compute_loss(outputs.squeeze(), targets, 
                                loss_type=loss_type,
                                k=k, a=a, b=b)
            loss.backward()
            optimizer.step()
            total_loss += compute_loss(outputs.squeeze(), targets, loss_type='standardmse').item()

        total_loss /= len(train_loader)
        scheduler.step(total_loss)
        
        # Validation
        # model.eval()
        # val_loss = 0
        # with torch.no_grad():
        #     for data, targets in test_loader:
        #         data, targets = data.to(device), targets.to(device)
        #         outputs = model(data)
        #         val_loss += criterion(outputs.squeeze(), targets).item()

        # val_loss /= len(test_loader)
        # scheduler.step(val_loss)

        # predict the whole X at every epoch
        model.eval()
        with torch.no_grad():
            epoch_pred = model(torch.tensor(X, dtype=torch.float32).to(device))
            epoch_loss = compute_loss(epoch_pred.squeeze(), 
                                    torch.tensor(Y, dtype=torch.float32).to(device), 
                                    loss_type='standardmse').item()
            # append to list
            if epoch == 0:
                losses = [epoch_loss]
                predictions = epoch_pred.cpu().numpy()
            else:
                losses.append(epoch_loss)
                predictions = np.concatenate([predictions, epoch_pred.cpu().numpy()], axis=1)

        logging.info(f'Epoch {epoch+1} (LR: {scheduler.get_last_lr()}), Training Loss: {total_loss}')
        print(f'Epoch {epoch+1} (LR: {scheduler.get_last_lr()}), Training Loss: {total_loss}')

        # Early stopping
        if total_loss < best_loss:
            best_loss = total_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                logging.info("Early stopping.")
                break

    # Save model
    if savepath is not None:
        torch.save(model.state_dict(), os.path.join(savepath, "model.pth"))
        np.save(os.path.join(savepath, "history_predictions.npy"), predictions)
        np.save(os.path.join(savepath, "history_losses.npy"), losses)
    
    logging.info("Training completed.")

In [None]:
from sklearn.preprocessing import StandardScaler
add_LP = False

if add_LP:
    LP = pd.read_hdf(os.path.join(data_path, f"{dataset_name}_latent.h5ad"), key='table').loc[adata.obs_names,:]
    X = np.concatenate([adata.X, LP/1000], axis=1)
else:
    X = adata.X

scaler = StandardScaler()
XTRAIN = scaler.fit_transform(X)

YTRAIN = adata.obs[['x_index', 'y_index', 'z_index']].values
a = torch.quantile(torch.tensor(YTRAIN), 0.1, axis=0).to(device)
b = torch.quantile(torch.tensor(YTRAIN), 0.9, axis=0).to(device)
k = 1

In [None]:
from lipid2position import Lipid2Position
model = Lipid2Position(XTRAIN.shape[1]).to(device)

l2_reg = 1e-3
optimizer = optim.Adam(model.parameters(), 
                       lr=0.001,
                       weight_decay=l2_reg)
scheduler = ReduceLROnPlateau(optimizer, 
                              'min', 
                              factor=0.1, 
                              patience=5, 
                              verbose=True)

format = final_format.replace("_", "")
latents = "+LPs" if add_LP else ""
loss_type = 'standardmse'
modelname = f"lipid2position_{dataset_name.split('_')[-1]}_{format}{latents}_{loss_type}"

savepath = os.path.join(results_path, modelname)
if not os.path.exists(savepath):
    os.makedirs(savepath)
print(f"Results will be saved in {savepath}")

In [None]:
model = Lipid2Position(X.shape[1]).to(device)
model.load_state_dict(torch.load(os.path.join(savepath, "model.pth")))
history_losses = np.load(os.path.join(savepath, "history_losses.npy"))
history_predictions = np.load(os.path.join(savepath, "history_predictions.npy"))
model.eval()

In [None]:
outputs = model(torch.tensor(X, dtype=torch.float32).to(device)).cpu().detach().numpy()
adata.obs['x_index_pred'] = outputs[:,0]
adata.obs['y_index_pred'] = outputs[:,1]
adata.obs['z_index_pred'] = outputs[:,2]

In [None]:
allen_colors = {
    'Cerebral cortex': [
        '#1f9d5a', # green
        '#40a666', # green
        '#2fa850', # green
        '#59b363', # green
        '#219866', # green
        '#188064', # green
        '#248a5e', # green
        '#a4daa4', # green
        '#9ad2bd', # green
        '#6acbba', # green
        '#54bf94', # green
        '#a0ee9d', # green
        '#62d09f', # green
        '#009c75', # green
        '#8ada87', # green
        '#11ad83', # green
        '#1aa698', # green
        '#7ed04b', # green
        '#66a83d', # green
        '#90eb8d', # green
        '#9de79c', # green
        '#84ea81', # green
        '#61e7b7', # green
        '#59daab', # green
        '#32b825', # green
        '#58ba48', # green
        '#4fc244', # green
        '#97ec93', # green
        '#48c83c', # green
        '#33b932', # green
        '#a8ecd3', # green
        '#59b947', # green
        '#72d569', # green
    ],

    'Cerebral Nuclei': [
        '#80cdf8', # blue
        '#98d6f9', # blue
        '#a2b1d8', # blue
        '#90cbed', # blue
        '#b3c0df', # blue
        '#96a7d3', # blue
        '#8599cc', # blue
        '#80c0e2', # blue 
        '#009fac', # blue
        '#019399', # blue
        '#08858c', # blue
        '#15b0b3', # blue
        '#0d9f91', # blue
        '#0e9684', # blue
    ],

    'InterBrain': [
        '#e64438', # red
        '#f2483b', # red
        '#ff5547', # red
        '#ff909f', # red
        '#ff7080', # red
        '#ff8084', # red
        '#ff4c3e', # red
        '#ff9b88', # red
    ],

    'MidBrain': [
        '#ff90ff', # pink
        '#ff64ff', # pink
        '#ff7aff', # pink
        '#ffa6ff', # pink
        '#ffc395', # pink
        '#ffb3d9', # pink
        '#ff9bcd', # pink
        '#ffa5d2', # pink
    ],
    
    'HindBrain': [
        '#ffba86', # orange
        '#ffae6f', # orange
    ],

    'Cerebellum': [
        '#f0f080', # yellow
        '#fffc91', # yellow
        '#fffdbc', # yellow
    ],

    'Fiber Tracts': [
    '#cccccc', # grey
    '#aaaaaa', # grey
    ]
}

# '#ffffff', # white 



In [None]:
global_min_x = adata.obs['x_index_pred'].min()
global_max_x = adata.obs['x_index_pred'].max()
global_min_z = adata.obs['z_index_pred'].min()
global_max_z = adata.obs['z_index_pred'].max()
global_min_y = adata.obs['y_index_pred'].max() 
global_max_y = adata.obs['y_index_pred'].min()  

In [None]:
fig, axes = plt.subplots(len(allen_colors.keys()), 3, 
                         figsize=(15, 3*len(allen_colors.keys())))

for i, col in enumerate(allen_colors.keys()):
    xx = adata.obs[adata.obs['allencolor'].isin(allen_colors[col])]
    
    for j, axis_name in enumerate(['x', 'y', 'z']):
        pred_col = f'{axis_name}_index_pred'
        true_col = f'{axis_name}_index'

        xx_filtered = xx

        ax = axes[i][j]
        ax.scatter(xx_filtered[true_col], xx_filtered[pred_col], 
                   s=1, alpha=0.5, c=allen_colors[col][0])
        ax.set_ylabel(col if j == 0 else '')
        
        mse = np.mean((xx_filtered[true_col] - xx_filtered[pred_col])**2)
        ax.set_title(f"MSE: {mse:.3f}", fontsize=10)

        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]),  
            np.max([ax.get_xlim(), ax.get_ylim()]), 
        ]
        ax.plot(lims, lims, 'k--', alpha=0.75, zorder=0)

plt.suptitle(f"Predictions on {dataset_name}")
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(4, 7, figsize=(15, 7))
axes = axes.flatten()
dot_size = 2

sections_to_plot = np.sort(adata.obs['Section'].unique())
allen_global_min_z = allenbrain_df['z_index'].min()
allen_global_max_z = allenbrain_df['z_index'].max()
allen_global_min_y = -allenbrain_df['y_index'].max() 
allen_global_max_y = -allenbrain_df['y_index'].min()  

for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = adata.obs[adata.obs["Section"] == section_num]
    ax.scatter(xx['z_index_pred'], -xx['y_index_pred'],
                     c=xx['allencolor'], 
                     s=dot_size, 
                     alpha=0.3) 
    ax.scatter(allen_global_max_z-xx['z_index'], -xx['y_index'],
                     c=xx['allencolor'], 
                     s=dot_size, 
                     alpha=0.3) 
    ax.set_title(f"{int(section_num)}")
    ax.axis('off')
    ax.set_aspect('equal')  

for j in range(i+1, len(axes)):
    fig.delaxes(axes[j])

plt.suptitle(f"Predictions on {dataset_name}")
plt.tight_layout()
plt.show()

### Test on Brain3

In [None]:
datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
df = datavignettes.loc[datavignettes['Sample'] == "SecondAtlas",:]

df = df[~df['name'].isna()]
zcenter = (df['z_index'].max() - df['z_index'].min()) /2
allenbrain_df = df.loc[df['z_index'] > zcenter,:]
df = df.loc[df['z_index'] < zcenter, :]
print(df.shape)

allen_global_min_z = allenbrain_df['z_index'].min()
allen_global_max_z = allenbrain_df['z_index'].max()
allen_global_min_y = -allenbrain_df['y_index'].max() 
allen_global_max_y = -allenbrain_df['y_index'].min()  

data_handler = DataHandler(df, 
                           initial_format=initial_format, 
                           final_format=final_format)

del df

adata_test = data_handler.adata
adata_test

In [None]:
from sklearn.preprocessing import StandardScaler

if add_LP:
    LPTEST = pd.read_hdf(os.path.join(data_path, f"{dataset_name}_latent.h5ad"), key='table').loc[adata_test.obs_names,:]
    XTEST = np.concatenate([adata_test.X, LPTEST/1000], axis=1)
else:
    XTEST = adata_test.X

scaler = StandardScaler()
scaler.fit(X)
XTEST = scaler.transform(XTEST)

YTEST = adata_test.obs[['x_index', 'y_index', 'z_index']].values

In [None]:
outputs = model(torch.tensor(XTEST, dtype=torch.float32).to(device)).cpu().detach().numpy()
adata_test.obs['x_index_pred'] = outputs[:,0]
adata_test.obs['y_index_pred'] = outputs[:,1]
adata_test.obs['z_index_pred'] = outputs[:,2]

In [None]:
global_min_x = adata_test.obs['x_index_pred'].min()
global_max_x = adata_test.obs['x_index_pred'].max()
global_min_z = adata_test.obs['z_index_pred'].min()
global_max_z = adata_test.obs['z_index_pred'].max()
global_min_y = adata_test.obs['y_index_pred'].max() 
global_max_y = adata_test.obs['y_index_pred'].min()  

fig, axes = plt.subplots(len(allen_colors.keys()), 3, 
                         figsize=(15, 3*len(allen_colors.keys())))

for i, col in enumerate(allen_colors.keys()):
    xx = adata_test.obs[adata_test.obs['allencolor'].isin(allen_colors[col])]
    
    for j, axis_name in enumerate(['x', 'y', 'z']):
        pred_col = f'{axis_name}_index_pred'
        true_col = f'{axis_name}_index'

        xx_filtered = xx

        ax = axes[i][j]
        ax.scatter(xx_filtered[true_col], xx_filtered[pred_col], 
                   s=1, alpha=0.5, c=allen_colors[col][0])
        ax.set_ylabel(col if j == 0 else '')
        
        mse = np.mean((xx_filtered[true_col] - xx_filtered[pred_col])**2)
        ax.set_title(f"MSE: {mse:.3f}", fontsize=10)

        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]), 
            np.max([ax.get_xlim(), ax.get_ylim()]), 
        ]
        ax.plot(lims, lims, 'k--', alpha=0.75, zorder=0)

plt.suptitle(f"Predictions on {dataset_name}")
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(4, 7, figsize=(15, 7))
axes = axes.flatten()
dot_size = 2

sections_to_plot = np.sort(adata_test.obs['Section'].unique())
allen_global_min_z = allenbrain_df['z_index'].min()
allen_global_max_z = allenbrain_df['z_index'].max()
allen_global_min_y = -allenbrain_df['y_index'].max() 
allen_global_max_y = -allenbrain_df['y_index'].min()  

for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = adata_test.obs[adata_test.obs["Section"] == section_num]
    ax.scatter(xx['z_index_pred'], -xx['y_index_pred'],
                     c=xx['allencolor'], 
                     s=dot_size, 
                     alpha=0.3, rasterized=True)
    ax.scatter(allen_global_max_z-xx['z_index'], -xx['y_index'],
                     c=xx['allencolor'], 
                     s=dot_size, 
                     alpha=0.3, rasterized=True) 
    ax.axis('off')
    ax.set_aspect('equal')  

for j in range(i+1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.savefig("recon_holdoutbrain.pdf")
plt.show()

In [None]:
import plotly.graph_objects as go

sections_to_plot = sorted(adata_test.obs['Section'].unique())
dot_size = 6

fig = go.Figure()

for sec in sections_to_plot:
    df = adata_test.obs[adata_test.obs['Section'] == sec]

    print(df.shape)
    print(df['x_index_pred'])
    df = df.loc[df['x_index_pred'] > 250,:]
    print(df.shape)
    
    fig.add_trace(go.Scatter3d(
        x=df['z_index_pred'],
        y=-df['y_index_pred'],
        z=df['x_index_pred'],
        mode='markers',
        marker=dict(
            size=dot_size,
            color=df['allencolor'],
            opacity=1.0
        ),
        showlegend=False
    ))

fig.update_layout(
    scene=dict(
        xaxis=dict(showbackground=False, showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showbackground=False, showgrid=False, zeroline=False, showticklabels=False),
        zaxis=dict(showbackground=False, showgrid=False, zeroline=False, showticklabels=False),
        bgcolor='rgba(0,0,0,0)'
    ),
    paper_bgcolor='rgba(0,0,0,0)',
    margin=dict(t=0, l=0, r=0, b=0)
)

fig.write_html(f"{dataset_name}_3d_sections3.html", include_plotlyjs='cdn')