In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
import os

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder

from icecream import ic

# local files
from src.embeddings.embeddings import get_embeddings
from src.util.data_handling.data_loader import save_as_pickle, load_dataset, make_dir
from src.util.data_handling.string_generator import ALPHABETS
from src.data.edit_distance import edit_distance

INFO: Using numpy backend


# Load Data

In [15]:
# initial values
data_name = 'ibd'
model_class = 'cnn'
distance = 'hyperbolic'
embedding_size = 2

# define paths
ihmp_data_path = '../data/interim/ihmp/{}_data.csv'.format(data_name)
metadata_path = '../data/interim/ihmp/{}_metadata.csv'.format(data_name)
otu_embedding_path = '../data/processed/otu_embeddings/{}/{}_{}_{}_otu_embeddings.csv'.format(data_name, model_class, distance, embedding_size)
mixture_embedding_path = '../data/processed/mixture_embeddings/{}/{}_{}_{}_mixture_embeddings.csv'.format(data_name, model_class, distance, embedding_size)

auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'
edit_distance_path = '../data/interim/ihmp/edit_distance/{}/{}_{}_{}_edit_distance.csv'.format(data_name, model_class, distance, embedding_size)

otu_embeddings_plot_path = '../reports/embeddings/{}/{}_{}_{}_otu_embeddings_on_poincare_disk.html'.format(data_name, model_class, distance, embedding_size)

# load data
data = pd.read_csv(ihmp_data_path, index_col='Sample')
metadata = pd.read_csv(metadata_path, index_col='Sample')
otu_embeddings = pd.read_csv(otu_embedding_path, index_col='OTU')
mixture_embeddings = pd.read_csv(mixture_embedding_path, index_col='Sample')

# make columns and indicies integers, not strings
otu_embeddings.columns = otu_embeddings.columns.map(int)
mixture_embeddings.columns = mixture_embeddings.columns.map(int)

In [16]:
data

Unnamed: 0_level_0,1000269,1008348,1009894,1012376,1017181,1017413,1019823,1019878,102222,1023075,...,964363,968675,968954,971907,975306,976470,979707,988375,988932,999046
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CSM5FZ3N,0.000003,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.000022,0.000041,0.000000,0.000451,0.000691,0.000000,0.000000,0.000011,0.000000
CSM5FZ3X,0.000003,0.000006,0.0,0.0,0.000012,0.000000,0.0,0.000009,0.000000,0.0,...,0.0,0.000062,0.000042,0.000021,0.000665,0.000009,0.000000,0.000128,0.000006,0.000003
CSM5FZ3Z,0.000000,0.000000,0.0,0.0,0.000000,0.000012,0.0,0.000012,0.000000,0.0,...,0.0,0.000061,0.000004,0.000000,0.000210,0.000000,0.000000,0.000000,0.000020,0.000000
CSM5FZ44,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
CSM5FZ46,0.000000,0.000000,0.0,0.0,0.000000,0.000005,0.0,0.000000,0.000000,0.0,...,0.0,0.000020,0.000010,0.000005,0.000243,0.000000,0.000000,0.000000,0.000010,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MSM5LLIO,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000017,0.000005,0.0,...,0.0,0.000031,0.000014,0.000010,0.001298,0.000010,0.000000,0.001417,0.000005,0.000000
MSM5LLIQ,0.000162,0.000000,0.0,0.0,0.000019,0.000011,0.0,0.000623,0.000000,0.0,...,0.0,0.002586,0.000686,0.001035,0.000233,0.000025,0.000000,0.000000,0.000206,0.000008
MSM5LLIS,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
MSM5ZOJY,0.000031,0.000000,0.0,0.0,0.000017,0.000021,0.0,0.001939,0.000000,0.0,...,0.0,0.003953,0.000159,0.002219,0.005125,0.000291,0.000000,0.000000,0.000025,0.000002


In [17]:
metadata

