In [None]:
from table_info import BiomTable, CSVTable
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display
import test_linear_clustering
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
import pandas as pd
from plot_definitions import is_ms, imsms_plots, imsms_qualitative_class
from coexclusion import counts_to_presence_absence, counts_to_dynamic_presence_absence, presence_absence_to_contingency, pairwise_eval
import os

from test_linear_clustering import recursive_subspacer, iterative_clustering, calc_projected, _fit_a_linear_subspace, _draw_subspace_line, _get_color
from linear_subspace_clustering import linear_subspace_clustering, calc_subspace_bases
from cluster_manager import ClusterManager, ClusterState, RecursiveClusterer, OUTLIER_CLUSTER_ID
import numpy as np
import matplotlib
import json
import scipy
import data_transform as transformer
import math

In [None]:
WOLTKA_METADATA_PATH="./woltka_metadata.tsv"
DATA_TRANSFORM_PATH="./results/jupyter_interactive_foo.json"

ALL_SPECIES_VECS = "./dataset/biom/species_vectors.biom"
SHARED_SPECIES_VECS = "./dataset/biom/species_vectors_shared.biom"
AKKER_TABLE="./dataset/biom/akkermansia_foobar.biom"

# BIOM_TABLE="./dataset/biom/imsms-combined-none.biom"
# BIOM_TABLE="./dataset/biom/finrisk-combined-none.biom"
# BIOM_TABLE="./dataset/biom/sol_public_99006-none.biom"
# BIOM_TABLE="./dataset/biom/bacteroides_isolates.biom"
# BIOM_TABLE="./dataset/biom/staph_aureus.biom"
# BIOM_TABLE="./dataset/biom/caitriona_matrix_tubes_none.biom"
BIOM_TABLE=[
    "./dataset/biom/Celeste_Prep_1_1428_samples.biom",
    "./dataset/biom/Celeste_Prep_2_672_samples.biom",
    "./dataset/biom/Celeste_Prep_3_936_samples.biom",
    "./dataset/biom/Celeste_Prep_4_792_samples.biom"
]
# BIOM_TABLE="./dataset/biom/sahar_asd.biom"


MIN_GENUS_COUNT = 500
CONSTANT_SUM_SCALE = False

In [None]:
woltka_meta_table = CSVTable(WOLTKA_METADATA_PATH, delimiter="\t")
woltka_meta_df = woltka_meta_table.load_dataframe()

if not isinstance(BIOM_TABLE, list):
    BIOM_TABLE = [BIOM_TABLE]

all_dfs = []
for biom_table in BIOM_TABLE:
    bt = BiomTable(biom_table)
    df = bt.load_dataframe()
    all_dfs.append(df)

df = pd.concat(all_dfs).fillna(0)

print(df.sum())
# all_df = None
all_species_vecs = BiomTable(ALL_SPECIES_VECS)
all_df = all_species_vecs.load_dataframe()

all_df.loc["intestini",:] = (all_df.loc["G000431295",:] + all_df.loc["G000230275",:]) / 2
all_df.loc["fermentans",:] = (all_df.loc["G000025305",:] + all_df.loc["G900107075",:] + all_df.loc["G900115425",:]) / 3
all_df.loc["CAG:196",:] = (all_df.loc["G000433235",:] + all_df.loc["G001917235",:]) / 2

shared_species_vecs = BiomTable(SHARED_SPECIES_VECS)
shared_df = shared_species_vecs.load_dataframe()

MS_TARGET = None
# if "imsms" in BIOM_TABLE:
#     df['target'] = df.index.map(is_ms)
#     MS_TARGET = df['target']
#     df = df.drop(['target'], axis=1)
try:
    with open(DATA_TRANSFORM_PATH) as infile:
        data_transform = json.load(infile)
except:
    print("Couldn't open data transform, starting new data transform")
    data_transform = {}

In [None]:
df_sum = df.sum(axis=1)
cst = df.divide(df_sum, axis='rows') * 10000

if CONSTANT_SUM_SCALE:
    df = cst

In [None]:
def list_genera():
    woltka_included = woltka_meta_df[woltka_meta_df['#genome'].isin(df.columns)]
    vcs = woltka_included['genus'].value_counts()
    genera = sorted(vcs[vcs>1].index.astype(str).tolist())
    
    filt_genera = ['all']
    for g in genera:
        filtered_df = df[list_woltka_refs(g)['#genome']]
        filtered_df_sum = filtered_df.sum(axis=1)
        filtered_df = filtered_df[filtered_df_sum >= MIN_GENUS_COUNT]
        if len(filtered_df) >= 10:
            filt_genera.append(g)
    return filt_genera

def list_woltka_refs(genus=None):
    woltka_included = woltka_meta_df[woltka_meta_df['#genome'].isin(df.columns)]
    if genus == None:
        refs = woltka_included
    else:
        refs = woltka_included[woltka_included['genus']==genus]
    
    filtered_df = df[refs['#genome']]
    col_sums = filtered_df.sum()
    col_sums.name='total'
    refs = refs.join(col_sums, on='#genome')
    
    refs = refs.reset_index()
    refs = refs.sort_values(["total", "#genome"], ascending=False)[['total', '#genome','species']]
    return refs

