In [14]:
from table_info import BiomTable, CSVTable
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display
import test_linear_clustering
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
import pandas as pd

from test_linear_clustering import recursive_subspacer, iterative_clustering, calc_projected, _fit_a_linear_subspace, _draw_subspace_line, _get_color
from linear_subspace_clustering import linear_subspace_clustering
from cluster_manager import ClusterManager, ClusterState, RecursiveClusterer
import numpy as np
import matplotlib
import json

In [2]:
WOLTKA_METADATA_PATH="./woltka_metadata.tsv"
BIOM_TABLE="./dataset/biom/combined-none.biom"
DATA_TRANSFORM_PATH="./results/jupyter_interactive.json"

MIN_GENUS_COUNT = 500

In [3]:
woltka_meta_table = CSVTable("./woltka_metadata.tsv", delimiter="\t")
woltka_meta_df = woltka_meta_table.load_dataframe()

bt = BiomTable(BIOM_TABLE)
df = bt.load_dataframe()

try:
    with open(DATA_TRANSFORM_PATH) as infile:
        data_transform = json.load(infile)
except:
    print("Couldn't open data transform, starting new data transform")
    data_transform = {}

In [4]:
def list_genera():
    woltka_included = woltka_meta_df[woltka_meta_df['#genome'].isin(df.columns)]
    vcs = woltka_included['genus'].value_counts()
    genera = sorted(vcs[vcs>1].index.astype(str).tolist())
    
    filt_genera = []
    for g in genera:
        filtered_df = df[list_woltka_refs(g)['#genome']]
        filtered_df_sum = filtered_df.sum(axis=1)
        filtered_df = filtered_df[filtered_df_sum >= MIN_GENUS_COUNT]
        if len(filtered_df) >= 10:
            filt_genera.append(g)
    return filt_genera

def list_woltka_refs(genus):
    woltka_included = woltka_meta_df[woltka_meta_df['#genome'].isin(df.columns)]
    refs = woltka_included[woltka_included['genus']==genus]
    
    filtered_df = df[refs['#genome']]
    col_sums = filtered_df.sum()
    col_sums.name='total'
    refs = refs.join(col_sums, on='#genome')
    
    refs = refs.reset_index()
    refs = refs.sort_values(["total", "#genome"], ascending=False)[['total', '#genome','species']]
    return refs

In [12]:
def calc_subspace_bases(filtered_df, cluster_assignments, cluster_dims):
    data = filtered_df.T.to_numpy()
    labels = np.unique(cluster_assignments)

    vecs = {}
    for gg in labels:
        pidx = np.where(cluster_assignments == gg)[0]
        if gg == -1:
            continue
        subspace_basis_vectors = _fit_a_linear_subspace(data[:, pidx], cluster_dims[gg])
        vecs[gg] = subspace_basis_vectors
    return vecs

