In [2]:
#import plotly.io as pio
import plotly
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd

In [3]:


def sphere3D(radius):
    '''
    Plot 3D Sphere Gridline for Poltly.
    https://community.plotly.com/t/adding-wireframe-around-a-sphere/37661/2
    '''
    theta = np.linspace(0, 2*np.pi, 120)
    phi = np.linspace(0, np.pi, 60) # can increase values for more specific fit
    u , v = np.meshgrid(theta, phi)
    xs = np.cos(u)*np.sin(v)*radius
    ys = np.sin(u)*np.sin(v)*radius
    zs = np.cos(v)*radius

    x = []
    y = []
    z = []
    for t in [theta[10*k] for k in range(12)]:  # meridians:
        x.extend(list(radius*np.cos(t)*np.sin(phi))+[None])# None is inserted to mark the end of a meridian line
        y.extend(list(radius*np.sin(t)*np.sin(phi))+[None]) 
        z.extend(list(radius*np.cos(phi))+[None])
        
    for s in [phi[6*k] for k in range(10)]:  # parallels
        x.extend(list(radius*np.cos(theta)*np.sin(s))+[None]) # None is inserted to mark the end of a parallel line 
        y.extend(list(radius*np.sin(theta)*np.sin(s))+[None]) 
        z.extend([radius*np.cos(s)]*120+[None])

    fig=go.Figure() 
    
    # fig.add_surface(x=xs, y=ys, z=zs, 
    #                 colorscale=[[0, '#ffffff' ], [1, '#ffffff']], 
    #                 showscale=False, opacity=0.2)  # or opacity=1
    
    fig.add_scatter3d(x=x, y=y, z=z, line_width=2, line_color='rgb(160,160,160)',showlegend = False, mode='lines')
    fig.update_layout(width=700, height=700)
    
    return fig

In [4]:
%store -r w_0_clustered

In [5]:
df = pd.DataFrame(w_0_clustered)

In [10]:
%store -r w_0_cluster_labels


In [6]:
df.columns = ['x','y','z']


In [11]:
df['cluster labels'] = w_0_cluster_labels

In [19]:

# custom color palette
#my_pal_plotly =  {'2 mons':color_hex[0], '25 mons': color_hex[1],'27 mons': '#4a1c02'}

# plot the data point
#color =  'Age(month)', color_discrete_map = my_pal_plotly <- add to 
fig1 = px.scatter_3d(df, x='x', y='y', z='z', color='cluster labels', color_discrete_map={'ON': 'teal', 'OFF': '#FF8C00', 'AVA': 'red', 'RME': 'blue', 'SMDV': 'purple', 'SMDD': 'yellow'} )
fig1.update_traces(marker=dict(size=12, opacity=0.6, line=dict(color = 'black', width= 2)), selector=dict(mode='markers'))

# plot the sphere, this will plot a sphere of some radius, here I choose 1 as my radius
fig2 = sphere3D(1)

# plot the geodesic, or a line in this space
# fig3 = px.line_3d(df, x="x", y="y", z="z")
# fig3.update_traces(marker=dict(color="#80b1d3"))

# combine all the plots together
fig = go.Figure(data=fig1.data + fig2.data) # + fig3.data)

fig.update_layout(
    width=800, height=700, 
    title_text='',template='simple_white',
      margin=go.layout.Margin(
        l=0, #left margin
        r=0, #right margin
        b=0, #bottom margin
        t=0  #top margin
    ),
    scene = dict(
        xaxis = dict(tickvals=[-1, -0.5, 0, 0.5, 1],
            ticktext=['-1', '-0.5', '0', '0.5', '1'],
            title=dict(text='X', font=dict(size=20)),
            tickfont=dict(size=15)),
        
        yaxis = dict(tickvals=[-1, -0.5, 0, 0.5, 1],
            ticktext=['-1', '-0.5', '0', '0.5', '1'],
            title=dict(text='Y ', font=dict(size=20)),
            tickfont=dict(size=15)),
        
        zaxis =dict(tickvals=[-1, -0.5, 0, 0.5, 1],
            ticktext=['-1', '-0.5', '0', '0.5', '1'],
            title=dict(text='Z ', font=dict(size=20)),
            tickfont=dict(size=15))
        )
)

fig.show()
