In [None]:
# TODO: PCA partitions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import pickle 

data_location = "../data/aging_brain/"

young = pickle.load(open(data_location + "aging_brain_young.pickle",mode='rb'))
old = pickle.load(open(data_location + "aging_brain_old.pickle",mode='rb'))

filtered = pickle.load(open(data_location + "aging_brain_filtered.pickle",mode='rb'))

batch_encoding = np.loadtxt(data_location + 'aging_batch_encoding.tsv')
batch_encoding = batch_encoding.astype(dtype=bool)

young_mask = np.zeros(37069,dtype=bool)
old_mask = np.zeros(37069,dtype=bool)

young_mask[:young.shape[0]] = True
old_mask[young.shape[0]:] = True

In [None]:
# sc.pp.neighbors(filtered)
# sc.tl.umap(filtered)
# sc.tl.louvain(filtered)
# sc.pl.umap(filtered,color='louvain')


In [None]:
# We must find all ways of combining 8 mice into 2 partitions of 4. 

# First find the binary representation of all numbers up to 256.

up = np.unpackbits(np.arange(256).astype(dtype='uint8')).reshape((256,8))

# Find out how many digits are 1

sums = np.sum(up,axis=1)

# Then select all representations where only 4 positions are set to 1.

partitions = up[:128][sums[:128] == 4]

partitions.shape

In [None]:
# Now we must select features that are reasonably predictable within the dataset

In [None]:
import sys
# sys.path.append('/localscratch/bbrener1/rusty_forest_v3/src')
sys.path.append('../src')
import tree_reader as tr 
import lumberjack



selection_forest = lumberjack.fit(
    young.X,
    header=filtered.var_names,
    trees=100,
    braids=2,
    ifs=700,
    ofs=700,
    ss=500,
    depth=8,
    leaves=10,
    sfr=0,
    norm='l1',
    reduce_input='true',
    reduce_output='false'
)

In [None]:
selection_forest.set_cache(True)

In [None]:
selection_forest.self_prediction = selection_forest.predict(young.X)


In [None]:
feature_residuals = selection_forest.self_prediction.residuals()


In [None]:
centered = young.X - np.mean(young.X,axis=0)
null_squared_residual = np.power(centered,2)
feature_null = np.sum(null_squared_residual,axis=0) + 1

forest_squared_residuals = np.power(feature_residuals,2)

forest_feature_error = np.sum(forest_squared_residuals,axis=0) + 1
forest_feature_remaining = forest_feature_error/feature_null


In [None]:
plt.figure()
plt.hist(forest_feature_remaining.copy(),bins=50)
plt.show()

filtered_feature_mask = forest_feature_remaining < .5

print(np.sum(filtered_feature_mask))

In [None]:
filtered_feature_mask = np.zeros(2000,dtype=bool)

for feature in cv_forest.output_features:
    f_i = list(young.var_names).index(feature)
    filtered_feature_mask[f_i] = True

In [None]:
# Now we must take the predictable features and train cross-validated forests on them

In [None]:
young_filtered = young[:,filtered_feature_mask]
young_filtered.shape

In [None]:
cv_forest = lumberjack.fit(
    young_filtered.X,
    header=young_filtered.var_names,
    trees=100,
    braids=2,
    ifs=150,
    ofs=150,
    ss=500,
    depth=8,
    leaves=10,
    sfr=0,
    norm='l1',
    reduce_input='true',
    reduce_output='false'
)

In [None]:
cv_forest.set_cache(True)
cv_forest.backup(data_location + "cv_forest_trimmed_extra")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import sys
# sys.path.append('/localscratch/bbrener1/rusty_forest_v3/src')
sys.path.append('../src')
import tree_reader as tr 
import lumberjack

cv_forest = tr.Forest.load(data_location + 'cv_forest_trimmed_extra')


In [None]:
cv_forest.reset_split_clusters()
cv_forest.interpret_splits(
    k=10,
    pca=False,
    depth=8,
    metric="cosine",
    mode='additive_mean',
    relatives=True
)