In [None]:
def plot_scatter3(filtered_df, cluster_manager, title, c1, c2, c3, target_cluster=None, subplot_ax=None, conic_bases=None):    
    filtered_df = filtered_df.copy()
    cs = cluster_manager.cluster_state
    cluster_counts = cs.get_cluster_counts()
    
    # print("Num Clusters:", cs.num_clusters())
    if cs.num_clusters() > 0:
        subspace_bases = calc_subspace_bases(filtered_df.T.to_numpy(), cs.clusters, cs.cluster_dims)
        subspace_bases = {x: pd.DataFrame(subspace_bases[x], index=filtered_df.columns) for x in subspace_bases}
        all_proj = calc_projected(filtered_df, cs.clusters, subspace_bases)
    elif state['genus'] in data_transform:
        subspace_bases = data_transform[state['genus']]
        subspace_bases = {i: pd.read_json(subspace_bases[i]) for i in range(len(subspace_bases))}
        if len(subspace_bases) == 0:
            subspace_bases = None
        # print("Loaded bases:", subspace_bases)
    else:
        subspace_bases = None
    
    if conic_bases is not None:
        conic_bases = {x: pd.DataFrame(conic_bases[x], index=filtered_df.columns) for x in conic_bases}
    
    if subplot_ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
    else:
        ax = subplot_ax

    maxx = max(filtered_df[c1])
    maxy = max(filtered_df[c2])
    maxz = max(filtered_df[c3])
    
    col1_index = filtered_df.columns.get_loc(c1)
    col2_index = filtered_df.columns.get_loc(c2)
    col3_index = filtered_df.columns.get_loc(c3)

    if target_cluster is not None:
        pidx = np.where(cs.clusters == target_cluster)[0]
        cdf = filtered_df.iloc[pidx]
        maxx = max(cdf[c1])
        maxy = max(cdf[c2])
        maxz = max(cdf[c3])

    if subspace_bases is not None:
        min_label = -1
        if -1 not in cluster_counts or cluster_counts[-1] == 0:
            min_label = 0
        max_label = max(subspace_bases)

        x = np.linspace(0, maxx, 10)
        y = np.linspace(0, maxy, 10)

        if conic_bases is not None:
            to_draw = conic_bases
            if target_cluster is not None:
                to_draw = [target_cluster]
                
            for label in to_draw:
                dim = conic_bases[label].shape[1]
                for u_idx in range(dim):
                    u = conic_bases[label].iloc[:,u_idx].loc[[c1,c2,c3]]
                    _draw_subspace_line(ax, u, maxx, maxy, maxz, label, min_label, max_label)
        for label in subspace_bases:
            dim = subspace_bases[label].shape[1]
            if dim == 1 or dim > 2:
                # unclear how to draw 3+ d surfaces, so just draw primary axis
                u = subspace_bases[label].iloc[:,0].loc[[c1,c2,c3]]
                _draw_subspace_line(ax,u,maxx,maxy,maxz,label,min_label,max_label)
            if dim == 2:
                # the basis vectors are orthogonal
                # But when you project them down to this space, their projections may not be orthogonal
                # They could even be parallel, or even 0 (Take u,v = 0,0,1,-1 and 0,0,1,1 projected to first three dimensions)
                # If they are parallel, we should be drawing a line, not a plane
                # If they are 0, we should just draw a single point (happens to be at the origin
                u = subspace_bases[label].iloc[:,0].loc[[c1,c2,c3]]
                v = subspace_bases[label].iloc[:,1].loc[[c1,c2,c3]]
                _draw_subspace_line(ax,u,maxx,maxy,maxz,label,min_label,max_label)
                _draw_subspace_line(ax,v,maxx,maxy,maxz,label,min_label,max_label)

                npu = np.array(u)
                npv = np.array(v)
                lenu = np.linalg.norm(u)
                lenv = np.linalg.norm(v)
                if lenu < 0.01 and lenv < 0.01:
                    # Both of the vectors are basically 0.  Can't draw anything.
                    continue
                elif lenu < 0.01:
                    _draw_subspace_line(ax, v, maxx,maxy,maxz, label, min_label, max_label)
                    continue
                elif lenv < 0.01:
                    _draw_subspace_line(ax, u, maxx,maxy,maxz, label, min_label, max_label)
                    continue

                npu = npu / lenu
                npv = npv / lenv

                dotproduct = np.dot(npu,npv)
                if dotproduct > 0.9 or dotproduct < -0.9:
                    print(dotproduct)
                    # Vectors are basically parallel.  Can't draw a plane, maybe can draw a line.
                    if dotproduct < 0:
                        npu = -npu
                    npu = (npu + npv) / 2
                    _draw_subspace_line(ax, npu, maxx, maxy, maxz, label, min_label, max_label)
                    continue

                normal = np.cross(u.T, v.T)
                # Normal vector gives x*n0 + y*n1 + z*n2 = 0
                # Divide everything by n2
                # x * n0/n2 + y * n1/n2 + z = 0
                # z = -n0/n2 * x - n1 / n2 * y
                X, Y = np.meshgrid(x,y)
                Z = -normal[0]/normal[2] * X - normal[1] / normal[2] * Y
                rgba = _get_color(label, min_label, max_label)
                rgba = rgba[0],rgba[1],rgba[2],0.25
                surf = ax.plot_surface(X, Y, Z, color=rgba)


    if target_cluster is None:
        if MS_TARGET is None:
            ax.scatter(filtered_df[c1], filtered_df[c2], filtered_df[c3], c=cs.clusters, cmap="Set1")
        else:
            filtered_df['target'] = MS_TARGET.loc[filtered_df.index]
            filtered_df['color'] = cs.clusters
            
            f_on = filtered_df[filtered_df['target'] == True]
            f_off = filtered_df[filtered_df['target'] == False]
            ax.scatter(f_on[c1], f_on[c2], f_on[c3], c=f_on['color'], marker='X', cmap="Set1")
            ax.scatter(f_off[c1], f_off[c2], f_off[c3], c=f_off['color'], marker='o', cmap="Set1")
        # if cs.num_clusters() > 0:
        #     projected = all_proj
        #     ax.scatter(projected[col1_index], projected[col2_index], projected[col3_index], c=cs.clusters, marker='x', cmap='Set1')
    elif target_cluster is not None:
        pidx = np.where(cs.clusters == target_cluster)[0]
        cdf = filtered_df.iloc[pidx]
        rgba = _get_color(target_cluster, min_label, max_label)
        ax.scatter(cdf[c1], cdf[c2], cdf[c3], c=[rgba])
        if cs.num_clusters() > 0:
            projected = all_proj[:,pidx]
            ax.scatter(projected[col1_index], projected[col2_index], projected[col3_index], c=[rgba], marker='x')

    if subspace_bases is not None:
        patches = []
        for label in cs.cluster_dims:
            rgba = _get_color(label, min_label, max_label)
            patch = matplotlib.patches.Patch(
                color=rgba,
                label=str(label) +
                      ": dim=" + str(cs.cluster_dims[label]) +
                      "#fit=" + str(cluster_counts.get(label, 0)))
            patches.append(patch)
    
    # _vec_srcs = [(shared_df, 'r'), (all_df, 'b')]
    # _vec_srcs = [(all_df, 'b')]
    _vec_srcs = []
    for _vec_src, color in _vec_srcs:
        if _vec_src is not None:
            print("PLOT ALL_DF")
            if c1 in _vec_src.columns and c2 in _vec_src.columns and c3 in _vec_src.columns:
                # Argh duplicates.
                col_list = [c1]
                if c2 not in col_list:
                    col_list.append(c2)
                if c3 not in col_list:
                    col_list.append(c3)

                print("IM PLOTTING AHHH")
                likely_candidates = _vec_src[_vec_src[col_list].sum(axis=1) > 5000][col_list]
                print(likely_candidates)
                for index, row in likely_candidates.iterrows():
                    scale_factor = float('inf') 
                    if row[c1] != 0:
                        scale_factor = min(scale_factor, maxx/row[c1])
                    if row[c2] != 0:
                        scale_factor = min(scale_factor, maxy/row[c2])
                    if row[c3] != 0:
                        scale_factor = min(scale_factor, maxz/row[c3]) 
                    if scale_factor == float('inf'):
                        scale_factor = 1
                    ax.scatter(row[c1] * scale_factor, 
                               row[c2] * scale_factor, 
                               row[c3] * scale_factor, c=color, marker='+', s=500)
                    ax.text(row[c1] * scale_factor, row[c2] * scale_factor, row[c3] * scale_factor, index, color='red')

    plt.title(title)
    xseries = woltka_meta_df[woltka_meta_df["#genome"] == c1]["species"]
    yseries = woltka_meta_df[woltka_meta_df["#genome"] == c2]["species"]
    zseries = woltka_meta_df[woltka_meta_df["#genome"] == c3]["species"]
    
    if len(xseries) > 0 and len(yseries) > 0 and len(zseries) > 0:
        xlabel = "\n"+xseries.iloc[0] + "\n" + c1
        ylabel = "\n"+yseries.iloc[0] + "\n" + c2
        zlabel = "\n"+zseries.iloc[0] + "\n" + c3
    else:
        xlabel = str(c1)
        ylabel = str(c2)
        zlabel = str(c3)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    ax.set_xlim([0, maxx * 1.5 + 100])
    ax.set_ylim([0, maxy * 1.5 + 100])
    ax.set_zlim([0, maxz * 1.5 + 100])
    
    if subspace_bases is not None:
        plt.legend(handles=patches)
        
    if subplot_ax is None:
        plt.show()

