In [None]:
"""
Purpose: Implementation fo DiffPool
graph coarsening manner

"""

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import datajoint as dj
import trimesh
from tqdm.notebook import tqdm
from pathlib import Path

from os import sys
#sys.path.append("/meshAfterParty/meshAfterParty")
sys.path.append("/datasci_tools/datasci_tools")
sys.path.append("/machine_learning_tools/machine_learning_tools/")
sys.path.append("/pytorch_tools/pytorch_tools/")
sys.path.append("/neuron_morphology_tools/neuron_morphology_tools/")
sys.path.append("/meshAfterParty/meshAfterParty/")

from importlib import reload

In [6]:
data_path = Path("../data/")
list(data_path.iterdir())

[PosixPath('../data/.ipynb_checkpoints'),
 PosixPath('../data/df_cell_type_fine_h01.pbz2'),
 PosixPath('../data/df_morphometrics_h01.pbz2')]

In [7]:
#datasci_tools modules
import system_utils as su
import pandas_utils as pu
import pandas as pd
import numpy as np
import numpy_utils as nu
import networkx_utils as xu
from tqdm_utils import tqdm

#neuron_morphology_tools modules
import neuron_nx_io as nxio

In [8]:
import torch
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges
from torch_geometric.data import Data
from torch_geometric import transforms

# for the dataset object
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.loader import DataLoader
from torch_geometric.data import DenseDataLoader

In [9]:
#pytorch_tools modules
import preprocessing_utils as pret
import geometric_models as gm

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


# Step 0: Choosing the Model

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device = {device}")

with_skeleton = True

features_to_delete = [
    "mesh_volume",
    "apical_label",
    "basal_label",
]

if not with_skeleton:
    features_to_delete +=[
        "skeleton_vector_downstream_phi",      
        "skeleton_vector_downstream_theta",    
        "skeleton_vector_upstream_phi",        
        "skeleton_vector_upstream_theta",  
    ]

features_to_keep = None


device = cpu


# Step 1: Loading the Graph Data

In [11]:
from pathlib import Path
import time
def load_data(
    gnn_task,
    data_file,#"df_cell_type_fine.pbz2"
    data_path = Path("./data/m65_full/"),
    data_df = None,
    label_name = None,
    graph_label = "cell_type_fine_label", 
    dense_adj = False,
    directed = False,
    features_to_remove = [
        "mesh_volume",
        "apical_label",
        "basal_label",
    ],
    with_skeleton=True,
    device = "cpu",
    
    #for the standardization
    df_standardization = None,
    
    cell_type_map = None,
    
    processed_data_folder_name = None,
    
    max_nodes = 300,
    
    #--------- processing the dataset ----
    clean_prior_dataset = False,
    data_source = None,
    verbose = True,
    
    return_cell_type_map = False,
    
    
    ):
    
    """
    Purpose: Will load the data for processing using the GNN models
    
    """
    if with_skeleton:
        gnn_task_name = f"{gnn_task}_with_skeleton"
        features_to_delete = features_to_remove
    else:
        gnn_task_name = f"{gnn_task}"
        
        features_to_delete = features_to_remove + [
        "skeleton_vector_downstream_phi",      
        "skeleton_vector_downstream_theta",    
        "skeleton_vector_upstream_phi",        
        "skeleton_vector_upstream_theta",  
        ]
            
    if processed_data_folder_name is None:
        if dense_adj:
            processed_data_folder = data_path / Path(f"{gnn_task_name}")#_processed_dense")
        elif directed:
            processed_data_folder = data_path / Path(f"{gnn_task_name}_directed")#_processed_dense")
        else:
            processed_data_folder = data_path / Path(f"{gnn_task_name}_no_dense")#_processed_dense")
    else:
        processed_data_folder = data_path / Path(f"{processed_data_folder_name}")

        
    #1) Load the data
    if verbose:
        print(f"Starting to load data")
        st = time.time()
        
    if data_df is None:
        data_filepath = Path(data_path) / Path(data_file)
        data_df = su.decompress_pickle(data_filepath)
    
    if verbose:
        print(f"Finished loading data: {time.time() - st}")
    
    #2) Getting the means and standard deviations if not already computed
    if df_standardization is None:
        if verbose:
            print(f"Started calculating normalization")
        all_batch_df = pd.concat([nxio.feature_df_from_gnn_info(
            k[0],
            return_data_labels_split = False) for k in data_df[gnn_task].to_list()])

        if label_name is not None:
            all_batch_df = all_batch_df[[k for k in 
                    all_batch_df.columns if k not in nu.convert_to_array_like(label_name)]]
        else:
            all_batch_df = all_batch_df

        # will use these to normalize the data
        col_means = all_batch_df.mean(axis=0).to_numpy()
        col_stds = all_batch_df.std(axis=0).to_numpy()
        df_standardization = pd.DataFrame(np.array([col_means,col_stds]),
             index=["norm_mean","norm_std"],
            columns=all_batch_df.columns)
        
        if verbose:
            print(f"Finished calculating normalization: {time.time() - st}")
        
        #max_nodes = np.max(all_batch_df.index.to_numpy()) + 1
        