In [None]:
cv_forest.maximum_spanning_tree(mode='samples')

In [None]:
cv_forest.tsne_coordinates = filtered.obsm['X_umap'][young_mask]
cv_forest.html_tree_summary(n=10)

In [None]:
cv_forest.self_prediction = cv_forest.predict(cv_forest.output)
cv_forest.self_prediction.node_sample_encoding()

In [None]:
cv_forest.self_prediction.prediction_report()

In [None]:
cv_forest.old_prediction = cv_forest.predict(old.X.T[filtered_feature_mask].T)
cv_forest.old_prediction.prediction_report()


In [None]:

# for i,partition in enumerate(partitions):
#     partition = np.array(partition).astype(dtype=bool)
#     partition_mask = np.any(batch_encoding[young_mask,:8][:,partition],axis=1)
#     sample_indices = np.arange(young_filtered.X.shape[0])[partition_mask]
#     sub_forest = cv_forest.derive_samples(sample_indices)
#     sub_forest.backup(f"sub_forest_{i}")
#     del(sub_forest)

# for i,partition in list(enumerate(partitions)):
#     partition = np.array(partition).astype(dtype=bool)
#     partition_mask = np.any(batch_encoding[young_mask,:8][:,partition],axis=1)
#     sub_forest = tr.Forest.load(f"sub_forest_{i}")
#     sub_forest.self_prediction = sub_forest.predict(young.X[partition_mask])
#     sub_forest.self_prediction.node_sample_encoding()
#     sub_forest.self_prediction.node_sample_r2()
#     sub_forest.test_prediction = sub_forest.predict(young.X[~partition_mask])
#     sub_forest.test_prediction.node_sample_encoding()
#     sub_forest.test_prediction.node_sample_r2()
#     sub_forest.backup(f"sub_forest_{i}_cached")
#     del(sub_forest)

determination_spread = np.zeros((39,20))
self_determination_spread = np.zeros((39,20))
other_determination_spread = np.zeros((39,20))

for i,partition in enumerate(partitions[:21]):
    print("+++++++++++++++++++")
    print(i)
    print("+++++++++++++++++++")
    partition = np.array(partition).astype(dtype=bool)
    partition_mask = np.any(batch_encoding[young_mask,:8][:,partition],axis=1)
    sub_forest = tr.Forest.load(data_location + "restricted_sub_forest/" + f"sub_forest_{i}_cached")
    for factor_index in range(1,39):
        factor = sub_forest.split_clusters[factor_index]
        self_fvu,other_fvu,_ = sub_forest.self_prediction.compare_factor_fvu(sub_forest.test_prediction,factor)
        other_determination_spread[factor_index,i] = 1-other_fvu
        delta = other_fvu - self_fvu
        determination_spread[factor_index,i] = delta
    del(sub_forest)

In [None]:
mean_spread = np.mean(determination_spread,axis=1)
spread_variance = np.var(determination_spread,axis=1)
mean_sort = np.argsort(mean_spread)
var_sort = np.argsort(spread_variance)

plt.figure()
for i in range(39):
    plt.scatter(np.array([i,]*20),determination_spread[var_sort[i]],s=2,c='blue')
plt.scatter(np.arange(39),mean_spread[var_sort],c='red')
plt.xticks(np.arange(39),var_sort,rotation=90)
plt.show()    


plt.figure()
plt.title("Mean Prediction Error Young vs Young and Young vs Old")
for i in range(39):
    plt.scatter(np.array([i,]*20),determination_spread[mean_sort[i]],s=2,c='blue')
plt.scatter(np.arange(39),mean_spread[mean_sort],c='red')
plt.xticks(np.arange(39),mean_sort,rotation=90)
plt.show()    

print(list(enumerate(np.sqrt(spread_variance))))

In [None]:
for factor in cv_forest.split_clusters[1:]:
    print("+++++++++++++++++++++")
    print(f"Factor {factor.name()}")
    print("+++++++++++++++++++++")
    cv_forest.self_prediction.compare_factor_fvu(cv_forest.old_prediction, factor)