In [None]:
# We need the ability to merge similar vectors that come back from clustering.
def merge_similar_vectors(vectors, weights, same_thresh_degrees = 5):
    def kruskal_top(groups, i):
        if i not in groups:
            return i
        cur = i
        while groups[cur] != cur:
            cur = groups[cur]
        return cur

    # How to pick same_thresh:
    # 1 degree off: dot product .9998
    # 5 degrees off: dot product .996
    # 15 degrees off: dot product .966
    # x degrees off: dot product = cos(x * pi/180)
    dot_thresh = np.cos(same_thresh_degrees * math.pi/180)

    # L2 normalize and ensure all vectors are positive 
    # (or at least as positive as we can make them)
    vectors = np.copy(vectors)
    for col in range(vectors.shape[1]):
        l2 = np.linalg.norm(vectors[:, col])
        vectors[:,col] /= l2
        if np.sum(vectors[:,col]) < 0:
            vectors[:,col] = -vectors[:,col]

    # Find all nearly parallel/anti-parallel vectors, probably have to collapse these or we'll have 
    # extreme numerical instabilities.  
    print("Basis Shape:", vectors.shape)
    sames = {}
    for i in range(vectors.shape[1]):
        arr = []
        for j in range(vectors.shape[1]):
            # TODO: Ack, basis is L1 normalized, needs to be L2 normalized to make the sameness thresh meaningful
            dp = np.dot(vectors[:,i], vectors[:,j])
            if j > i and dp > dot_thresh or dp < -dot_thresh:
                print("Bases: ", i, j, " are nearly identical")
                a = kruskal_top(sames,i)
                b = kruskal_top(sames,j)
                chosen = min(a,b)
                sames[i] = chosen
                sames[j] = chosen
            arr.append(dp)

    final_groups = {}
    for i in range(vectors.shape[1]):
        if i not in sames:
            sames[i] = i
        group_id = kruskal_top(sames, i)
        if group_id not in final_groups:
            final_groups[group_id] = set()
        final_groups[group_id].add(i)
    
    print(final_groups)

    out_vecs = []
    for group_id in final_groups:
        group = final_groups[group_id]
        # Build final vector as weighted average of vectors in group
        total_vec = np.zeros(vectors.shape[0])
        total_weight = 0
        for col in group:
            total_vec += vectors[:, col] * weights[col]
            total_weight += weights[col]
        total_vec /= total_weight

        # Then L1 normalize output vector
        total_vec = total_vec / np.sum(np.abs(total_vec))
        out_vecs.append(total_vec)        

    return np.stack(out_vecs, axis=-1)
 

In [None]:
genus_widget = widgets.Dropdown(options=list_genera())

axis1 = widgets.Dropdown(layout=widgets.Layout(width='50%'))
axis2 = widgets.Dropdown(layout=widgets.Layout(width='50%'))
axis3 = widgets.Dropdown(layout=widgets.Layout(width='50%'))

plot = widgets.Button(description='Plot')
calculate = widgets.Button(description='Calculate')
save = widgets.Button(description='Save')
fix = widgets.Button(description='Fix')
plot_fix = widgets.Button(description="Plot Fix")
plot_fix_resids = widgets.Button(description="Plot Fix Resids")
output = widgets.Output()

state = {}

def filter_df(df, genus):
    if genus=='all':
        refs_df = list_woltka_refs()
    else:
        refs_df = list_woltka_refs(genus)
    genomes = refs_df['#genome'].tolist()
    if len(genomes) == 0 and genus =='all':
        print("Probably 16S, falling back to raw table")
        filtered_df = df.copy()
    else:
        filtered_df = df[genomes]
        filtered_df_sum = filtered_df.sum(axis=1)
        filtered_df = filtered_df[filtered_df_sum >= MIN_GENUS_COUNT]
    return filtered_df


# Define a function that updates the content of y based on what we select for x
def update_genus(*args):
    genus = genus_widget.value
    if genus=='all':
        refs_df = list_woltka_refs()
    else:
        refs_df = list_woltka_refs(genus)
    
    genomes = refs_df['#genome'].tolist()
    species = refs_df['species'].tolist()
    choices = []

    if genus == 'all' and len(genomes) == 0:
        print("Probably a 16S table, falling back to sorted columns")
        col_sums = df.sum()
        col_sums.name='total'
        col_sums = col_sums.sort_values(ascending=False)
        genomes=list(col_sums.index)
        species=list(genomes)
        
    for i in range(len(genomes)):
        choices.append((species[i] + "(" + genomes[i] + ")", genomes[i]))

    axis1.options = choices
    axis2.options = choices
    axis3.options = choices
    axis1.value = genomes[min(0, len(genomes)-1)]
    axis2.value = genomes[min(1, len(genomes)-1)]
    axis3.value = genomes[min(2, len(genomes)-1)]
    
    filtered_df = filter_df(df, genus)

    state.clear()
    state["genus"] = genus
    state["filtered_df"] = filtered_df
    state["cluster_manager"] = ClusterManager(filtered_df, filtered_df.shape[1])
    state["breakdown"] = data_transform.get(genus)

    with output:
        output.clear_output()
        if len(genomes) < 2:
            print(genus, "Not enough reference genomes to do clustering\n")
            calculate.disabled=True
        elif len(filtered_df) <= 10:
            print(genus, "Not enough reads to do clustering\n")
            calculate.disabled=True
        else:
            calculate.disabled=False

genus_widget.observe(update_genus, names='value')
update_genus()

display(genus_widget)
display(axis1)
display(axis2)
display(axis3)
display(plot)
display(calculate)
display(save)
display(fix)
display(output)
display(plot_fix)
display(plot_fix_resids)

def plot_click(button):
    print('"' + state["genus"] + '":["' + axis1.value + '", "' + axis2.value + '", "' + axis3.value + '"],')
    filtered_df = state["filtered_df"]
    plot_scatter3(
        state["filtered_df"], 
        state.get("cluster_manager"), 
        state["genus"],
        axis1.value,
        axis2.value,
        axis3.value
    )
    