def plot_scatter3(filtered_df, cluster_manager, title, c1, c2, c3, target_cluster=None):    
    cs = cluster_manager.cluster_state
    cluster_counts = cs.get_cluster_counts()
    
    if cs.num_clusters() > 0:
        subspace_bases = calc_subspace_bases(filtered_df, cs.clusters, cs.cluster_dims)
        subspace_bases = {x: pd.DataFrame(subspace_bases[x], index=filtered_df.columns) for x in subspace_bases}
        all_proj = calc_projected(filtered_df, cs.clusters, subspace_bases)
    elif state['genus'] in data_transform:
        subspace_bases = data_transform[state['genus']]
        subspace_bases = {i: pd.read_json(subspace_bases[i]) for i in range(len(subspace_bases))}
    else:
        subspace_bases = None

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    maxx = max(filtered_df[c1])
    maxy = max(filtered_df[c2])
    maxz = max(filtered_df[c3])
    
    col1_index = filtered_df.columns.get_loc(c1)
    col2_index = filtered_df.columns.get_loc(c2)
    col3_index = filtered_df.columns.get_loc(c3)

    if target_cluster is not None:
        pidx = np.where(cs.clusters == target_cluster)[0]
        cdf = filtered_df.iloc[pidx]
        maxx = max(cdf[c1])
        maxy = max(cdf[c2])
        maxz = max(cdf[c3])

    if subspace_bases is not None:
        min_label = -1
        if -1 not in cluster_counts or cluster_counts[-1] == 0:
            min_label = 0
        max_label = max(subspace_bases)

        x = np.linspace(0, maxx, 10)
        y = np.linspace(0, maxy, 10)

        for label in subspace_bases:
            dim = subspace_bases[label].shape[1]
            if dim == 1:
                u = subspace_bases[label].iloc[:,0].loc[[c1,c2,c3]]
                _draw_subspace_line(ax,u,maxx,maxy,maxz,label,min_label,max_label)
            if dim == 2:
                # the basis vectors are orthogonal
                # But when you project them down to this space, their projections may not be orthogonal
                # They could even be parallel, or even 0 (Take u,v = 0,0,1,-1 and 0,0,1,1 projected to first three dimensions)
                # If they are parallel, we should be drawing a line, not a plane
                # If they are 0, we should just draw a single point (happens to be at the origin
                u = subspace_bases[label].iloc[:,0].loc[[c1,c2,c3]]
                v = subspace_bases[label].iloc[:,1].loc[[c1,c2,c3]]

                npu = np.array(u)
                npv = np.array(v)
                lenu = np.linalg.norm(u)
                lenv = np.linalg.norm(v)
                if lenu < 0.01 and lenv < 0.01:
                    # Both of the vectors are basically 0.  Can't draw anything.
                    continue
                elif lenu < 0.01:
                    _draw_subspace_line(ax, v, maxx,maxy,maxz, label, min_label, max_label)
                    continue
                elif lenv < 0.01:
                    _draw_subspace_line(ax, u, maxx,maxy,maxz, label, min_label, max_label)
                    continue

                npu = npu / lenu
                npv = npv / lenv

                dotproduct = np.dot(npu,npv)
                if dotproduct > 0.9 or dotproduct < -0.9:
                    print(dotproduct)
                    # Vectors are basically parallel.  Can't draw a plane, maybe can draw a line.
                    if dotproduct < 0:
                        npu = -npu
                    npu = (npu + npv) / 2
                    _draw_subspace_line(ax, npu, maxx, maxy, maxz, label, min_label, max_label)
                    continue

                normal = np.cross(u.T, v.T)
                # Normal vector gives x*n0 + y*n1 + z*n2 = 0
                # Divide everything by n2
                # x * n0/n2 + y * n1/n2 + z = 0
                # z = -n0/n2 * x - n1 / n2 * y
                X, Y = np.meshgrid(x,y)
                Z = -normal[0]/normal[2] * X - normal[1] / normal[2] * Y
                rgba = _get_color(label, min_label, max_label)
                rgba = rgba[0],rgba[1],rgba[2],0.25
                surf = ax.plot_surface(X, Y, Z, color=rgba)


    if target_cluster is None:
        ax.scatter(filtered_df[c1], filtered_df[c2], filtered_df[c3], c=cs.clusters, cmap="Set1")
        if cs.num_clusters() > 0:
            projected = all_proj
            ax.scatter(projected[col1_index], projected[col2_index], projected[col3_index], c=cs.clusters, marker='x', cmap='Set1')
    elif target_cluster is not None:
        pidx = np.where(cs.clusters == target_cluster)[0]
        cdf = filtered_df.iloc[pidx]
        rgba = _get_color(target_cluster, min_label, max_label)
        ax.scatter(cdf[c1], cdf[c2], cdf[c3], c=[rgba])
        if cs.num_clusters() > 0:
            projected = all_proj[:,pidx]
            ax.scatter(projected[col1_index], projected[col2_index], projected[col3_index], c=[rgba], marker='x')

    if subspace_bases is not None:
        patches = []
        for label in cs.cluster_dims:
            rgba = _get_color(label, min_label, max_label)
            patch = matplotlib.patches.Patch(
                color=rgba,
                label=str(label) +
                      ": dim=" + str(cs.cluster_dims[label]) +
                      "#fit=" + str(cluster_counts[label]))
            patches.append(patch)

    plt.title(title)
    ax.set_xlabel(c1)
    ax.set_ylabel(c2)
    ax.set_zlabel(c3)
    ax.set_xlim([0, maxx * 1.5 + 100])
    ax.set_ylim([0, maxy * 1.5 + 100])
    ax.set_zlim([0, maxz * 1.5 + 100])
    if subspace_bases is not None:
        plt.legend(handles=patches)
    plt.show()

In [11]:
genus_widget = widgets.Dropdown(options=list_genera())

axis1 = widgets.Dropdown(layout=widgets.Layout(width='50%'))
axis2 = widgets.Dropdown(layout=widgets.Layout(width='50%'))
axis3 = widgets.Dropdown(layout=widgets.Layout(width='50%'))

plot = widgets.Button(description='Plot')
calculate = widgets.Button(description='Calculate')
save = widgets.Button(description='Save')
output = widgets.Output()


state = {}