In [None]:
len(cv_forest.split_clusters)

In [None]:
self_fvu = [1.,]
old_fvu = [1.,]

for factor in cv_forest.split_clusters[1:]:
    print("++++++++++++++++++++")
    print(factor.name())
    print("++++++++++++++++++++")
    (sfvu,olfvu,olmwu) = cv_forest.self_prediction.compare_factor_fvu(cv_forest.old_prediction,factor,plot=True)
    self_fvu.append(sfvu)
    old_fvu.append(olfvu)

self_cod = 1 - np.array(self_fvu) 
old_cod = 1 - np.array(old_fvu)
# other_cod = 1 - np.array(other_fvu)


In [None]:

old_delta = self_cod - old_cod
delta_sort = np.argsort(old_delta)

print(list(zip(delta_sort,old_delta[delta_sort])))

# selected_labels = set([7,27,16,35])

plt.figure()
plt.title("Coefficient of Determination Ratio")
plt.plot([0,.5],[0,.5],label="Ideal Fit", color='red')
for i,(p1,p2) in enumerate(zip(self_cod,old_cod)):
        plt.text(p1+.005,p2-.01,str(i),fontsize=5)

#     if i in selected_labels:
#         plt.text(p1+.005,p2-.01,str(i),fontsize=10)
#         plt.scatter([p1,],[p2,],s=10,color='blue')
plt.scatter(self_cod,old_cod,label="Old Mice",s=2,color='blue')
plt.legend()
plt.xlabel("Trained COD")
plt.ylabel("Observed COD")
plt.show()

In [None]:


old_delta = self_cod - old_cod
delta_sort = np.argsort(old_delta)

cod_sort = np.argsort(self_cod)

# self_min = self_cod - (1.5*mean_spread)
# self_max = self_cod + (1.5*mean_spread)

selected_labels = set([7,27,14,1,22])


print(list(zip(delta_sort,old_delta[delta_sort])))

plt.figure(figsize=(4,3))
plt.title("Coefficient of Determination Ratio, Young vs Old")
plt.plot([0,.5],[0,.5],label="Ideal Fit", color='red')
for i,(p1,p2) in enumerate(zip(self_cod,old_cod)):
    if i in selected_labels:
        plt.text(p1+.003,p2-.005,str(i),fontsize=10)
        plt.scatter([p1,],[p2,],color='green',s=10)
plt.scatter([],[],color='green',label='Selected Factors')
# plt.fill_between(self_cod[cod_sort],self_min[cod_sort],self_max[cod_sort])
plt.scatter(self_cod,old_cod,s=2)
plt.legend()
plt.xlabel("Trained COD")
plt.ylabel("Observed COD")
plt.xlim(0,.4)
plt.ylim(0,.4)
plt.show()

In [None]:
selected_labels = set([7,27,14,1,22])

plt.figure(figsize=(5,4))
plt.title("Mean Prediction Discrepancy,\n Young vs Young and Young vs Old")
for i in range(39):
    plt.scatter(np.array([i,]*20),determination_spread[mean_sort[i]],s=2,c='blue')

spread_min = np.min(determination_spread.flatten())    
for i in range(39):
    if mean_sort[i] in selected_labels:
        plt.scatter([i,],[spread_min,],marker="*",color='black')
plt.scatter([],[],color='black',label='Significant')
plt.scatter([],[],color='blue',label="Bootstrapped Young vs Young Δ COD") #phantom scatter to apply label of blue dots
plt.scatter(np.arange(39),mean_spread[mean_sort],c='red',label="Mean Δ COD Young vs Young")
plt.scatter(np.arange(39),old_delta[mean_sort],c='green',label="Mean Δ COD Young vs Old")
plt.xticks(np.arange(39),mean_sort,rotation=90,fontsize=8)
plt.xlabel("Factor")
plt.ylabel("Difference in Coefficient of Determination")
plt.legend()
plt.show()    