#         all_batch_df_norm = pu.normalize_df(all_batch_df,
#                 column_means=df_standardization[all_batch_df.columns].loc["norm_mean",:],
#                 column_stds =df_standardization[all_batch_df.columns].loc["norm_std",:])
    try:
        col_means = df_standardization.loc["norm_mean",:].to_numpy()
    except:
        col_means = df_standardization.iloc[0,:].to_numpy()
    
    try:
        col_stds = df_standardization.loc["norm_std",:].to_numpy()
    except:
        col_stds = df_standardization.iloc[1,:].to_numpy()

    
    #3) Creating the Dataclass
    if cell_type_map is None:
        total_labels,label_counts = np.unique((data_df.query(f"{graph_label}=={graph_label}")[
        graph_label]).to_numpy(),return_counts = True)
        cell_type_map = {k:i+1 for i,k in enumerate(total_labels)}
        cell_type_map[None] = 0
    
    
    # ---------- Creating the dataset --------------------
    
    # --------- Functions for loading custom dataset -----
    def pytorch_data_from_gnn_info(
        gnn_info,
        y = None,
        verbose = False,
        normalize = True,
        features_to_delete=None,
        features_to_keep = None,
        data_name = None,
        data_source = None,
        ): 
        """
        Purpose: To convert our data format into pytorch Data object

        Pseudocode: 
        1) Create the edgelist (turn into tensor)
        2) Get the 
        """
        edgelist = torch.tensor(xu.edgelist_from_adjacency_matrix(
            array = gnn_info["adjacency"],
            verbose = False,
        ).T,dtype=torch.long)

        x,y_raw = nxio.feature_df_from_gnn_info(
            gnn_info,
            return_data_labels_split = True)
        if y is None:
            y = y_raw

        if not type(y) == str:
            y = None

        y_int = np.array(cell_type_map[y] ).reshape(1,-1)

        if normalize:
            x = (x-col_means)/col_stds

        # --- keeping or not keeping sertain features
        gnn_features = gnn_info["features"]

        keep_idx = np.arange(len(gnn_features))
        if features_to_delete is not None:
            curr_idx = np.array([i for i,k in enumerate(gnn_features)
                           if k not in features_to_delete])
            keep_idx = np.intersect1d(keep_idx,curr_idx)
            if verbose:
                print(f"keep_idx AFTER DELETE= {keep_idx}")
        if features_to_keep is not None:
            curr_idx = np.array([i for i,k in enumerate(gnn_features)
                           if k in features_to_keep])
            keep_idx = np.intersect1d(keep_idx,curr_idx)
            if verbose:
                print(f"keep_idx AFTER KEEP = {keep_idx}")

        x = x[:,keep_idx]

        x = torch.tensor(x,dtype=torch.float)
        y = torch.tensor(y_int,dtype=torch.long)

        if len(y) > 1:
            raise Exception(f"y = {y}")

        if y.shape[0] != 1 or y.shape[1] != 1:
            raise Exception(f"y = {y}")


        if verbose:
            print(f"x.shape = {x.shape},y.shape ={y.shape}")
        
        data_dict = dict(x=x,y=y,edge_index=edgelist)
        if data_name is not None:
            data_dict["data_name"] = data_name
            
        if data_source is not None:
            data_dict["data_source"] = data_source
            
        
        data = Data(**data_dict)
        
        return data

    class CellTypeDataset(InMemoryDataset):
        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
            super().__init__(root, transform, pre_transform, pre_filter)
            self.data, self.slices = torch.load(self.processed_paths[0])

        @property
        def raw_file_names(self):
            #return ['some_file_1', 'some_file_2', ...]
            return [str(data_filepath.absolute())]

        @property
        def processed_file_names(self):
            return ['data.pt']

        # def download(self):
        #     # Download to `self.raw_dir`.
        #     download_url(url, self.raw_dir)
        #     ...

        def process(self):
            # Read data into huge `Data` list.
            #data_list = [...]

    #         if data_df is None:
    #             data_df = su.decompress_pickle(self.raw_file_names[0])


            data_list = []
            for k,y,segment_id,split_index in tqdm(zip(
                data_df[gnn_task].to_list(),
                data_df[graph_label].to_list(),
                data_df["segment_id"],
                data_df["split_index"])):
                
                if len(k) > 0:
                    data_list.append(pytorch_data_from_gnn_info(
                        k[0],
                        y=y,
                        features_to_delete=features_to_delete,
                        features_to_keep = features_to_keep,
                        data_name = f"{segment_id}_{split_index}",
                        data_source = data_source,
                        verbose = False))

            if self.pre_filter is not None:
                data_list_final = []
                for data in data_list:
                    try:
                        if self.pre_filter(data):
                            data_list_final.append(data)
                    except:
                        continue

                data_list = data_list_final

            for j,d in enumerate(data_list):
                if d.y.shape[0] != 1 or d.y.shape[1] != 1:
                    raise Exception(f"{j}")

            if self.pre_transform is not None:
                data_list_final = []
                for j,data in enumerate(data_list):
                    try:
                        curr_t = self.pre_transform(data)
                        if curr_t.y.shape[0] != 1 or curr_t.y.shape[1] != 1:
                            raise Exception(f"{j}, data = {curr_t}")
                        data_list_final.append(curr_t)
                    except:
                        continue
                data_list = data_list_final

            for j,d in enumerate(data_list):
                if d.y.shape[0] != 1 or d.y.shape[1] != 1:
                    raise Exception(f"{j}, data = {d}")

            data, slices = self.collate(data_list)
            torch.save((data, slices), self.processed_paths[0])
            
    # --- creating the folder for the dataset --
    if clean_prior_dataset:
        try:
            su.rm_dir(processed_data_folder)
        except:
            pass
        
    processed_data_folder.mkdir(exist_ok = True)
    
    
    # a) Processing Filteres
    class MyFilter(object):
        def __call__(self, data):
            return data.num_nodes <= max_nodes

    if dense_adj:
        #gets the maximum number of nodes in any of the graphs
        transform_list = [
            transforms.ToUndirected(),
            T.ToDense(max_nodes),
            #transforms.NormalizeFeatures(),
            ]
        re_filter = MyFilter()
    elif directed:
        transform_list = []
        pre_filter = None
    else:
        transform_list = [
            transforms.ToUndirected(),]

        pre_filter = None


    transform_norm = transforms.Compose(transform_list)
    
    
    # b) Creating the Dataset
    dataset = CellTypeDataset(
            processed_data_folder.absolute(),
            pre_transform = transform_norm,
            pre_filter = pre_filter,
            )
    
    if return_cell_type_map:
        return dataset,cell_type_map
    else:
        return dataset