Unnamed: 0_level_0,Participant,Sample Collection Date,Visit Number,Hospital,Age,Diagnosis,HBI,Sex,Race,Fecalcal,SCCAI
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
CSM5FZ3N,C3001,2014-03-14,4,Cedars-Sinai,43.0,CD,4.0,Female,White,193.89,0.0
CSM5FZ3X,C3002,2014-05-13,5,Cedars-Sinai,76.0,CD,7.0,Female,White,71.48,0.0
CSM5FZ3Z,C3002,2014-05-28,6,Cedars-Sinai,76.0,CD,8.0,Female,White,156.73,0.0
CSM5FZ44,C3002,2014-06-24,8,Cedars-Sinai,76.0,CD,7.0,Female,White,54.33,0.0
CSM5FZ46,C3002,2014-07-08,9,Cedars-Sinai,76.0,CD,6.0,Female,White,54.74,0.0
...,...,...,...,...,...,...,...,...,...,...,...
MSM5LLIO,M2021,2014-06-17,11,MGH,26.0,CD,2.0,Male,White,89.32,0.0
MSM5LLIQ,M2026,2014-04-16,4,MGH,21.0,UC,0.0,Female,White,224.07,7.0
MSM5LLIS,M2027,2014-05-02,4,MGH,41.0,CD,0.0,Male,Other,194.74,0.0
MSM5ZOJY,M2014,2014-04-22,9,MGH,30.0,CD,1.0,Male,White,219.23,0.0


In [18]:
otu_embeddings

Unnamed: 0_level_0,0,1
OTU,Unnamed: 1_level_1,Unnamed: 2_level_1
1000269,-0.971315,-0.233556
1008348,-0.188370,0.981080
1009894,-0.938192,0.343216
1012376,-0.713432,0.699297
1017181,-0.995324,-0.085623
...,...,...
976470,-0.775837,-0.629347
979707,-0.535303,0.843476
988375,-0.504512,0.862246
988932,-0.979092,-0.198445


In [19]:
mixture_embeddings

Unnamed: 0_level_0,0,1
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1
CSM5FZ3N,0.000000,0.000000
CSM5FZ3X,0.000000,0.000000
CSM5FZ3Z,-0.019282,0.780946
CSM5FZ44,0.000000,0.000000
CSM5FZ46,0.000000,0.000000
...,...,...
MSM5LLIO,-0.187492,0.394580
MSM5LLIQ,-0.359186,0.083998
MSM5LLIS,-0.237867,0.063742
MSM5ZOJY,-0.379982,-0.431924


In [20]:
# remove nans
mixture_embeddings[np.isnan(mixture_embeddings)] = 0

In [21]:
feature_names = metadata.columns.to_list()
features = {feature_name: metadata[feature_name].to_list() for feature_name in feature_names}

# Compute Edit Distance

In [22]:
def get_edit_distance_matrix(otu_ids, id_to_str_seq):
    
    str_seqs = [id_to_str_seq[str(_id)] for _id in otu_ids]
    edit_distance_matrix = edit_distance(str_seqs, n_thread=16)
    edit_distance_df = pd.DataFrame(edit_distance_matrix, index=otu_ids, columns=otu_ids)
    return edit_distance_df

In [23]:
id_to_str_seq, _, _, _ = load_dataset(auxillary_data_path)
otu_ids = data.columns.to_list()

if not os.path.exists(edit_distance_path):
    edit_distance_df = get_edit_distance_matrix(otu_ids, id_to_str_seq)
    edit_distance_df.to_csv(make_dir(edit_distance_path))
else:
    edit_distance_df = pd.read_csv(edit_distance_path, index_col=0)

edit_distance_df

Unnamed: 0,1000269,1008348,1009894,1012376,1017181,1017413,1019823,1019878,102222,1023075,...,964363,968675,968954,971907,975306,976470,979707,988375,988932,999046
1000269,0,349,322,393,315,325,305,304,347,283,...,286,341,80,333,292,356,365,377,310,332
1008348,349,0,348,236,379,374,378,390,511,382,...,375,402,355,382,371,457,255,266,368,385
1009894,322,348,0,390,332,243,338,345,448,332,...,355,258,322,248,346,407,366,368,332,246
1012376,393,236,390,0,411,407,418,432,538,425,...,396,437,390,418,397,496,309,323,408,421
1017181,315,379,332,411,0,352,313,335,451,331,...,321,348,320,338,327,397,396,390,327,345
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
976470,356,457,407,496,397,421,342,400,373,386,...,335,431,347,413,269,0,458,474,348,421
979707,365,255,366,309,396,387,384,404,511,380,...,361,395,368,390,385,458,0,273,382,382
988375,377,266,368,323,390,382,390,423,514,395,...,388,405,376,385,374,474,273,0,406,392
988932,310,368,332,408,327,349,290,316,447,300,...,275,354,299,357,276,348,382,406,0,349


# Plot OTU Embeddings on Poincare Disk