In [None]:
mean_determination = np.mean(other_determination_spread,axis=1)
mean_sort = np.argsort(mean_determination)


plt.figure(figsize=(5,4))
plt.title("Factor COD")
for i in range(39):
    plt.scatter(np.array([i,]*20),other_determination_spread[mean_sort[i]],s=1,c='blue')
    
for i in range(39):
    if mean_sort[i] in selected_labels:
        plt.scatter([i,],[spread_min,],marker="*",color='black')
plt.scatter([],[],color='black',label='Significant')

plt.scatter(np.arange(39),mean_determination[mean_sort],c='red',label="Young vs Young (Mean)",alpha=.5)
plt.scatter(np.arange(39),old_cod[mean_sort],label="Young Vs Old",c='green',alpha=.5)
plt.scatter([],[],color='blue',label="Bootstraps")
plt.xticks(np.arange(39),mean_sort,rotation=90,fontsize=8)
plt.legend()
plt.ylabel("COD")
plt.xlabel("Factor")
plt.show()    


In [None]:
f_i = 22
print(mean_determination[f_i])
print(np.min(other_determination_spread[f_i]))
print(np.max(other_determination_spread[f_i]))
print(old_cod[f_i])

In [None]:
delta_z = np.abs(old_delta/np.sqrt(spread_variance))
z_sort = np.argsort(delta_z[1:])
z_sort + 1

In [None]:
print(delta_z[1:][z_sort])

In [None]:
selected_labels = set([7,27,16,1,30,14,22,29])

plt.figure(figsize=(5,4))
plt.title("Δ COD vs Z Score")
plt.scatter(old_delta[1:],delta_z[1:],s=10)
for i,(p1,p2) in enumerate(zip(old_delta[1:],delta_z[1:])):
    if i+1 in selected_labels:
        plt.text(p1+.003,p2-.005,str(i+1),fontsize=10)
plt.ylabel("Z Score")
plt.xlabel("Δ COD")
plt.plot([-.14,.14,],[3.3,3.3],'--',label='Significance: \n P > .001',color='red')
plt.plot([.02,.02,],[0,16],'--',color='green',label='COD > .02%')
plt.plot([-.02,-.02,],[0,16],'--',color='lightgray',label='COD < .02%')
plt.xlim(-.14,.14)
plt.legend()
plt.show()

In [None]:
from matplotlib.colors import DivergingNorm

factor = cv_forest.split_clusters[7]
sister_scores = factor.sister_scores()

fig = plt.figure(figsize=(3,2.5))
plt.title(
    f"Distribution of Samples \nIn {factor.name()} (Red) vs Its Sisters (Blue)")
plt.scatter(*cv_forest.tsne_coordinates.T, s=1,
            alpha=.6, c=sister_scores, norm=DivergingNorm(0), cmap='bwr')
plt.colorbar(label="Sister Score")
plt.ylabel("tSNE Coordinates (AU)")
plt.xlabel("tSNE Coordinates (AU)")
plt.show()

In [None]:
self_fraction_spread = np.zeros((39,20))
other_fraction_spread = np.zeros((39,20))

for i,partition in enumerate(partitions[:20]):
    print("+++++++++++++++++++")
    print(i)
    print("+++++++++++++++++++")
    partition = np.array(partition).astype(dtype=bool)
    partition_mask = np.any(batch_encoding[young_mask,:8][:,partition],axis=1)
    sub_forest = tr.Forest.load(data_location + "restricted_sub_forest/" + f"sub_forest_{i}_cached")
    for factor_index in range(1,39):
        factor = sub_forest.split_clusters[factor_index]
        self_fraction,other_fraction,_ = sub_forest.self_prediction.compare_factor_fractions(sub_forest.test_prediction,factor)
        self_fraction_spread[factor_index,i] = self_fraction
        other_fraction_spread[factor_index,i] = other_fraction
    del(sub_forest)

In [None]:
mean_fraction = np.mean(other_fraction_spread,axis=1)
fraction_sort = np.argsort(mean_fraction)