def plot_fix_click(button):
    fixed_df = state.get("fixed_df")
    if fixed_df is None:
        return
    plot_scatter3(
        fixed_df,
        state.get("cluster_manager"), 
        state["genus"],
        fixed_df.columns[0],
        fixed_df.columns[min(1, fixed_df.shape[1]-1)],
        fixed_df.columns[min(2, fixed_df.shape[1]-1)]
    )
    
def _plot_fix_resids(ax, genus, l1_resids, vlines=None):
    # sort the data in ascending order
    x = np.sort(l1_resids)

    # get the cdf values of y
    N = len(l1_resids)
    y = np.arange(N) / float(N)

    # plotting
    ax.set_xlabel('x-axis')
    ax.set_ylabel('y-axis')
    ax.set_title(genus + 'L1 Residuals CDF')
    ax.plot(x, y, marker='o')
    ax.axhline(y=0.8, color='gray', linestyle=':')
    ax.axhline(y=0.9, color='gray', linestyle=':')
    if vlines is not None:
        for vline in vlines:
            ax.axvline(x=vline, color='gray', linestyle=":")
        
def plot_fix_resids_click(button):
    if state.get("l1_resids") is None:
        return
    
    fig, axs = plt.subplots(1, 2, sharey=True)
    _plot_fix_resids(axs[0], state.get("genus"), state.get("l1_resids"))
    _plot_fix_resids(axs[1], state.get("genus"), state.get("l1_resid_pcts"), vlines=[0.05, 0.10])
    plt.show()
    

def _calculate(to_cluster, genus):
    num_dims = 10 # Beware, this is completely dataset dependent!
    if genus != 'all':
        num_dims = to_cluster.shape[1]
    cm = ClusterManager(to_cluster, num_dims)
    rc = RecursiveClusterer()
    rc.run(cm)
    return {"cluster_manager": cm, "clusterer": rc}

def calculate_click(button):
    updated = _calculate(state["filtered_df"], state["genus"])
    state.update(updated)
    
def _save(genus, filtered_df, cm, data_path):
    if cm == None:
        print("No clusters to save")
        return

    cs = cm.cluster_state
    subspace_bases = calc_subspace_bases(filtered_df.T.to_numpy(), cs.clusters, cs.cluster_dims)

    surfaces = []
    for key in subspace_bases:
        basis = subspace_bases[key]
        basis_df = pd.DataFrame(basis, index=filtered_df.columns)
        print(basis_df)
        surfaces.append(basis_df.to_json())
    
    data_transform[genus] = surfaces
    only_genus = {genus: surfaces}

    with open(data_path, 'w') as outfile:
        print("Saving to: " + data_path)
        json.dump(only_genus, outfile)
    
def save_click(button):
    _save(state["genus"], state["filtered_df"], state["cluster_manager"], DATA_TRANSFORM_PATH)

def _fix(cm, filtered_df):
    if cm is None:
        print("No clusters to fix!")
        return {}
    cs = cm.cluster_state
    subspace_bases = calc_subspace_bases(filtered_df.T.to_numpy(), cs.clusters, cs.cluster_dims)
    cluster_counts = cs.get_cluster_counts()
    
    if len(subspace_bases) == 0:
        print("Could not identify species vectors")
        return {}
    # L1 normalize subspace bases.
    for c_id in subspace_bases:
        for col in range(subspace_bases[c_id].shape[1]):
            l1_len = np.sum(np.abs(subspace_bases[c_id][:,col]))
            if l1_len != 0:
                subspace_bases[c_id][:,col] = subspace_bases[c_id][:,col] / l1_len
            if np.sum(subspace_bases[c_id][:,col]) < 0:
                subspace_bases[c_id][:,col] = -subspace_bases[c_id][:,col]

    # TODO FIXME HACK:  Find a way to include SVs from all dims of higher dim clusters rather than just the first
    full_basis = np.stack([subspace_bases[c][:,0] for c in subspace_bases], axis=1)
    basis_counts = [cluster_counts[c] for c in subspace_bases]
    full_basis = merge_similar_vectors(full_basis, basis_counts, same_thresh_degrees=5)
    
    out_pts = []
    l1_resids = []
    l1_resid_pcts = []
    max_resid = 0
    max_resid_pct = 0
    for i in range(filtered_df.shape[0]):
        pt = filtered_df.iloc[i].T.to_numpy()
        output_pt, l2_resid = scipy.optimize.nnls(full_basis, pt)
        nnls_proj = np.matmul(full_basis, output_pt)
        l1_resid = np.sum(np.abs(pt - nnls_proj))

        output_pt = pd.DataFrame(output_pt).T
        output_pt.index = [filtered_df.index[i]]

        l1_resids.append(l1_resid)
        length = np.sum(np.abs(pt))
        l1_resid_pct = l1_resid / length
        l1_resid_pcts.append(l1_resid_pct)
        
            
        max_resid = max(max_resid, abs(l1_resid))
        if pt.sum() > 0: 
            max_resid_pct = max(max_resid_pct, abs(l1_resid)/pt.sum())

        out_pts.append(output_pt)
    
    output_df = pd.concat(out_pts)
    output_df = output_df.fillna(0)
    output_df.columns = list(["SV"+str(i) for i in range(len(output_df.columns))])
    
    return {"fixed_df": output_df, "l1_resids": l1_resids, "l1_resid_pcts": l1_resid_pcts}
    
def fix_click(button):
    updated = _fix(state.get("cluster_manager"), state["filtered_df"])
    state.update(updated)
    
plot.on_click(plot_click)
calculate.on_click(calculate_click)
save.on_click(save_click)
fix.on_click(fix_click)
plot_fix.on_click(plot_fix_click)
plot_fix_resids.on_click(plot_fix_resids_click)

In [None]:
def bulk_calculate(prefix):
    genera = list_genera()
    
    for genus in genera:
        if genus == "all":
            continue
        print("Genus:", genus)
        filtered_df = filter_df(df, genus)
        
        ns = _calculate(filtered_df, genus)
        cm = ns["cluster_manager"]
        
        ns = _fix(cm, filtered_df)
        if len(ns) == 0:
            # no species vectors
            print("SKIP")
            continue
            
        fixed_df = ns["fixed_df"]
        l1_resids = ns["l1_resids"]
        l1_resid_pcts = ns["l1_resid_pcts"]
        
        # fig, axs = plt.subplots(1, 2, sharey=True)
        # _plot_fix_resids(axs[0], genus, l1_resids)
        # _plot_fix_resids(axs[1], genus, l1_resid_pcts, vlines=[0.05, 0.10])
        # plt.show()
        
        _save(genus, filtered_df, cm, "./results/" + prefix + "/" + genus + ".json")
# bulk_calculate("foo")

