## Existence Of Node Clusters

Here we demonstrate that in random forest that has been trained on some set of data, the nodes can be reasonably organized into clusters.

First, we must train or load a forest:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import sys
# sys.path.append('/localscratch/bbrener1/rusty_forest_v3/src')
sys.path.append('../src')
import tree_reader as tr 
import lumberjack

data_location = "../data/aging_brain/"

forest = tr.Forest.load(data_location + 'full_clustering')
forest.arguments


In [None]:
len(forest.output_features)

A Random Forest is a collection of decision trees, and a decision tree is a collection of individual decision points, commonly known as "Nodes"

To understand Random Forests and Decision Trees, it is important to understand how Nodes work. Each individual node is a (very crappy) regressor, eg. each Node makess a prediction based on a rule like "If Gene 1 has expression > 10, Gene 2 will have expression < 5", or "If a house is < 5 miles from a school, it will cost > $100,000". A very important property of each node, however, is that it can also have children, which are other nodes. When a node makes a prediction like "If Gene 1 has expression > 10 then Gene 2 has expression < 5", it can pass all the samples for which Gene 1 is > 10 to one of its children, and all the samples for which Gene 1 < 10 to the other child. After that, each one of its children can make a different prediction, which results in compound rules.

This is how a decision tree is formed. A decision tree with a depth of 2 might contain a rule like "If Gene 1 > 10 AND Gene 3 > 10, THEN Gene 2 and Gene 4 are both < 2, which would represent one of the "Leaf" nodes that it has. Leaf nodes are nodes with no children. 

Individual decision trees, then, are somewhat crappy predictors, but they're better than individual nodes. In order to improve the performance of decision trees, we can construct a Random Forest. To construct a random forest, we can train many decision trees on bootstraps of a dataset

If many decision trees are combined and their predictions averaged together, you have a Random Forest, which is a pretty good kind of regressor. 

A practical demonstration might help:

In [None]:
forest.reset_split_clusters()
forest.interpret_splits(depth=5,mode='additive_mean',neighborhood_fraction=.1,metric='cosine',pca=100,relatives=True,k=100,resolution=2)

So now that we know that random forests are collections of ordered nodes, we can examine a more interesting question: do certain nodes occur repeatedly in the forest, despite operating on bootstrapped samples? 

In order to examine this question first we must understand different ways of describing a node. I think generally there are three helpful ways of looking at a node:

* **Node Sample Encoding**: A binary vector the length of the number of samples you are considering. 0 or false means the sample is absent from the node. A 1 or true means the sample is present in the node. 

* **Node Mean Encoding**: A float vector the length of the number of targets you are considering. Each value is the mean of the target values for all samples in this node. This is the node's prediction for samples that occur in it.

* **Node Additive Encoding**: A float vector the length of the number of targets you are considering. Each value is THE DIFFERENCE between the mean value for that target in THIS NODE and the mean value for that target IN THE PARENT of this node. For root nodes, which have no parents, the additive encoding is simply th mean value across the entire dataset. (As if the mean of a hypothetical parent would have been 0). This encoding represents the marginal effect of each node.

We should examine if there are any common patterns that appear if we encode many nodes from a forest using each of these representations:

In [None]:
# Here we plot the sample representations of nodes. 
# This generates a set of figures demonstrating the existence of node clusters

from sklearn.decomposition import PCA

sample_encoding = forest.node_representation(forest.nodes(depth=5,root=False),mode='sample')
reduced_sample = PCA(n_components=100).fit_transform(sample_encoding.T)
reduced_node = PCA(n_components=100).fit_transform(sample_encoding)

print(sample_encoding.shape)
print(reduced_sample.shape)
print(reduced_node.shape)

from scipy.cluster.hierarchy import linkage,dendrogram

sample_agglomeration = dendrogram(linkage(reduced_sample, metric='cosine', method='average'), no_plot=True)['leaves']
node_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']

plt.figure()
plt.title("Figure 1: Sample Presence in Node (Two-Way Agglomerated)")
plt.imshow(sample_encoding[node_agglomeration].T[sample_agglomeration].T,cmap='binary',aspect='auto',interpolation='none')
plt.xlabel("Samples")
plt.ylabel("Nodes")
plt.colorbar()
plt.tight_layout()
plt.show()

# And here we sort the nodes after they have been clustered (more on the clustering procedure in a bit)

node_cluster_sort = np.argsort([n.split_cluster for n in forest.nodes(depth=5,root=False)])

plt.figure()
plt.title("Figure S1: Sample Presence in Node (Clustered)")
plt.imshow(sample_encoding[node_cluster_sort].T[sample_agglomeration].T,cmap='binary',aspect='auto',interpolation='none')
plt.xlabel("Samples")
plt.ylabel("Nodes")
plt.colorbar()
plt.tight_layout()
plt.show()


In [None]:
from sklearn.decomposition import PCA

sample_encoding = forest.node_representation(forest.nodes(depth=5,root=False),mode='sister')
reduced_sample = PCA(n_components=100).fit_transform(sample_encoding.T)
reduced_node = PCA(n_components=100).fit_transform(sample_encoding)

print(sample_encoding.shape)
print(reduced_sample.shape)
print(reduced_node.shape)

from scipy.cluster.hierarchy import linkage,dendrogram

sample_agglomeration = dendrogram(linkage(reduced_sample, metric='cosine', method='average'), no_plot=True)['leaves']
node_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']