In [42]:
def plot_otu_embeddings(x, y, otu_ids, edit_distance_matrix, title='OTU Embeddings on Poincare Disk', subset=200):
    
    # create subset of data for clearer, simpler visualization
    if subset:
        x = x[subset:2*subset]
        y = y[subset:2*subset]
        otu_ids = otu_ids[subset:2*subset]
        edit_distance_matrix = edit_distance_matrix[subset:2*subset, subset:2*subset]
            
    # initialize dataframe. Rename the 0th OTU Edit Distance so that the colorbar is labeled 'Edit Distance'
    first_otu_idx = 0
    df = pd.DataFrame({'x': x, 'y': y} | {i: edit_distance_matrix[i] for i in range(len(edit_distance_matrix))})
    df = df.rename(columns={first_otu_idx: 'Edit Distance'})

    # plot embeddings
    fig = px.scatter(df,
                     x='x', y='y',
                     color='Edit Distance',
                     # color_continuous_scale='YlGnBu'
                     # color_continuous_scale='Viridis_r'
                     )
   
    # plot Poincare circumfrence
    offset = 0.02
    poincare_circumfrence = dict(
        type="circle",
        xref="x", yref="y",
        x0=-1 - offset, y0=-1 - offset, x1=1 + offset, y1=1 + offset,
        line=dict(color="Black", width=4)
    )
    fig.add_shape(poincare_circumfrence)
    
    # make circle around initial OTU
    radius = 0.05
    circle = dict(type="circle",
        xref="x", yref="y",
        x0=x[first_otu_idx] + radius, y0=y[first_otu_idx] + radius,
        x1=x[first_otu_idx] - radius, y1=y[first_otu_idx] - radius,
        line=dict(color="Black"))
    fig.add_shape(circle)
    
    # create options for dropdown menu
    menu_options = []
    for i in range(len(x)):
        
        circle = dict(type="circle",
            xref="x", yref="y",
            x0=x[i] + radius, y0=y[i] + radius,
            x1=x[i] - radius, y1=y[i] - radius,
            line=dict(color="Black"))
        
        menu_options.append(
            dict(
                label=otu_ids[i],
                method='update',
                args=[{'marker.color': [edit_distance_matrix[i]]},
                      {'shapes': [circle] + [poincare_circumfrence]},
                      {"colorbar":{"title":{"text":"data_a_title"}}}
                    ]
                )
            )
    
    # create dropdown menu
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=menu_options,
                direction='down',
                pad={'r': 10, 't': 10},
                showactive=True,
                x=0.11,
                xanchor='left',
                y=1.1,
                yanchor='top'
            ),
        ]
    )
    
    # Add annotation
    fig.update_layout(
        annotations=[
            dict(text="OTU ID", showarrow=False,
            x=0, y=1.07, xref='paper', yref='paper', align="left")
        ]
    )
    
    # set figure size
    fig.update_layout(
        autosize=False, width=1000, height=1000
    )

    # set limits and remove grid
    offset = 0.05
    fig.update_xaxes(range=[-1 - offset, 1 + offset], showgrid=False)
    fig.update_yaxes(range=[-1 - offset, 1 + offset], showgrid=False)
    
    # add figure title, legened title, and axis titles
    fig.update_layout(
        xaxis_title_text='Dimension 1', yaxis_title_text='Dimension 2', 
        legend_title_text='Edit Distance', 
        title={'text': title, 'xanchor': 'center', 'x':0.57}
        )

    return fig

In [43]:
edit_distance_matrix = edit_distance_df.values
x = otu_embeddings[0].to_list()
y = otu_embeddings[1].to_list()
otu_ids = data.columns.to_list()

title = 'OTU Embeddings on Poincare Disk:<br>{} Data | {} {} {}'.format(data_name.upper(), model_class.upper(), distance, embedding_size)

otu_fig = plot_otu_embeddings(x, y, otu_ids, edit_distance_matrix, title=title)

otu_fig.write_html(make_dir(otu_embeddings_plot_path))
otu_fig.show()

# Plot Mixture Embeddings on Poincare Disk

