## Local Effects

One of the more important aspects of random forest nodes, and by extension node clusters, is that they describe what we would call "Local Effects"

While a conventional linear regression might describe a linear relationship between the behavior of a feature and a target that is true across the entire dataset, a node in a random forest may just as easily be a child of another node, and thus only trained on a small part of the dataset. Therefore a relationship that it describes between a feature and a target may be true across the entire dataset, or it may only be true conditionally on the predictions made by the parents of the node.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import sys
# sys.path.append('/localscratch/bbrener1/rusty_forest_v3/src')
sys.path.append('../src')
import tree_reader as tr 
import lumberjack

import pickle 

data_location = "../data/aging_brain/"

young = pickle.load(open(data_location + "aging_brain_young.pickle",mode='rb'))
old = pickle.load(open(data_location + "aging_brain_old.pickle",mode='rb'))

forest = tr.Forest.load(data_location + 'cv_forest_trimmed_extra')
forest.arguments

In [None]:
filtered_feature_mask = np.zeros(2000,dtype=bool)

for feature in forest.output_features:
    f_i = list(young.var_names).index(feature)
    filtered_feature_mask[f_i] = True
    
young_filtered = young[:,filtered_feature_mask]
young_filtered.shape

In [None]:
# forest.reset_split_clusters()

# forest.interpret_splits(
#     depth=8,
#     mode='additive_mean',
#     metric='cosine',
#     pca=100,
#     relatives=True,
#     k=50,
#     resolution=2,
# )

# print(len(forest.split_clusters))

In [None]:
# forest.maximum_spanning_tree(mode='samples')

In [None]:
# forest.html_tree_summary(n=10)

In [None]:
# forest.backup(data_location + "full_clustering")

In [None]:
# We now would like to see if there are any local associations that are dramatically different
# from global ones, to the degree that it is impossible to recapture them using PCA-based analysis. 

# We will need to perform a PCA analysis first. 

from sklearn.decomposition import PCA

model = PCA(n_components=40).fit(young.X)
transformed = model.transform(young.X)
recovered = model.inverse_transform(transformed)

centered = young.X - np.mean(young.X,axis=0)
null_squared_residual = np.power(centered,2)

recovered_residual = young.X - recovered
recovered_squared_residual = np.power(recovered_residual,2)

pca_recovered_per_sample = np.sum(recovered_squared_residual,axis=1)
pca_recovered_fraction_per_sample = np.sum(recovered_squared_residual,axis=1) / np.sum(null_squared_residual,axis=1)
print(np.sum(null_squared_residual))
print(np.sum(recovered_squared_residual))

print(f"Remaining variance:{(np.sum(recovered_squared_residual) / np.sum(null_squared_residual))}")

In [None]:
# Here we specify two interesting features and see what the weights for them are in each PC

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_loadings = model.components_[:,f1_index]
f2_loadings = model.components_[:,f2_index]

# plt.figure()
# plt.scatter(forest.output[:,f1_index],forest.output[:,f2_index],s=1)
# plt.show()

plt.figure()
plt.title(f"PC Loadings of {f1} vs {f2}")
plt.scatter(f1_loadings,f2_loadings)
plt.xlabel(f"{f1} weight")
plt.ylabel(f"{f2} weight")
plt.plot([.2,-.2],np.array([-.2,.2])*.55,color='red',label="Slope of -.55")
plt.legend()
plt.show()

for i,pc in enumerate(model.components_):
    print(f"PC:{i}, {f1}:{pc[f1_index]}, {f2}:{pc[f2_index]}")

In [None]:
# Here we visualize the loadings of each PC to get a sense for where the PC is making meaningful predictions. 
# (It may give us a hint as to whether or not it specifies a particular cell type)