# --- Running the preprocessing --

In [12]:
df_standardization = None#pu.csv_to_df("../data/cell_type_normalization_df.csv")
cell_type_map = None,#su.decompress_pickle("../data/cell_type_map")

In [None]:
gnn_task = "cell_type_fine"
label_name = None
graph_label = "cell_type_fine_label"
data_file = "df_cell_type_fine.pbz2"
dense_adj = False
directed = False

data_path = Path("../data/m65_full/")
data_source = "m65"
processed_data_folder_name = f"{gnn_task}_{data_source}"

m65_dataset = load_data(
    gnn_task = gnn_task,
    data_file = data_file,
    data_path = data_path,
    label_name = label_name,
    graph_label = graph_label, 
    dense_adj = False,
    directed = False,
    features_to_remove = [
        "mesh_volume",
        "apical_label",
        "basal_label",
    ],
    with_skeleton=True,
    device = "cpu",
    
    #for the standardization
    df_standardization = df_standardization,
    cell_type_map = cell_type_map,
    
    processed_data_folder_name = processed_data_folder_name,
    
    max_nodes = 300,
    clean_prior_dataset = True,  
    data_source = data_source,
    
)

In [None]:
m65_dataset[1000].x

# -- a) creating the h01 dataset

In [None]:
cell_type_map.update({"Unsure E":32,"Unsure I":33})
cell_type_map

In [None]:
data_path= Path("../data/h01_full")
gnn_task = "cell_type_fine"
label_name = None
graph_label = "cell_type_fine_label"
data_file = "df_cell_type_fine_h01.pbz2"
dense_adj = False
directed = False
data_source = "h01"


processed_data_folder_name = f"{gnn_task}_h01"

df_standardization = pu.csv_to_df("../cell_type_normalization_df.csv")