In [None]:
def plot_mixture_embeddings(x, y, features, title='Poincare Disk'):

    # convert data to dataframe
    df = pd.DataFrame({'x': mixture_embeddings[:, 0], 'y': mixture_embeddings[:, 1]} | features)
    
    # Create the Plotly figure
    fig = go.Figure()
    
    # set figure size
    fig.update_layout(
        autosize=False, width=700, height=700
    )
    
    # plot Poincare circumfrence
    offset = 0.02
    poincare_circumfrence = dict(
        type="circle",
        xref="x", yref="y",
        x0=-1 - offset, y0=-1 - offset, x1=1 + offset, y1=1 + offset,
        line=dict(color="Black", width=4)
    )
    fig.add_shape(poincare_circumfrence)
    
    # create scatter plots
    feature_to_visible = defaultdict(list)
    cum_sum = 0
    
    for f_idx, feature_name in enumerate(features.keys()):
        
        # create scatter figure
        _fig = px.scatter(df, x, y, color=feature_name)
        
        # only make the first feature visible
        visible = f_idx == 0
        
        for i in range(len(_fig.data)):
            
            # add scatter plot to our figure with vis
            fig.add_trace(go.Scatter(_fig.data[i], visible=visible))
            
            feature_to_visible[feature_name].append(cum_sum)
            cum_sum += 1
    total_data = cum_sum

    
    # create buttons
    buttons = []
    for i, feature_name in enumerate(features.keys()):
        visible_boolean = [False] * total_data
        true_idxs = feature_to_visible[feature_name]
        
        for idx in true_idxs:
            visible_boolean[idx] = True
        
        button = dict(
            label=feature_name,
            # method='restyle',
             method='update',
            args=[{'visible': visible_boolean},
                  {"title": "{}: {}".format(title, feature_name)},
                  ] # 'legend.title.text': feature_name
        )
        buttons.append(button)
    buttons = list(buttons)
    
        
    # create dropdown menu
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                showactive=True,
                x=0.11,
                xanchor="left",
                y=1.087,
                yanchor="top"
            ),
        ]
    )

    
    # Add annotation
    fig.update_layout(
        annotations=[
            dict(
                text="Color ", showarrow=False,
                x=0, y=1.07, xref='paper', yref='paper', align="left"
                )
            ]
        )
    
    # set title and axis
    first_feature_name = list(features.keys())[0]
    fig.update_layout(
        xaxis_title_text='Dimension 1', yaxis_title_text='Dimension 2', 
        legend_title_text='Feature', 
        title='{}: {}'.format(title, first_feature_name)
        )


    # set limits
    offset = 0.05
    fig.update_xaxes(range=[-1 - offset, 1 + offset])
    fig.update_yaxes(range=[-1 - offset, 1 + offset])
    
    return fig

In [None]:
mixture_fig = plot_mixture_embeddings(mixture_embeddings[:, 0], mixture_embeddings[:, 1], features)

mixture_fig.show()

# Mixture and OTU Embeddings on Poincare Disk

In [None]:
def plot_embeddings(mixture_x, mixture_y, otu_x, otu_y, otu_table, features, title='Poincare Disk'):

    # convert data to dataframe
    df = pd.DataFrame({'x': mixture_embeddings[:, 0], 'y': mixture_embeddings[:, 1]} | features)
    
    # Create the Plotly figure
    fig = go.Figure()
    
    # set figure size
    fig.update_layout(
        autosize=False, width=700, height=700
    )
    
    # plot Poincare circumfrence
    offset = 0.02
    poincare_circumfrence = dict(
        type="circle",
        xref="x", yref="y",
        x0=-1 - offset, y0=-1 - offset, x1=1 + offset, y1=1 + offset,
        line=dict(color="Black", width=4)
    )
    fig.add_shape(poincare_circumfrence)
    
    # plot otu embeddings
    otu_embeddings = go.Scatter(x=otu_x, y=otu_y,
                                mode='markers',
                                marker=dict(size=otu_table['1000269'].to_numpy() + 10,
                                            color='black'
                                            ),
                                visible=True
                                )
    fig.add_trace(otu_embeddings)
    
    
    # create scatter plots of mixture embeddings with various metadata features
    feature_to_visible = defaultdict(list)
    cum_sum = 0
    
    for f_idx, feature_name in enumerate(features.keys()):
        
        # plot mixture embeddings
        _fig = px.scatter(df, mixture_x, mixture_y, color=feature_name)
        
        # only make the first feature visible
        visible = f_idx == 0
        
         # plot otu embeddings
        otu_embeddings = go.Scatter(x=otu_x, y=otu_y,
                                    mode='markers',
                                    marker=dict(size=otu_table['1000269'].to_numpy() + 10,
                                                color='black'
                                                ),
                                    visible=visible
                                    )
        fig.add_trace(otu_embeddings)
        feature_to_visible[feature_name].append(cum_sum)
        cum_sum += 1
        
        for i in range(len(_fig.data)):
            
            # add scatter plot to our figure with vis
            fig.add_trace(go.Scatter(_fig.data[i], visible=visible))
            
            feature_to_visible[feature_name].append(cum_sum)
            cum_sum += 1
    total_data = cum_sum

    
    # create buttons
    buttons = []
    for i, feature_name in enumerate(features.keys()):
        visible_boolean = [False] * total_data
        true_idxs = feature_to_visible[feature_name]
        
        for idx in true_idxs:
            visible_boolean[idx] = True
        
        button = dict(
            label=feature_name,
            # method='restyle',
             method='update',
            args=[{'visible': visible_boolean},
                  {"title": "{}: {}".format(title, feature_name)},
                  ] # 'legend.title.text': feature_name
        )
        buttons.append(button)
    buttons = list(buttons)
    
        
    # create dropdown menu
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                showactive=True,
                x=0.11,
                xanchor="left",
                y=1.087,
                yanchor="top"
            ),
        ]
    )

    
    # Add annotation
    fig.update_layout(
        annotations=[
            dict(
                text="Color ", showarrow=False,
                x=0, y=1.07, xref='paper', yref='paper', align="left"
                )
            ]
        )
    
    # set title and axis
    first_feature_name = list(features.keys())[0]
    fig.update_layout(
        xaxis_title_text='Dimension 1', yaxis_title_text='Dimension 2', 
        legend_title_text='Feature', 
        title='{}: {}'.format(title, first_feature_name)
        )


    # set limits
    offset = 0.05
    fig.update_xaxes(range=[-1 - offset, 1 + offset])
    fig.update_yaxes(range=[-1 - offset, 1 + offset])
    
    return fig