In [None]:
def bulk_compare(prefix):
    genera = list_genera()
    
    calculation_failed = 0
    calculation_success = 0
    calculation_bad_result = 0
    increased_dim = 0
    columns_evaluated = 0
    columns_dropped = 0

    for genus in genera:
        if genus == "all":
            continue
        print("Genus:", genus)
        file = "./results/" + prefix + "/" + genus + ".json"
        filtered_df = filter_df(df, genus)
        
        initial_dim = filtered_df.shape[1]
        columns_evaluated += initial_dim
        
        if not os.path.exists(file):
            print("File does not exist, calculation failed.")
            calculation_failed += 1
            continue
        
        cm = ClusterManager.apply_bases_from_file(file, filtered_df)
        cm.finalize()
        
        ns = _fix(cm, filtered_df)
        if len(ns) == 0:
            # no species vectors
            print("No Species Vectors Were Calculated")
            calculation_failed += 1
            continue
            
        fixed_df = ns["fixed_df"]
        l1_resids = ns["l1_resids"]
        l1_resid_pcts = ns["l1_resid_pcts"]
        
        x = np.sort(l1_resid_pcts)
        x_80 = l1_resid_pcts[int(.8 * len(l1_resid_pcts))]
        x_90 = l1_resid_pcts[int(.9 * len(l1_resid_pcts))]
        
        print(x_80, x_90)
        
        final_dim = fixed_df.shape[1]
                
        if x_80 < 0.05 and final_dim <= initial_dim:
            print("Success")
            calculation_success += 1
            print("Dropping", initial_dim - final_dim, "columns.")
            columns_dropped += initial_dim - final_dim
        elif x_80 >= 0.05:
            print("Bad Result")
            calculation_bad_result += 1
        elif final_dim > initial_dim:
            increased_dim += 1        
        
        # fig, axs = plt.subplots(1, 2, sharey=True)
        # _plot_fix_resids(axs[0], genus, l1_resids)
        # _plot_fix_resids(axs[1], genus, l1_resid_pcts, vlines=[0.05, 0.10])
        # plt.show()
    
    print("Success", calculation_success)    
    print("Fail", calculation_failed)
    print("Bad Result", calculation_bad_result)
    print("Bad Result (Increased Dim)", increased_dim)
    print("\nTotal Dimensionality Reduction:", columns_dropped, str(int((columns_dropped / columns_evaluated) * 100)) + "%")
    
# bulk_compare("foo")

In [None]:
def cluster_run_click(button):
    cmd = cluster_command.value
    ss = cmd.split()
    f_df = state["filtered_df"]
    cm = state["cluster_manager"]
    cs = cm.cluster_state

    if ss[0] == 'split':
        cluster_id = int(ss[1])
        cluster_pieces = int(ss[2])
        cm.get_split_cluster(cluster_id, cluster_pieces).apply()
        cm.finalize()
    elif ss[0] == 'merge':
        cluster_id_a = int(ss[1])
        cluster_id_b = int(ss[2])
        final_dim = int(ss[3])
        cm.get_merge_clusters(cluster_id_a, cluster_id_b, final_dim).apply()
        cm.finalize()
    elif ss[0] == 'delete' or ss[0] == 'del':
        cluster_id = int(ss[1])
        cm.get_delete_cluster(cluster_id).apply()
        cm.finalize()
    elif ss[0] == 'reassign':
        cm.get_reassign_nearest(dim_penalty=10, outlier_thresh=0.10).apply()
        cm.finalize()
    elif ss[0] == 'prevalence':
        print(df.shape)
        print(f_df.shape)
        print("Approximate " + state["genus"] + " prevalence")
        print(str((f_df.shape[0] / df.shape[0]) * 100) + "%")
        threshold_df = counts_to_presence_absence(df[f_df.columns], 500)
        pairwise_eval(threshold_df)
    elif ss[0] == 'info':
        
        if len(ss) >= 2:
            cluster_id = int(ss[1])
            to_run = [cluster_id]
        else:
            to_run = range(cm.max_cluster_id())
        
        for cluster_id in to_run:
            idx = cs.clusters == cluster_id

            # print("Samples In Cluster:")
            # print(f_df.iloc[idx].index)

            subspace_bases = calc_subspace_bases(f_df.T.to_numpy(), cs.clusters, cs.cluster_dims)
            basis = subspace_bases[cluster_id]
            basis_df = pd.DataFrame(basis, index=f_df.columns)

            abs_basis = basis_df.abs()
            subspace = abs_basis.idxmax()

            pdf = pd.DataFrame(data=subspace, columns=["#genome"])
            pdf = pdf.merge(woltka_meta_df, on="#genome")

            print("Approximate Name(s)", cluster_id)
            print(pdf[["#genome", "species"]])
    elif ss[0] == 'diff':
        first = int(ss[1])
        second = int(ss[2])
        
        subspace_bases = calc_subspace_bases(f_df.T.to_numpy(), cs.clusters, cs.cluster_dims)
        basis_1 = subspace_bases[first]
        basis_2 = subspace_bases[second]
        
        b1 = basis_1[:,0]
        b2 = basis_2[:,0]
        
        if np.sum(b1) < 0:
            b1 = -b1
        if np.sum(b2) < 0:
            b2 = -b2
        
        b1_l2norm = np.linalg.norm(b1)
        b2_l2norm = np.linalg.norm(b2)
        
        b1 = b1 / b1_l2norm
        b2 = b2 / b2_l2norm
        
        dot = np.dot(b1,b2)
        theta = np.arccos(dot)
        print("Dot Product:", dot, "Diff Angle Theta (Degrees):", theta * 180/math.pi)        
        diffs = np.abs(b2 - b1)
        max_diff = np.argmax(diffs)
        print(f_df.columns[max_diff],diffs[max_diff], b1[max_diff], b2[max_diff])
    elif ss[0] == 'taxi':
        new_filt_df = f_df.copy()

        for cluster_id in cs.cluster_dims:
            idx = cs.clusters == cluster_id
        
        cluster_dim = cs.cluster_dims[cluster_id]
        if cluster_dim == 1:
            new_filt_df["AX"+str(cluster_id)]
    elif ss[0] == "load":
        f = ss[1]
        # ./results/celeste_ecoli_many.json
        # TODO FIXME HACK:  Should I apply it to df or filtered df?
        cm = ClusterManager.apply_bases_from_file(f, f_df)
        cm.finalize()
        state["cluster_manager"] = cm
    else:
        print("Unsupported Command")

cluster_command = widgets.Text(layout=widgets.Layout(width='80%'), placeholder='split 1 2')
cluster_run = widgets.Button(description='Run Command')

display(cluster_command)
display(cluster_run)

cluster_run.on_click(cluster_run_click)


In [None]:
state["cluster_manager"].cluster_state.clusters