h01_dataset = load_data(
    gnn_task = gnn_task,
    data_file = data_file,
    data_path = data_path,
    graph_label = graph_label, 
    dense_adj = False,
    directed = False,
    
    #for the standardization
    df_standardization = df_standardization,
    cell_type_map=cell_type_map,
    
    processed_data_folder_name = processed_data_folder_name,
    clean_prior_dataset = True,
    
    data_source = data_source,
)

In [None]:
dataset_num_node_features = m65_dataset.num_node_features
dataset_num_classes = m65_dataset.num_classes

# Loading the Model

In [None]:
%load_ext tensorboard
%tensorboard --logdir /pytorch_tools/Applications/Cell_Types_GNN/tensorboard/GCNFlat --bind_all

In [None]:
import geometric_models as gm

In [None]:
import general_utils as gu
architecture_kwargs_global = dict(
    n_hidden_channels = 32, 
    global_pool_type="mean",
    n_layers = 2
)

architecture_kwargs_curr = dict(
    n_hidden_channels = 64,
    global_pool_type = "mean",
    n_layers = 2)

architecture_kwargs = gu.merge_dicts([architecture_kwargs_global,architecture_kwargs_curr])
architecture_kwargs

In [None]:
model_name = "GCNFlat"
checkpoint_dir = Path("../model_checkpoints")
checkpoint_dir = checkpoint_dir / Path(f"{model_name}")

winning_name = (f"{model_name}_" + "_".join([f"{k}_{v}" for k,v in architecture_kwargs.items()]) +
                "_lr_0.01_with_skeleton_True")


epoch = 95
winning_dir = checkpoint_dir / Path(f"{winning_name}_checkpoints") 
winning_filepath = winning_dir / Path(f"{winning_name}_epoch_{epoch}")
winning_filepath.exists()

In [None]:
model = getattr(gm,model_name)(
    dataset_num_node_features=dataset_num_node_features,
    dataset_num_classes=dataset_num_classes,
    **architecture_kwargs,
    #use_bn=False
    )

checkpoint = torch.load(winning_filepath)
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

model.eval()

In [None]:
[list(model.parameters())]

# Running the Embeddings

In [None]:
dataset = m65_dataset + h01_dataset
print(f"m65_dataset = {len(m65_dataset)}")
print(f"h01_dataset = {len(h01_dataset)}")
print(f"dataset = {len(dataset)}")
# mask_m65 = np.zeros(len(dataset))
# mask_m65[:len(m65_dataset)] = 1

# mask_h01 = np.zeros(len(dataset))
# mask_h01[len(m65_dataset)] = 1

In [None]:
batch_size = 64

all_data_loader = DataLoader(
    dataset, 
    batch_size=batch_size,
    shuffle = False)

all_data_loader

In [None]:
device = "cpu"
model.eval()
embeddings = []
labels = []
data_names = []
data_sources = []
for data in tqdm(all_data_loader):#train_loader:  # Iterate in batches over the training dataset.
    data = data.to(device)
    if model_name == "DiffPool":
            out,gnn_loss, cluster_loss = model(data)  # Perform a single forward pass.
            #y_true = data.y.reshape(-1,3)
    elif model_name == "TreeLSTM":
        n = data.x.shape[0]
        h = torch.zeros((n, architecture_kwargs["n_hidden_channels"]))
        c = torch.zeros((n, architecture_kwargs["n_hidden_channels"]))
        out = model(
            data,
            h = h,
            c = c,
            embeddings = data.x
            )
    else:
        out = model(data)

    out_array = out.detach().cpu().numpy()
    out_labels = data.y.numpy().reshape(-1)
    #print(f"out_array.shape = {out_array.shape}, out_labels.shape = {out_labels.shape}")
    
#     if out_array.shape[0] != out_labels.shape[0]:
#         raise Exception("")
    
    embeddings.append(out_array)
    labels.append(out_labels)
    data_names.append(data.data_name)
    data_sources.append(data.data_source)
    
    
embeddings = np.vstack(embeddings)
labels = np.hstack(labels)
data_names = np.hstack(data_names)
data_sources = np.hstack(data_sources)

In [None]:
embedding_df = pd.DataFrame(embeddings)
embedding_df["cell_type"] = labels
embedding_df["cell_type_predicted"] = np.argmax(embeddings,axis=1)
embedding_df["data_name"] = data_names
embedding_df["data_source"] = data_sources

import general_utils as gu
decoder_map = dict([(v,k) if k is not None else (v,"Unknown") for k,v in cell_type_map.items()])

import pandas_utils as pu
embedding_df["cell_type"] = pu.new_column_from_dict_mapping(
    embedding_df,
    decoder_map,
    column_name = "cell_type"
)