for i,pc in enumerate(transformed.T):
    plt.figure()
    plt.title(f"PC {i} Loadings")
    ab_max = np.max(np.abs(pc))
    plt.scatter(*forest.tsne_coordinates.T,c=pc,s=3,alpha=.4,cmap='bwr',vmin=-ab_max,vmax=ab_max)
    plt.xlabel("UMAP Embedding, (AU)")
    plt.ylabel("UMAP Embedding, (AU)")
    plt.colorbar()
    plt.show()

In [None]:
# Now we will look for features that have an especially large discrepancy in the local 
# correlation compared to the global correlation for each factor. 

for factor in forest.split_clusters:
    print("=====================================")
    print(factor.name())
    print("=====================================")
    fi_pairs = factor.most_local_correlations()
    features = forest.output_features
    f_names = [(features[i],features[j]) for (i,j) in fi_pairs]
    local_correlations = factor.local_correlations()
    global_correlations = forest.global_correlations()
    discrepancy = [(local_correlations[i,j],global_correlations[i,j]) for (i,j) in fi_pairs]
    print(f_names)
    print(discrepancy)

In [None]:
interesting_pairs = []

for factor in forest.split_clusters:
    interesting_pairs.extend(factor.most_local_correlations(n=1))
    
uniques = list(set([y for x in interesting_pairs for y in x]))
  
factor_correlation_table = np.zeros((len(interesting_pairs),len(forest.split_clusters)))

for i,factor in enumerate(forest.split_clusters):
    local_correlations = factor.local_correlations(indices=uniques)
    for j,(f1,f2) in enumerate(interesting_pairs):
        f1_u = uniques.index(f1)
        f2_u = uniques.index(f2)
        factor_correlation_table[j,i] = local_correlations[f1_u,f2_u]

plt.figure()
plt.imshow(factor_correlation_table,interpolation='none',aspect='auto',cmap='bwr',vmin=-1,vmax=1)
plt.colorbar()
plt.show()

from scipy.cluster.hierarchy import linkage,dendrogram

# factor_agglomeration = dendrogram(linkage(factor_correlation_table, metric='cosine', method='average'), no_plot=True)['leaves']

# plt.figure()
# plt.imshow(factor_correlation_table.T[factor_agglomeration].T,interpolation='none',aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

print([(x,y) for x,y in enumerate(interesting_pairs)])

In [None]:
forest.output_features[1639]

print(forest.split_clusters[23].local_correlations(indices=[717,1639]))
print(forest.split_clusters[20].local_correlations(indices=[717,1639]))

# cluster 23, Rrares2 (717), Meg3 (1639)

In [None]:
# Here we check the naive linear fit between two features (eg a simple correlation among all cells)

from scipy.stats import linregress

f1 = "Ctsd"
f2 = "H2-Ab1"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index]
f2_values = forest.output[:,f2_index]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure(figsize=(3,2.5))
plt.title(f"Linar Fit, {f1}, {f2}, Naive")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.legend()
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.show()

In [None]:
# Here we filter only for cells that have a high or low sister score for a particular factor
# and linearly regress two genes to check for a "local" association. 


from scipy.stats import linregress

factor = forest.split_clusters[34]
factor_threshold = .2
factor_mask = np.abs(factor.sister_scores() > factor_threshold)

plt.figure()
plt.title(f"Sister scores, {factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.legend()
plt.show()

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index][factor_mask]
f2_values = forest.output[:,f2_index][factor_mask]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure()
plt.title(f"Linar Fit, {f1}, {f2}, Factor {factor.name()}, Filtered")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.legend()
plt.show()



In [None]:
# Here we filter only for cells that have a high or low sister score for a particular factor
# and linearly regress two genes to check for a "local" association. 


from scipy.stats import linregress

factor = forest.split_clusters[41]
factor_threshold = .05
factor_mask = np.abs(factor.sister_scores() > factor_threshold)

plt.figure()
plt.title(f"Sister scores, {factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.legend()
plt.show()

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index][factor_mask]
f2_values = forest.output[:,f2_index][factor_mask]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure()
plt.title(f"Linar Fit, {f1}, {f2}, Factor {factor.name()}, Filtered")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.legend()
plt.show()