In [None]:
def bulk_plot(df, genlist=None, suptitle=None):
    if genlist is None:
        genlist = list_genera()
    print(len(genlist))
    i = 0
    fig = None
    passing = 0
    for genus in genlist:
        if genus == 'all':
            continue

        filtered_df = filter_df(df, genus)
        sample_pass = round((filtered_df.shape[0] / df.shape[0]) * 100,2)
        print(genus + ", " + str(sample_pass) + "%" + ", " + imsms_qualitative_class[genus][0])

        # if sample_pass < 15:
        #     passing += 1
        #     continue
        
        if i == 0:
            fig = plt.figure()
            if suptitle is not None:
                fig.suptitle(suptitle)
        ax = fig.add_subplot(2, 3, i+1, projection='3d')
        i += 1

        classification = "\n(Co-Exclusive)"
        if imsms_qualitative_class[genus][0] != "YES":
            classification = "\n(Co-Occurring)"
        plot_scatter3(
            filtered_df, 
            ClusterManager(filtered_df, filtered_df.shape[1]), 
            genus + ": " + str(sample_pass) + "%" + classification,
            imsms_plots[genus][0],
            imsms_plots[genus][1],
            imsms_plots[genus][2],
            subplot_ax = ax
        )
        
        if i == 6:
            i = 0
            plt.show()
    print(passing)


In [None]:
# df_imsms = BiomTable("./dataset/biom/imsms-combined-none.biom").load_dataframe()
# df_finrisk = BiomTable("./dataset/biom/finrisk-combined-none.biom").load_dataframe()
# df_sol = BiomTable("./dataset/biom/sol_public_99006-none.biom").load_dataframe()

#bulk_plot(df)

# bulk_plot(df_imsms, ["Akkermansia", "Butyricicoccus", "Dialister", "Eggerthella", "Methanobrevibacter", "Ruminiclostridium"], "iMSMS")
# bulk_plot(df_finrisk, ["Akkermansia", "Butyricicoccus", "Dialister", "Eggerthella", "Methanobrevibacter", "Ruminiclostridium"], "FINRISK")
# bulk_plot(df_sol, ["Akkermansia", "Butyricicoccus", "Dialister", "Eggerthella", "Methanobrevibacter", "Ruminiclostridium"], "SOL")
# plt.show()

In [None]:
state["filtered_df"].sum(axis=1).median()

In [None]:
cm = state["cluster_manager"]

full_clusters = cm.apply_active_clustering(df, dim_penalty=5, outlier_thresh=.1)
print(full_clusters.value_counts())

full_df = df[cm.data_cols]
full_cm = ClusterManager(full_df, len(cm.data_cols))
full_cm.cluster_state = cm.cluster_state.copy()
full_cm.cluster_state.clusters = full_clusters.to_numpy()

# plot_scatter3(full_df, full_cm, "TEST", full_df.columns[0],full_df.columns[1],full_df.columns[2])

In [None]:
# For my next trick, I will convert linear subspaces to conic polyhedrons
# Since its not entirely clear how to do this when there is noise, we will use the following idea:
# Points on a 1D linear subspace results in an obvious basis vector
# Points on a 2D linear subspace correspond to a 1D line segment on the simplex.  The endpoints of that line
# segment result in two basis vectors
# Points on a 3D linear subspace correspond to a 2D polygon on the simplex.  The "best" simplification of
# that polygon to a triangle results in three basis vectors.  The best simplification is likely the
# triangle with the largest area, though we can try other algorithms.  
# For a 4D, it should result in a 3D polyhedron on the simplex, the best simplification of that to a 
# tetrahedron is the conic polyhedron.  
# For a 5D, it should result in a 4-simplex on the the 5D simplex.  Grarghh.  

# NOTE:  Finding the largest n simplex in an n-1 dimensional convex hull is NP hard and inapproximable.
# RIP.  https://stackoverflow.com/questions/50049658/largest-simplex-in-convex-hull-of-points-in-n-dimensions
# We will go ahead and use either brute force or some bs heuristic.  
cm = state["cluster_manager"]
cs = cm.cluster_state
filtered_df = state["filtered_df"]
subspace_bases = calc_subspace_bases(filtered_df.T.to_numpy(), cs.clusters, cs.cluster_dims)
all_proj = calc_projected(filtered_df, cs.clusters, subspace_bases)

conic_bases = {}
for cluster_id in subspace_bases:
    if subspace_bases[cluster_id].shape[1] == 1:
        conic_bases[cluster_id] = subspace_bases[cluster_id]
    else:
        print("Cluster ID:", cluster_id, "Dim:", cs.cluster_dims[cluster_id])
        
        # Grab points in the cluster
        idx = cs.clusters == cluster_id
        cluster_proj = all_proj[:,idx]
        print(cluster_proj.shape)
        
        # Push them out to the simplex (L1 normalize them)
        for i in range(cluster_proj.shape[1]):
            c_pt = cluster_proj[:,i]
            if c_pt.sum() != 0:
                c_pt = c_pt/c_pt.sum()
            cluster_proj[:,i] = c_pt
        
        M = subspace_bases[cluster_id]
        cluster_proj_low_d = np.matmul(M.T, cluster_proj).T        
        
        # Since we've pushed everything to the simplex, qhull's convex hull will say we are too
        # low dimension.  We add a point at the origin to get it to give us a clean set of points.
        
        cluster_proj_low_d = np.vstack([cluster_proj_low_d, np.zeros(cluster_proj_low_d.shape[1])])
        print(cluster_proj_low_d.shape)
        hull = scipy.spatial.ConvexHull(cluster_proj_low_d)
        
        print(hull.vertices)
        hull_verts = hull.vertices[hull.vertices != cluster_proj_low_d.shape[0]-1]
        print(hull_verts)
        plt.plot(cluster_proj_low_d[hull_verts,0], cluster_proj_low_d[hull_verts,1], 'ro')
        for simplex in hull.simplices:
            plt.plot(cluster_proj_low_d[simplex, 0], cluster_proj_low_d[simplex, 1], 'k-')
        plt.show()
        
        if subspace_bases[cluster_id].shape[1] == 2:
            if len(hull_verts) != 2:
                print("Dangit hull verts")
                print(hull_verts)
                print(all_proj[:,idx][:, hull_verts])
                raise Exception("Aww I thought it was working, check coplanar points in convex hull?")
            
            idx = cs.clusters == cluster_id
            cluster_proj = all_proj[:,idx]
            
            conic_bases[cluster_id] = cluster_proj[:, hull_verts]
            print(conic_bases[cluster_id])
            print(subspace_bases[cluster_id])
            
        if subspace_bases[cluster_id].shape[1] == 3:
            # Ugh.  Find largest triangle in a convex hull.  At least its more plausible in 3D than 4+
            # A possible optimization for this case is here:
            # https://stackoverflow.com/questions/1621364/how-to-find-largest-triangle-in-convex-hull-aside-from-brute-force-search/1621913#1621913
            # Area of triangle by heron's formula
            
            largest_area_idx = [-1,-1,-1]
            largest_area_sq = 0
            for i in hull_verts:
                for j in hull_verts:
                    for k in hull_verts:
                        pt_i = cluster_proj[:,i]
                        pt_j = cluster_proj[:,j]
                        pt_k = cluster_proj[:,k]
                        a = np.linalg.norm(pt_j - pt_i)
                        b = np.linalg.norm(pt_k - pt_j)
                        c = np.linalg.norm(pt_i - pt_k)
                        p = (a+b+c)/2
                        area_sq = p * (p-a) * (p-b) * (p-c)
                        if area_sq > largest_area_sq:
                            largest_area_idx = [i,j,k]
                            largest_area_sq = area_sq
            
            # conic_bases[cluster_id] = cluster_proj[:, largest_area_idx]
            conic_bases[cluster_id] = cluster_proj[:, hull_verts]
            print(conic_bases[cluster_id])
            print(subspace_bases[cluster_id])
            