plt.figure(figsize=(5,4))
plt.title("Factor COD")
for i in range(39):
    plt.scatter(np.array([i,]*20),other_fraction_spread[fraction_sort[i]],s=1,c='blue')
    
# for i in range(39):
#     if mean_sort[i] in selected_labels:
#         plt.scatter([i,],[spread_min,],marker="*",color='black')
# plt.scatter([],[],color='black',label='Significant')

plt.scatter(np.arange(39),mean_fraction[fraction_sort],c='red',label="Young vs Young (Mean)",alpha=.5)
plt.scatter([],[],color='blue',label="Bootstraps")
plt.scatter(np.arange(39),np.array(old_fractions)[fraction_sort],label="Old Fraction")
plt.xticks(np.arange(39),fraction_sort,rotation=90,fontsize=8)
plt.legend()
plt.ylabel("Fraction")
plt.xlabel("Factor")
plt.show()    


In [None]:
old_fractions = [0,]
for factor in cv_forest.split_clusters[1:]:
    self_fraction,old_fraction,_ = cv_forest.self_prediction.compare_factor_fractions(cv_forest.old_prediction,factor)
    old_fractions.append(old_fraction)

In [None]:
mouse_ratios = np.zeros((39,8))


for i,mouse in enumerate(batch_encoding[young_mask].T[:8]):
    for j,factor in enumerate(cv_forest.split_clusters[1:]):
        node_encoding = cv_forest.node_sample_encoding(factor.nodes)[mouse]
        parent_encoding = cv_forest.node_sample_encoding(factor.parents())[mouse]
        node_pop = np.sum(node_encoding,axis=0)
        parent_pop = np.sum(parent_encoding,axis=0)
        ratio = (node_pop+1)/(parent_pop+1)
        mean_ratio = np.mean(ratio)
        mouse_ratios[j+1,i] = mean_ratio

old_mouse_ratios = np.zeros((39,8))


for i,mouse in enumerate(batch_encoding[old_mask].T[8:]):
    for j,factor in enumerate(cv_forest.split_clusters[1:]):
        nodes = [n for n in factor.nodes if n.parent is not None]
        parents = [n.parent for n in nodes]
        node_indices = [n.index for n in nodes]
        parent_indices = [p.index for p in parents]
        node_encoding = cv_forest.old_prediction.node_sample_encoding()[node_indices].T[mouse]
        parent_encoding = cv_forest.old_prediction.node_sample_encoding()[parent_indices].T[mouse]
        node_pop = np.sum(node_encoding,axis=0)
        parent_pop = np.sum(parent_encoding,axis=0)
        ratio = (node_pop+1)/(parent_pop+1)
        mean_ratio = np.mean(ratio)
        old_mouse_ratios[j+1,i] = mean_ratio


In [None]:
plt.figure()
for mouse in mouse_ratios.T:
    plt.scatter(np.arange(39)-.1,mouse[fraction_sort],s=2,c='blue')
for mouse in old_mouse_ratios.T:
    plt.scatter(np.arange(39)+.1,mouse[fraction_sort],s=2,c='red')
plt.scatter([],[],c='blue',label="Young")
plt.scatter([],[],c='red',label='Old')
for i in range(39):
    plt.plot([i+.5,i+.5],[0,1],"--",linewidth=.5,color='lightgray')
plt.xticks(np.arange(39),labels=fraction_sort,rotation=90)
plt.legend()
plt.show()


In [None]:
from scipy.stats import mannwhitneyu

ratio_mwus = []

for i in range(1,39):
    mwu = mannwhitneyu(mouse_ratios[i],old_mouse_ratios[i])
    print("++++++++++++++++++")
    print(i)
    print(mwu)
    print("++++++++++++++++++")
    ratio_mwus.append(mwu)

In [None]:
mwu_sort = np.argsort([rm[1] for rm in ratio_mwus]) + 1

for i in range(39):
    print(mwu_sort[i])
    print(ratio_mwus[mwu_sort[i]-1])