cluster_node_sort = np.argsort([n.split_cluster for n in forest.nodes(depth=5,root=False)])

plt.figure()
plt.title("Figure 1SC: Sample Presence in Node vs Sister (Two-Way Agglomerated)")
plt.imshow(sample_encoding[node_agglomeration].T[sample_agglomeration].T,cmap='bwr',aspect='auto',interpolation='none')
plt.xlabel("Samples")
plt.ylabel("Nodes")
plt.colorbar()
plt.tight_layout()
plt.show()

plt.figure()
plt.title("Figure 1SC: Sample Presence in Node vs Sister (Clustered By Gain)")
plt.imshow(sample_encoding[cluster_node_sort].T[sample_agglomeration].T,cmap='bwr',aspect='auto',interpolation='none')
plt.xlabel("Samples")
plt.ylabel("Nodes")
plt.colorbar()
plt.tight_layout()
plt.show()


In [None]:
# Here we plot the construct and agglomerate the additive gain representation 


feature_encoding = forest.node_representation(forest.nodes(depth=5,root=False),mode='additive_mean')
reduced_feature = PCA(n_components=100).fit_transform(feature_encoding.T)
reduced_node = PCA(n_components=100).fit_transform(feature_encoding)

feature_agglomeration = dendrogram(linkage(reduced_feature, metric='cosine', method='average'), no_plot=True)['leaves']
node_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']


In [None]:
# Here we plot the additive gain representation 

print(feature_encoding.shape)

plt.figure()
plt.title("Figure S2 a: Target Gain in Node (Double-Agglomerated)")
plt.imshow(feature_encoding[node_agglomeration].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
plt.xlabel("Features")
plt.ylabel("Nodes")
plt.colorbar(label="Parent Target Mean - Node Target Mean")
plt.tight_layout()
plt.show()

plt.figure()
plt.title("Figure S2 b: Target Gain in Node (Clustered)")
plt.imshow(feature_encoding[node_cluster_sort].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
plt.xlabel("Features")
plt.ylabel("Nodes")
plt.colorbar(label="Parent Target Mean - Node Target Mean")
plt.tight_layout()
plt.show()

Finally we can look at silhouette plots scores for various node encodings in order to get a feel for whether or not we are adequately clustering them and whether or not the clusters meaningfully exist. 

In [None]:
# Silhouette Plots For Node Clusters 

from sklearn.metrics import silhouette_samples, silhouette_score

node_labels = np.array([n.split_cluster for n in forest.nodes(depth=5,root=False)])

# silhouette_scores = silhouette_samples(reduced_node,node_labels,metric='cosine')
silhouette_scores = silhouette_samples(feature_encoding,node_labels,metric='cosine')
# silhouette_scores = silhouette_samples(sample_encoding,node_labels,metric='cosine')

sorted_silhouette = np.zeros(silhouette_scores.shape)
sorted_colors = np.zeros(silhouette_scores.shape)

current_index = 0
next_index = 0
for i in sorted(set(node_labels)):
    mask = node_labels == i
    selected_values = sorted(silhouette_scores[mask])    
    next_index = current_index + np.sum(mask)
    sorted_silhouette[current_index:next_index] = selected_values
    sorted_colors[current_index:next_index] = i
    current_index = next_index

In [None]:
import matplotlib.cm as cm

plt.figure()
plt.title("Silhouette Plots For Nodes Clustered By Gain")
for i,node in enumerate(sorted_silhouette):
    plt.plot([0,node],[i,i],color=cm.nipy_spectral(sorted_colors[i] / len(forest.split_clusters)))
# plt.scatter(sorted_silhouette,np.arange(len(sorted_silhouette)),s=1)
plt.plot([0,0],[0,len(sorted_silhouette)],color='red')
plt.xlabel("Silhouette Score")
plt.ylabel("Nodes")
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import pickle 

data_location = "../data/aging_brain/"

young = pickle.load(open(data_location + "aging_brain_young.pickle",mode='rb'))
old = pickle.load(open(data_location + "aging_brain_old.pickle",mode='rb'))

filtered = pickle.load(open(data_location + "aging_brain_filtered.pickle",mode='rb'))

batch_encoding = np.loadtxt(data_location + 'aging_batch_encoding.tsv')
batch_encoding = batch_encoding.astype(dtype=bool)

young_mask = np.zeros(37069,dtype=bool)
old_mask = np.zeros(37069,dtype=bool)

young_mask[:young.shape[0]] = True
old_mask[young.shape[0]:] = True

In [None]:
forest.maximum_spanning_tree(mode='samples')

In [None]:
forest.html_tree_summary(n=3)

In [None]:
trans =  forest.split_cluster_transition_matrix(depth=10)

plt.figure()
plt.title("Node Cluster Transition Frequency")
plt.imshow(trans[:-1],cmap='binary',interpolation='none')
plt.xlabel("Destination")
plt.ylabel("Origin")
plt.colorbar(label="Frequency")
plt.show()

In [None]:
np.sum(trans[18])

In [None]:
plt.figure()
plt.bar(np.arange(50),trans[:,37])

In [None]:
np.sum(trans[:,37])

In [None]:
print(trans[45,37])
print(trans[23,37])
print(trans[37,37])
# print(trans[34,9])
# print(trans[34,24])

In [None]:
990 - 677 

In [None]:
99/313