c1 = min(3, len(filtered_df.columns)-1)
c2 = min(1, len(filtered_df.columns)-1)
c3 = min(4, len(filtered_df.columns)-1)
plot_scatter3(filtered_df, cm, "CONIC TEST", filtered_df.columns[c1],filtered_df.columns[c2],filtered_df.columns[c3], target_cluster=7, conic_bases=conic_bases)

In [None]:
# This is how to transform to a linear subspace, not a conic polyhedron, ugh...
# 𝑌=(𝑀𝑇𝑀)^(−1) * 𝑀𝑇𝑊
# See https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.nnls.html
# for a way to project onto a conic polyhedron instead
# (of course, we still need to compute the conic polyhedron rather than just the subspaces. Ugh.)
cm = state["cluster_manager"]
cs = cm.cluster_state
filtered_df = state["filtered_df"]

#Blah, should make this an enum
USE_REFERENCE_BASIS = True
USE_CALCULATED_BASIS = False

if USE_REFERENCE_BASIS:
    full_basis = (all_df.copy() / all_df.sum(axis=0)).T.to_numpy()
elif USE_CALCULATED_BASIS:
    subspace_bases = calc_subspace_bases(filtered_df.T.to_numpy(), cs.clusters, cs.cluster_dims)

    for c_id in subspace_bases:
        for col in range(subspace_bases[c_id].shape[1]):
            # If we are to maintain read counts in the transformed space, basis vectors must be L1 normalized
            # This also negates ones which point in overall negative directions.  
            l1_len = np.sum(subspace_bases[c_id][:,col])
            if l1_len != 0:
                subspace_bases[c_id][:,col] = subspace_bases[c_id][:,col] / l1_len

    print(subspace_bases)
    full_basis = np.concatenate([subspace_bases[c] for c in subspace_bases], axis=1)

print(full_basis)

# Find all nearly parallel/anti-parallel vectors, probably have to collapse these or we'll have 
# extreme numerical instabilities.  
SAMENESS_THRESH = 0.9
print("Basis Shape:", full_basis.shape)
sames = {}
for i in range(full_basis.shape[1]):
    arr = []
    for j in range(full_basis.shape[1]):
        # TODO: Ack, basis is L1 normalized, needs to be L2 normalized to make the sameness thresh meaningful
        dp = np.dot(full_basis[:,i], full_basis[:,j])
        if j > i and dp > SAMENESS_THRESH or dp < -SAMENESS_THRESH:
            print("Bases: ", i, j, " are nearly identical")
            # TODO: If there are 3+ axes that are the same, this code won't work right.
            sames[j] = i
        arr.append(dp)
    # print(arr)
    
# Maybe would be better to do a weighted average of near identical vectors, but for now,
# we'll just take the first one, since we are always more confident in vectors towards the beginning of
# the list based on the lower dimension and increased number of samples.  
for same_key in sames:
    full_basis[:, same_key] = full_basis[:, sames[same_key]]

vec_index = 0
for cluster_id in subspace_bases:
    for col in range(subspace_bases[cluster_id].shape[1]):
        if vec_index in sames:
            subspace_bases[cluster_id][:, col] = full_basis[:, sames[vec_index]]
        vec_index += 1

# Unclear what to do with outlier points which are either unknown species vectors or rare co-occurrence of
# identified species vectors.  We can build a full basis of the identified species vectors to account for
# co-occurrence, but our basis vectors for each subspace are not orthogonal to each other
# This can result in underconstrained solutions and numerical instability.  If we use Graham-Schmidt 
# orthonormalization or something similar on our basis vectors in the order we wish to assign them 
# weight, (and our spaces are sorted by number of dimensions and number of points, so we are assigning
# most common first), that may break ties in a clear way.  Majority of numerical instability probably
# comes from having multiple estimates of the same species vector, (A 1d vector and a 2d surface overlapping
# would do that, as would two 2d surfaces that intersect)
solver_cols = {}
Ms = {}

M = full_basis
keep_cols = [i for i in range(full_basis.shape[1]) if i not in sames]
M = M[:, keep_cols]
cols = ["ax" + str(c) for c in keep_cols]

solver_cols[OUTLIER_CLUSTER_ID] = cols
Ms[OUTLIER_CLUSTER_ID] = M
tot_vecs = 0
for cluster_id in subspace_bases:
    M_i = subspace_bases[cluster_id]
    Ms[cluster_id] = M_i
    solver_cols[cluster_id] = []
    for i in range(subspace_bases[cluster_id].shape[1]):
        if tot_vecs in sames:
            solver_cols[cluster_id].append("ax" + str(sames[tot_vecs]))
        else:
            solver_cols[cluster_id].append("ax" + str(tot_vecs))
        tot_vecs += 1

# For points which do cluster well, we have orthonormal bases (though we want to switch to non orthogonal
# bases of conic polyhedra)
to_transform = df
full_clusters = cm.apply_active_clustering(to_transform, dim_penalty=5, outlier_thresh=.1)
print(full_clusters.value_counts())

full_df = to_transform[cm.data_cols]
full_cm = ClusterManager(full_df, len(cm.data_cols))
full_cm.cluster_state = cm.cluster_state.copy()
full_cm.cluster_state.clusters = full_clusters.to_numpy()

out_pts = []
max_resid = 0
for i in range(to_transform.shape[0]):
    pt = to_transform[cm.data_cols].iloc[i].T.to_numpy()
    cluster_assignment = full_clusters.iloc[i]
    
    output_pt_lin_subspace = np.matmul(solvers[cluster_assignment], pt)
    output_pt_nnls, l2_resid = scipy.optimize.nnls(Ms[cluster_assignment], pt)
    nnls_proj = np.matmul(Ms[cluster_assignment], output_pt_nnls)
    l1_resid = np.sum(pt - nnls_proj)
    # print(output_pt_lin_subspace, " vs ", output_pt_nnls, "(", resid, ")")
    output_pt = output_pt_nnls                                                   
                                                   
    output_pt = pd.DataFrame(output_pt).T
    output_pt.index = [to_transform.index[i]]
    output_pt.columns = solver_cols[cluster_assignment]
    # print(pt.sum(), output_pt.sum(axis=1), "+", l1_resid)
    max_resid = max(max_resid, abs(l1_resid))
    if pt.sum() > 0: 
        max_resid_pct = max(max_resid_pct, abs(l1_resid)/pt.sum())
    
    out_pts.append(output_pt)

print("MAX RESIDUAL: ", max_resid)
    