# Define a function that updates the content of y based on what we select for x
def update_genus(*args):
    genus = genus_widget.value
    refs_df = list_woltka_refs(genus)
    genomes = refs_df['#genome'].tolist()
    species = refs_df['species'].tolist()
    choices = []
    for i in range(len(genomes)):
        choices.append((species[i] + "(" + genomes[i] + ")", genomes[i]))
    
    axis1.options = choices
    axis2.options = choices
    axis3.options = choices
    axis1.value = genomes[min(0, len(genomes)-1)]
    axis2.value = genomes[min(1, len(genomes)-1)]
    axis3.value = genomes[min(2, len(genomes)-1)]
    
    filtered_df = df[genomes]
    filtered_df_sum = filtered_df.sum(axis=1)
    filtered_df = filtered_df[filtered_df_sum >= MIN_GENUS_COUNT]

    state.clear()
    state["genus"] = genus
    state["filtered_df"] = filtered_df
    state["cluster_manager"] = ClusterManager(filtered_df, filtered_df.shape[1])
    state["breakdown"] = data_transform.get(genus)

    with output:
        output.clear_output()
        if len(genomes) < 2:
            print(genus, "Not enough reference genomes to do clustering\n")
            calculate.disabled=True
        elif len(filtered_df) <= 10:
            print(genus, "Not enough reads to do clustering\n")
            calculate.disabled=True
        else:
            calculate.disabled=False

genus_widget.observe(update_genus, names='value')
update_genus()

display(genus_widget)
display(axis1)
display(axis2)
display(axis3)
display(plot)
display(calculate)
display(save)
display(output)

def plot_click(button):
    filtered_df = state["filtered_df"]
    plot_scatter3(
        state["filtered_df"], 
        state.get("cluster_manager"), 
        state["genus"],
        axis1.value,
        axis2.value,
        axis3.value
    )

def calculate_click(button):    
    to_cluster = state["filtered_df"]
    state["cluster_manager"] = ClusterManager(to_cluster, to_cluster.shape[1])
    rc = RecursiveClusterer()
    rc.run(state["cluster_manager"])
    
    state["clusterer"] = rc
    
def save_click(button):
    genus = state["genus"]
    filtered_df = state["filtered_df"]
    cm = state["cluster_manager"]
    cs = cm.cluster_state
    
    subspace_bases = calc_subspace_bases(filtered_df, cs.clusters, cs.cluster_dims)
    print(subspace_bases)
        
    surfaces = []
    for key in subspace_bases:
        basis = subspace_bases[key]
        basis_df = pd.DataFrame(basis, index=filtered_df.columns)
        surfaces.append(basis_df.to_json())

    data_transform[genus] = surfaces

    with open(DATA_TRANSFORM_PATH, 'w') as outfile:
        print("Saving to: " + DATA_TRANSFORM_PATH)
        json.dump(data_transform, outfile)

plot.on_click(plot_click)
calculate.on_click(calculate_click)
save.on_click(save_click)

()
Built Choices
Filtered DF
Created State


Dropdown(options=('Acetivibrio', 'Acetobacter', 'Acidaminococcus', 'Acinetobacter', 'Actinomyces', 'Akkermansi…

Dropdown(layout=Layout(width='50%'), options=(('Acetivibrio ethanolgignens(G001461035)', 'G001461035'), ('Acet…

Dropdown(index=1, layout=Layout(width='50%'), options=(('Acetivibrio ethanolgignens(G001461035)', 'G001461035'…

Dropdown(index=1, layout=Layout(width='50%'), options=(('Acetivibrio ethanolgignens(G001461035)', 'G001461035'…

Button(description='Plot', style=ButtonStyle())

Button(description='Calculate', style=ButtonStyle())

Button(description='Save', style=ButtonStyle())

Output()

In [15]:
def cluster_run_click(button):
    cmd = cluster_command.value
    ss = cmd.split()
    if ss[0] == 'split':
        cluster_id = int(ss[1])
        cluster_pieces = int(ss[2])
        state["cluster_manager"].get_split_cluster(cluster_id, cluster_pieces).apply()
        state["cluster_manager"].finalize()
    elif ss[0] == 'delete' or ss[0] == 'del':
        cluster_id = int(ss[1])
        state["cluster_manager"].get_delete_cluster(cluster_id).apply()
        state["cluster_manager"].finalize()
    else:
        print("Unsupported Command")

cluster_command = widgets.Text(layout=widgets.Layout(width='80%'), placeholder='split 1 2')
cluster_run = widgets.Button(description='Run Command')

display(cluster_command)
display(cluster_run)

cluster_run.on_click(cluster_run_click)


Text(value='', layout=Layout(width='80%'), placeholder='split 1 2')

Button(description='Run Command', style=ButtonStyle())

In [8]:
state["cluster_manager"].cluster_state.cluster_dims

{-1: 2}