In [None]:
mixture_fig = plot_embeddings(
    mixture_embeddings[:, 0], mixture_embeddings[:, 1], 
    otu_embeddings[:, 0], otu_embeddings[:, 1],
    ihmp_data,
    features
    )

mixture_fig.show()

In [None]:
# def plot_embeddings(x, y, otu_ids, features, edit_distance_matrix, subset=200):
    
#     # create subset of data for clearer, simpler visualization
#     x = x[subset:2*subset]
#     y = y[subset:2*subset]
#     otu_ids = otu_ids[subset:2*subset]
#     edit_distance_matrix = edit_distance_matrix[subset:2*subset, subset:2*subset]
        
#     # initialize dataframe. Rename the 0th OTU Edit Distance so that the colorbar is labeled 'Edit Distance'
#     first_otu_idx = 0
#     df = pd.DataFrame({'x': x, 'y': y} | {i: edit_distance_matrix[i] for i in range(len(edit_distance_matrix))})
#     df = df.rename(columns={first_otu_idx: 'Edit Distance'})
    
#     # plot embeddings
#     fig = px.scatter(df,
#                      x='x', y='y',
#                      color='Edit Distance'
#                      )
    
#     # plot Poincare circumfrence
#     offset = 0.02
#     poincare_circumfrence = dict(
#         type="circle",
#         xref="x", yref="y",
#         x0=-1 - offset, y0=-1 - offset, x1=1 + offset, y1=1 + offset,
#         line=dict(color="Black", width=4)
#     )
#     fig.add_shape(poincare_circumfrence)
    
#     # make circle around initial OTU
#     radius = 0.05
#     circle = dict(type="circle",
#         xref="x", yref="y",
#         x0=x[first_otu_idx] + radius, y0=y[first_otu_idx] + radius,
#         x1=x[first_otu_idx] - radius, y1=y[first_otu_idx] - radius,
#         line=dict(color="Black"))
#     fig.add_shape(circle)
    
#     # # create options for dropdown menu
#     # menu_options = []
#     # for i in range(len(x)):
        
#     #     circle = dict(type="circle",
#     #         xref="x", yref="y",
#     #         x0=x[i] + radius, y0=y[i] + radius,
#     #         x1=x[i] - radius, y1=y[i] - radius,
#     #         line=dict(color="Black"))
        
#     #     menu_options.append(
#     #         dict(
#     #             label=otu_ids[i],
#     #             method='update',
#     #             args=[{'marker.size': [edit_distance_matrix[i]]},
#     #                   {'shapes': [circle] + [poincare_circumfrence]},
#     #                   {"colorbar":{"title":{"text":"data_a_title"}}}
#     #                 ]
#     #             )
#     #         )
#      # create options for dropdown menu
#     menu_options = []
#     for feature_names, feature_values in features.items():
#         menu_options.append(
#             dict(
#                 label=feature_names,
#                 method='update',
#                 args=[{'marker.size': [feature_values]},
#                       {'shapes': [circle] + [poincare_circumfrence]},
#                     ]
#                 )
#             )
    