In [None]:
# Here we find highly weighted genes for a particular PC, as well as the rankings of particular features of interest
# Our objective is to see if the two featurs represent an important part of the variance captured by the PC
pc = 8

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

weights = model.components_[pc]

weight_sort = np.argsort(np.abs(weights))

print(list(forest.output_features[weight_sort[:-20:-1]]))
print(list(weights[weight_sort[:-20:-1]]))

print(f"{f1}: {len(weights) - list(weight_sort).index(f1_index)}")
print(f"{f2}: {len(weights) - list(weight_sort).index(f2_index)}")

print(weights[f1_index])
print(weights[f2_index])

In [None]:
for s_c in forest.split_clusters:
    scores = s_c.sister_scores()
    log_scores = s_c.log_sister_scores(prior=10)

    abmax=np.max(np.abs(scores))

    plt.figure()
    plt.title("Regular")
    plt.scatter(*forest.tsne_coordinates.T,c=scores,cmap='bwr',s=1,vmin=-abmax,vmax=abmax)
    plt.colorbar()
    plt.show()

    abmax=np.max(np.abs(log_scores))

    plt.figure()
    plt.title("Log")
    plt.scatter(*forest.tsne_coordinates.T,c=log_scores,cmap='bwr',s=1,vmin=-abmax,vmax=abmax)
    plt.colorbar()
    plt.show()

In [None]:
factor = forest.split_clusters[34]

samples = factor.sample_scores()
sisters = factor.sister_scores()
log_sisters = factor.log_sister_scores()

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of Sample Scores In {factor.name()}")
plt.hist(samples,bins=50)
plt.ylabel("Frequency")
plt.xlabel("Sample Scores")
plt.show()

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of Sister Scores In {factor.name()}")
plt.hist(sisters,bins=50)
plt.ylabel("Frequency")
plt.xlabel("Sister Scores")
plt.show()

plt.figure()
plt.hist(log_sisters,bins=50)
plt.show()




In [None]:
# Here we test whether or not a particular factor over-expresses a gene of interest
# (Used as a statistical test for cell type identity, eg "is factor 34 immune cells?")
from scipy.stats import ttest_ind

factor = forest.split_clusters[34]
factor_threshold = .2
mask = factor.sister_scores() > factor_threshold

feature = "Cd45"

f_index = forest.truth_dictionary.feature_dictionary[f1]

test = ttest_ind(young.X[mask][:,f_index],young.X[~mask][:,f_index],equal_var=False)

print(f"{feature} in {factor.name()} vs all other: {test}")


In [None]:
from scipy.stats import ttest_ind

factor = forest.split_clusters[34]
factor_threshold = .2
mask = factor.sister_scores() > factor_threshold

feature = "C1qa"

f_index = forest.truth_dictionary.feature_dictionary[f1]

test = ttest_ind(young.X[mask][:,f_index],young.X[~mask][:,f_index],equal_var=False)

print(f"{feature} in {factor.name()} vs all other: {test}")


In [None]:
feature = "Tmem119"

f_index = forest.truth_dictionary.feature_dictionary[feature]

plt.figure()
plt.scatter(*forest.tsne_coordinates.T,c=forest.output[:,f_index],s=1)
plt.colorbar()
plt.show()

In [None]:
feature = "Cd74"

f_index = forest.truth_dictionary.feature_dictionary[feature]

plt.figure()
plt.scatter(forest.split_clusters[24].sister_scores(),forest.output[:,f_index],s=1)
plt.show()

In [None]:
plt.figure(figsize=(3,2.5))
plt.title("Cd45 Mean Expression")
plt.bar([0,1],[0.0030044015,0.3989004],yerr=[0.0006448814299982911,0.017244088969228192],width=.5,tick_label=["Rest","NC 34"])
plt.ylabel("Mean Expression (Log TPM)")
plt.show()