def e_i_label(row):
    ct = row["cell_type"]
    if ct is None:
        return ct
    
    return ctu.e_i_label_from_cell_type_fine(ct)

embedding_df["e_i"] = pu.new_column_from_row_function(
    embedding_df,
    e_i_label,
)

embedding_df["cell_type_predicted"] = pu.new_column_from_dict_mapping(
    embedding_df,
    decoder_map,
    column_name = "cell_type_predicted"
)


def e_i_label_predicted(row):
    ct = row["cell_type_predicted"]
    if ct is None:
        return ct
    
    return ctu.e_i_label_from_cell_type_fine(ct)

embedding_df["e_i_predicted"] = pu.new_column_from_row_function(
    embedding_df,
    e_i_label_predicted,
)


In [None]:
embedding_df

# Step 5: Doing the Predictions

In [None]:
import string_utils as stru
embed_cols = [k for k in embedding_df.columns if "int" in str(type(k))]#stru.is_int(k)]
np.array(embed_cols)

# Step 6: Plotting the Embeddings

In [None]:
"""
Get the mask of m65 vs h01
extract the data the data names
Collect the X value and the y values

-> there might be certain masks want to apply
"""

In [None]:
import datajoint_utils as du
import cell_type_utils as ctu

In [None]:
%matplotlib inline
import visualizations_ml as vml
n_components = 3
import dimensionality_reduction_ml as dru
import pandas_ml as pdml


In [None]:
embedding_df_known = embedding_df.query("(cell_type != 'Unknown') and (cell_type != 'Unsure')").reset_index(drop=True)
embedding_df_known

# a) UMAP on embedding (0.5 min dist)

In [None]:
method = "UMAP"
kwargs = dict(n_components =2,min_dist = 0.5,)

In [None]:
#mask_before_trans = np.ones(len(X_trans)).astype("bool")
data_source = "h01"

if data_source is not None:
    df_input = embedding_df_known.query(f"data_source == '{data_source}'").reset_index(drop=True)#.#.query("")
else:
    df_input = embedding_df_known

In [None]:
X_trans = dru.dimensionality_reduction_by_method(
        method=method,
        X = df_input[embed_cols].to_numpy().astype("float"),
        **kwargs
        )

In [None]:
trans_cols = [f"{method}_{k}" for k in range(X_trans.shape[1])]
df_input = pd.concat([df_input,pd.DataFrame(X_trans,columns = trans_cols)],axis = 1)
df_input

In [None]:
df_plot = df_input
df_plot

In [None]:
vml.plot_df_scatter_classification(
                X = df_plot[trans_cols].to_numpy().astype("float"),
                y = df_plot["cell_type"].to_numpy(),
                target_to_color = ctu.cell_type_fine_color_map,
                ndim = len(trans_cols),
                title=method,
                use_labels_as_text_to_plot=True,
            )

# Sampling the Space

In [None]:
from dataInterfaceH01 import data_interface as hdju_h01

In [None]:
data_source = "h01"
embedding_df_known.query("data_source == 'h01'").query("cell_type_predicted=='Martinotti'")

In [None]:
embedding_df_known.query("data_source == 'h01'")

In [None]:
# finding where disagrees
embedding_df_known.query("data_source == 'h01'").query("e_i != e_i_predicted").query("cell_type=='Unsure E'")

In [None]:
from importlib import reload
import trimesh_utils as tu
import human_utils as hu
tu = reload(tu)

# Visualizing Wrong ones

In [None]:
"""
Excitatory's that are wrong are mostly the aspiny

"""

In [None]:
hdju_h01.plot_axon_dendrite_skeletons("46836446896_0")

# Finding the 5 closest cells

In [None]:
data_source = "m65"
data_source_df = embedding_df_known.query(f"data_source == '{data_source}'").query("cell_type=='BC'")
data_source_df

In [None]:
node_name = "864691134884769786_0" # IT cell

X = data_source_df[embed_cols]
data_point = data_source_df.query(f"data_name=='{node_name}'")[embed_cols].to_numpy()
data_point

In [None]:
from dataInterfaceMinnie65 import data_interface as hdju_m65
import gnn_embedding_utils as gnneu

node_name = "864691136925825354_0"
node_name = "864691135272313361_0"

gnneu.closest_neighbors_in_embedding_df(
    df = data_source_df,
    data_name = node_name,
    n_neighbors = 5,
    verbose = True,
    plot = True,
    plotting_func = hdju_m65.plot_axon_dendrite_skeletons,
    
    )