#     # create dropdown menu
#     fig.update_layout(
#         updatemenus=[
#             dict(
#                 buttons=menu_options,
#                 direction='down',
#                 pad={'r': 10, 't': 10},
#                 showactive=True,
#                 x=0.11,
#                 xanchor='left',
#                 y=1.1,
#                 yanchor='top'
#             ),
#         ]
#     )
    
#     # Add annotation
#     fig.update_layout(
#         annotations=[
#             dict(text="OTU ID", showarrow=False,
#             x=0, y=1.07, xref='paper', yref='paper', align="left")
#         ]
#     )
    
#     # set figure size
#     fig.update_layout(
#         autosize=False, width=700, height=700
#     )

#     # set limits and remove grid
#     offset = 0.05
#     fig.update_xaxes(range=[-1 - offset, 1 + offset], showgrid=False)
#     fig.update_yaxes(range=[-1 - offset, 1 + offset], showgrid=False)
    
#     # add figure title, legened title, and axis titles
#     title='OTU Embeddings on Poincare Disk'
#     fig.update_layout(
#         xaxis_title_text='Dimension 1', yaxis_title_text='Dimension 2', 
#         legend_title_text='Edit Distance', 
#         title={'text': title, 'xanchor': 'center', 'x':0.57}
#         )

#     return fig

In [None]:
fig = plot_embeddings(otu_embeddings[:, 0], otu_embeddings[:, 1], otu_ids, features, edit_distance_matrix)
fig.show()

# Old Plots

In [None]:
# Source: https://9to5tutorial.com/drawing-a-poincare-disc-in-python

import numpy as np
import matplotlib.pyplot as plt

theta = np.linspace(0,2*np.pi,100)
colorlist = ["r","g","b","c","m","y"]

fig, ax = plt.subplots(figsize=(8, 8))

t = list(range(0,6))
for n in t:
    n2 = np.power(2,n)
    for phi in np.linspace(0,2*np.pi,2*n2+1):
        x = np.cos(theta)*np.tan(np.pi/n2) + np.cos(phi)/np.cos(np.pi/n2)
        y = np.sin(theta)*np.tan(np.pi/n2) + np.sin(phi)/np.cos(np.pi/n2)
        ax.plot(x, y, lw=0.5, color=colorlist[n-2])

for phi in np.linspace(0,2*np.pi,9):
    t = np.linspace(-2,2,100)
    x = t*np.cos(phi)
    y = t*np.sin(phi)
    ax.plot(x,y,lw=0.5,color='y')

ax.plot(np.cos(theta),np.sin(theta),color='black')
ax.set(xlim=(-1,1), ylim=(-1,1))

fig.show()

In [None]:
def contains_str(lst):
    return sum([isinstance(x, str) for x in lst]) > 0


def plot(x, y, features):
    # Create the Plotly figure
    fig = go.Figure()
    
    # set figure size
    fig.update_layout(
        autosize=False, width=700, height=700
    )

    # Add scatter plot
    feature_values = list(features.values())[0]
    if contains_str(feature_values):
            feature_values = LabelEncoder().fit_transform(feature_values)
    scatter_trace = go.Scatter(
        x=x,
        y=y,
        mode='markers',
        marker=dict(
            size=10,
            color=feature_values,  # Initial color is based on first feature
            colorscale='Viridis',
            showscale=True
        )
    )
    fig.add_trace(scatter_trace)
    
    # add circle
    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=-1, y0=-1, x1=1, y1=1,
        line_color="black",
    )
    
    # create options for dropdown menu
    menu_options = []
    for feature_name, feature_values in features.items():
        if contains_str(feature_values):
            feature_values = LabelEncoder().fit_transform(feature_values)
        menu_options.append(
            {'label': feature_name, 'method': 'update', 'args': [{'marker.color': [feature_values]}]}
            )
    
    # create dropdown menu
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=menu_options,
                direction='down',
                pad={'r': 10, 't': 10},
                showactive=True,
                x=0.11,
                xanchor='left',
                y=1.1,
                yanchor='top'
            ),
        ]
    )
    
    # Add annotation
    fig.update_layout(
        annotations=[
            dict(text="Color by", showarrow=False,
            x=0, y=1.07, xref='paper', yref='paper', align="left")
        ]
    )

    # set limits
    offset = 0.05
    fig.update_xaxes(range=[-1 - offset, 1 + offset])
    fig.update_yaxes(range=[-1 - offset, 1 + offset])

    return fig