output_df = pd.concat(out_pts)
output_df = output_df.fillna(0)
print(output_df)
print(output_df.loc["S.71801.0073.4.7.17"])    
output_cm = ClusterManager(output_df, output_df.shape[1])
output_cm.cluster_state = cm.cluster_state.copy()
output_cm.cluster_state.clusters = full_clusters.to_numpy()
output_cm.cluster_state.cluster_dims[-1] = output_df.shape[1]

c1 = min(0, len(output_df.columns)-1)
c2 = min(1, len(output_df.columns)-1)
c3 = min(2, len(output_df.columns)-1)
plot_scatter3(output_df, full_cm, "TEST", output_df.columns[c1],output_df.columns[c2],output_df.columns[c3])


In [None]:
print(output_df.head())
c1 = min(0, len(output_df.columns)-1)
c2 = min(3, len(output_df.columns)-1)
c3 = min(4, len(output_df.columns)-1)
plot_scatter3(output_df, full_cm, "TEST", output_df.columns[c1],output_df.columns[c2],output_df.columns[c3])


In [None]:
pts = np.array([[1,1],[5,5],[3,0],[2,1]]) 

hull = scipy.spatial.ConvexHull(pts)

plt.plot(pts[:,0], pts[:,1], 'o')
for simplex in hull.simplices:
    plt.plot(pts[simplex, 0], pts[simplex, 1], 'k-')
    
# plt.plot(pts[hull.vertices,0], pts[hull.vertices,1], 'r--', lw=2)
# plt.plot(pts[hull.vertices[0],0], pts[hull.vertices[0],1], 'ro')
plt.show()



In [None]:
filtered_df = state["filtered_df"]
plot_scatter3(
    state["filtered_df"], 
    state.get("cluster_manager"), 
    state["genus"],
    axis1.value,
    axis2.value,
    axis3.value
)

In [None]:
with open("./results/celeste_ecoli_many.json") as infile:
    data_transform = json.load(infile)
    all_bases = []
    counter = 0
    for subspace in range(len(data_transform["all"])):
        conic_basis = pd.read_json(data_transform["all"][subspace])
        conic_basis.columns = range(counter, counter + conic_basis.shape[1])
        counter += conic_basis.shape[1]
        all_bases.append(conic_basis)
    conic_basis = pd.concat(all_bases, axis=1)


# foo_basis = pd.DataFrame([[100/600, 200/700],[100/600, 0],[400/600, 500/700]])
# foo_df = pd.DataFrame([[100, 400],[100, 0],[400, 1000]])

# print(transformer.transform(foo_df, foo_basis))


# conic_basis = transformer.L1_normalize(conic_basis)
# print(conic_basis.loc[["G000183345", "G000026345","G000026325", "G000299455", "G000008865", "G001283625", "G000759795", "G001941055", "G000009065"],:])
transformed = transformer.transform(df, conic_basis)
colors = []
for i in transformed.index:
    ss = i.split(".")
    if ss[1].startswith("4"):
        colors.append("m")
    else:
        colors.append("b")

plt.scatter(transformed[0], transformed[1], c=colors)
plt.show()
# plt.scatter(df["G000183345"], df["G000026345"])
# plt.show()
# plt.scatter(transformed[0], transformed[1])
# plt.show()

transformed.hist("L1_resid")
plt.show()
transformed[transformed["L1_resid"] > 50000].hist("L1_resid")
plt.show()

transformed = transformed.sort_values("L1_resid", ascending=False)
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print(transformed.head(n=100)[["L1_resid", "WorstAxis"]])
    
foo_df = df.loc[transformed.head(n=200).index]

# plot_scatter3(
#     foo_df, 
#     ClusterManager(foo_df, 1),
#     state["genus"],
#     axis1.value,
#     axis2.value,
#     axis3.value
# )


In [None]:
cm

In [None]:
# Extra junk for plotting info about the bacteroides samples
# bacteroides_sample_info = CSVTable("./dataset/biom/14360_20211220-082840.txt", delimiter="\t")
# bacteroides_sample_info = bacteroides_sample_info.load_dataframe()
# bacteroides_sample_info.index = bacteroides_sample_info["sample_name"]

# bacteroides_sample_contamination = CSVTable("./dataset/biom/bfragilis_divisions.txt", delimiter="\t")
# bacteroides_sample_contamination = bacteroides_sample_contamination.load_dataframe()

# sample_division = bacteroides_sample_contamination[["Division", "Sample"]]
# tube_ids = sample_division["Sample"].map(lambda x: x.split("_")[0])
# sample_division.index = tube_ids
# sample_division.index.name = "tube_id"

# real_filtered = state["filtered_df"]
# real_cm = state["cluster_manager"]
# try:
#     fake_clusters = []
#     requested_samples = []
#     for sample_id in state["filtered_df"].index:
#         tube_id = bacteroides_sample_info.loc[sample_id, "tube_id"]
#         if str(tube_id) in sample_division.index:
#             division = int(sample_division.loc[str(tube_id), "Division"]) - 1
#             requested_samples.append(sample_id)
#         else:
#             division = -1
#         fake_clusters.append(division)

#     print(fake_clusters)
#     # state["filtered_df"] = state["filtered_df"].loc[requested_samples]    
#     fake_cm = ClusterManager(state["filtered_df"], 1)
        
#     fake_cm.cluster_state = ClusterState(fake_cm, np.array(fake_clusters), {-1:1, 0:1, 1:1})
#     state["cluster_manager"] = fake_cm
#     plot_click(None)
# finally:
#     pass
#     # state["filtered_df"] = real_filtered
#     # state["cluster_manager"] = real_cm

In [None]:
from sklearn.decomposition import NMF
from detnmf.detnmf import run_detnmf

# X = df.to_numpy() #np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
# model = NMF(n_components=1)
# W = model.fit_transform(X)
# H = model.components_

# print("Coefficients")
# print(W)
# print("Components")
# print(H)

# print(np.argmax(H))
# print(df.columns[np.argmax(H)])

X = df.to_numpy()
# X = np.array([[0, 0], [100, 100], [200, 200], [300, 300], [400, 400], [300, 100], [600, 200], [900, 300], [1200, 400], [1500, 500]])
# model = MVCNMF(n_components=2)
# A, S = model.fit(X, learning_rate=10)

n_components = 2
W, H = run_detnmf(X, n_components, 1000, 2500)

print("Data")
print(X.shape)
print(X.sum())
print(X)
print("Coefficients")
print(W)
print(W.shape)
print("Components")
print(H)
print(H.shape)
print(H.sum(axis=1))

for component in range(n_components):
    big_val = np.argmax(H[component, :])
    print(df.columns[big_val])

qq = H.copy()
qq[0,:] = qq[0,:] / np.linalg.norm(qq[0,:])
qq[1,:] = qq[1,:] / np.linalg.norm(qq[1,:])
print(np.dot(qq[0,:], qq[1,:]))

basis = H.T
basis_df = pd.DataFrame(basis, index=df.columns)
to_save = {"all":[basis_df.to_json()]}

with open("./results/detnmf.json", 'w') as outfile:
    json.dump(to_save, outfile)
