# Interactive collision statistics and collision tree plots

This notebook uses JupyterDash to generate an interactive dashboard of all of the final planets and their collision statistics and histories. Clicking on a planet on the scatterplot will generate a plot of that planet's collision tree and the histogram of the collisions it experienced. Hovering over a point will tell you the planet's final mass, core-mass fraction (CMF), semi-major axis, eccentricity and inclination, as well as its run parameters, etc. 

In [3]:
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import pandas as pd

#import colltree functions 
import util
import colltree
import find_coll_history
import coll_history_stats

## Define inputs

In [4]:
#directory information
base_dir = "../../../../Planetform-data/excited-disks/"
dirtable = "good_plotting_dirs.csv"

#collision table file name
vtab = ''

#name of stats files
numcoll_file = 'excited-numcoll-giant-newest-update-cmfav-nonadditive-nominmass-wallpls-wstability-final.csv'
collhist_file = 'excited-disks-collhist-giant-embs-vtab-nonadditive-nominmass-wallpls-final.csv'
collhist_small_file = ''

#info to generate stats files
cparam = 'giant'
minemb = 'embryo'

#options for dropdown menu
available_indicators = ['mass','a','e','inc','cmf','cmf_av']
param_indicators = ['ecc_init','inc_init','dloss','slope']
available_indicators_2 = ['mass','a','cmf']

In [5]:
try:
    #read in pre-existing stats files
    numcoll_all = pd.read_csv(base_dir+numcoll_file)
    collhists = pd.read_csv(base_dir+collhist_file)
except:
    #generate the stats files
    collhists = find_coll_history.get_collhist(base_dir,dirtable,vtab,cparam,minemb,collhist_file,collhist_small_file,True)
    numcoll_all = coll_history_stats.get_numcoll(base_dir,dirtable,cparam,minemb,numcoll_file,True)
    
#convert to useful values
mass_names = ['tmass','pmass','LRMass','SLRMass']
for name in mass_names:
    collhists[name] = collhists[name].apply(util.munit_to_mearth)

## Build the JupyterDash app

In [6]:
# Build App
app = JupyterDash(__name__)

#set up layout of the app
app.layout = html.Div([
    html.H2("Collision statistics",
           style={'textAlign': 'center'}),
    html.Div([
    html.Div([ html.Label([ 'y-axis:',
        dcc.Dropdown(
            id='crossfilter-yaxis-column',
            options=[{'label': i, 'value': i} for i in available_indicators],
            value='cmf_av'
        )])
    ],style={'width': '25%','float': 'left','padding-left': '150px'}),
    html.Div([ html.Label([ "colour scheme:",
        dcc.Dropdown(
            id='colour-param',
            options=[{'label': i, 'value': i} for i in param_indicators],
            value='dloss'
        )])
    ],style={'width': '25%', 'float': 'right','padding-right': '150px', 'padding-left': '75px'}),
    ],className='row'),
    
    html.Div([
        html.Div([
        dcc.Graph(
            id='collision',
        )]),
        html.Div([
        dcc.Dropdown(
            id='crossfilter-yaxis-collhist',
            options=[{'label': i, 'value': i} for i in available_indicators_2],
            value = 'mass'),
        ],style={'width': '25%', 'padding-left': '150px'}),
        dcc.Graph(
            id='planet-coll-history'
        ),
        dcc.Graph(
            id='planet-ctype-hist'
        )
    ]),
    
    
])

@app.callback(
    Output('collision', 'figure'),
    [Input('crossfilter-yaxis-column', 'value'),
     Input('colour-param', 'value')],
    )
def update_graph(yaxis_column_name,colour_name):
    """plot the collision statistics"""
    
    #change which parameter size corresponds to depending on y-axis
    if yaxis_column_name == 'mass':
        size_name = 'cmf_av'
    elif yaxis_column_name == 'cmf_av':
        size_name = 'mass'
    else:
        size_name = 'mass'
        
    fig = px.scatter(numcoll_all,x='tot_giant',y=yaxis_column_name,color=colour_name,size=size_name,
                     hover_name='dirs',hover_data=['pid','mass','cmf_av','a','e','inc','stability'],custom_data=['dirs','pid','mass','cmf_av'])
    fig.update_layout(margin={'l': 20, 'b': 15, 't': 20, 'r': 30})
    
    
    #fig.update_yaxes(title=yaxis_column_name)
    fig.update_xaxes(title='number of giant collisions')
    
    return(fig)

def plot_coll_history(pdir,pid,title,y_axis):
    """Plot collision history"""
    
    CollTree = colltree.get_CollTree(pdir,pid,collhists)
    
    #read in pl.maxcorecollisions
    comp = util.read_comp(base_dir,pdir,'pandas')

    #create a group for each collision
    plot_tab = CollTree.groupby('time')

    #create the figure
    fig = go.Figure()
    colours = ['#000000','#E775C6','#FA8013','#9764C0','#2EA21F','#D4272A','#1BBBD1',
               '#2277B0','#A3DB37']
    coll_groups = ['1','2','3','4','5','6','7','8','9']
    coll_names = ['super-catastrophic','catastrophic disruption','erosion','partial accretion',
                  "hit'n'spray","hit'n'run",'graze','merge',"graze'n'merge"]

    #iterate over each collision group
    for name, group in plot_tab:
        ctype = group['ctype'].iloc[0]
        
        fig.add_trace(go.Scatter(x=group['time'],y=group[y_axis],line_color=colours[int(ctype-1)],hovertext=group['body'],
                      legendgroup=coll_groups[int(ctype-1)],name=coll_names[int(ctype-1)]))
        l = len(group)   

        #if the LR has children, find them and connect them
        connect_1 = pd.DataFrame(columns=['time','body','id','mass','cmf','a','ctype','children','parents'])

        #connect to children
        if l == 3:
            connect_1.loc[0] = group.iloc[l-1].copy()
        elif l == 4:
            #there's a SLR
            connect_1.loc[0] = group.iloc[l-2].copy()
            
            
        #find children of this body (i.e. next collision that this body experiences)
        maskc = CollTree['body'] == connect_1['children'].loc[0][0]
        
        if len(CollTree[maskc]) > 0:
            coll_time = CollTree['time'][maskc].iloc[0]
            
            #add inbetween data points from pl.maxcorecompositions
            start_time = connect_1['time'].loc[0]
            maskt1 = comp['time'] > start_time
            maskt2 = comp['time'] < coll_time
            maskid = comp['iinit'] == connect_1['id'].loc[0]
            timesteps = comp[maskid&maskt1&maskt2]
            
            for i in range(0,len(timesteps)):
                #add in each row at a time
                connect_1.loc[i+1] = [timesteps['time'].iloc[i],'timestep',timesteps['iinit'].iloc[i],timesteps['mtot'].iloc[i],
                                      timesteps['cmf'].iloc[i],timesteps['a'].iloc[i],0,[None],[None]]

            #then add in the next collision that was found
            lc = len(connect_1)
            connect_1.loc[lc] = CollTree[maskc].iloc[0]
            
        fig.add_trace(go.Scatter(x=connect_1['time'],y=connect_1[y_axis],hovertext=connect_1['body'],
                                 mode="lines",line_color = colours[int(ctype-1)],line_dash='dash',showlegend=False))
        
        if l == 4:
            #plot SLR if it exists
            connect_2 = pd.DataFrame(columns=['time','body','id','mass','cmf','a','ctype','children','parents'])
            connect_2.loc[0] = group.iloc[l-1].copy()
            maskc2 = CollTree['body'] == connect_2['children'].loc[0][0]
            if len(CollTree[maskc2]) > 0:

                coll_time = CollTree['time'][maskc2].iloc[0]

                #add inbetween data points from pl.maxcorecompositions
                start_time = connect_2['time'].loc[0]
                maskt1 = comp['time'] > start_time
                maskt2 = comp['time'] < coll_time
                maskid = comp['iinit'] == connect_2['id'].loc[0]
                timesteps2 = comp[maskid&maskt1&maskt2]
                #print(timesteps2)

                for i in range(0,len(timesteps2)):
                    #add in each row at a time
                    connect_2.loc[i+1] = [timesteps2['time'].iloc[i],'timestep',timesteps2['iinit'].iloc[i],timesteps2['mtot'].iloc[i],
                                          timesteps2['cmf'].iloc[i],timesteps2['a'].iloc[i],0,[None],[None]]

                #then add in the next collision that was found
                lc = len(connect_2)
                connect_2.loc[lc] = CollTree[maskc2].iloc[0]
            
            fig.add_trace(go.Scatter(x=connect_2['time'],y=connect_2[y_axis],hovertext=connect_2['body'],
                                     mode="lines",line_dash='dash',line_color = colours[int(ctype-1)],showlegend=False))
            
            if len(CollTree[maskc]) == 0 and len(CollTree[maskc2]) == 0:
                tstart = group['time'].iloc[0]

            
        if l==3 and len(CollTree[maskc]) == 0:
            tstart = group['time'].iloc[0]
    
    
    #plot out to current planet
    mask = comp['time'] >= tstart
    maskid = comp['iinit'] == pid
    final = comp[mask&maskid]
    fig.add_trace(go.Scatter(x=final['time'],y=final[y_axis],hovertext=final['iinit'],
                             mode="lines",line_dash='dash',line_color=colours[int(ctype-1)],showlegend=False))
    #plot final
    lx = len(comp[maskid])
    t_final = comp[maskid]['time'].iloc[lx-1]
    mask_final = comp['time'] == t_final
    fig.add_trace(go.Scatter(x=comp[maskid&mask_final]['time'],y=comp[maskid&mask_final][y_axis],
                             mode="markers",marker=dict(color=colours[int(ctype-1)]),showlegend=False))

    
    #add labels
    fig.add_annotation(x=0, y=0.78, xanchor='left', yanchor='bottom',
                       xref='paper', yref='paper', showarrow=False, align='left',
                       bgcolor='rgba(255, 255, 255, 0.5)', text=title)

    fig.update_xaxes(type="log",title_text="Time (yr)")
    fig.update_yaxes(title_text=y_axis)
    fig.update_layout(margin={'l': 40, 'b': 20, 't': 20, 'r': 20})
    
    #makes sure that there is only one label for each collision type
    names = set()
    fig.for_each_trace(
        lambda trace:
            trace.update(showlegend=False)
            if (trace.name in names) else names.add(trace.name))
    
    
    return(fig)


def plot_ctype_hist(pdir,pid):
    """Plots a histogram of the collision types that this planet experienced."""
    ctypes = [1,2,3,4,5,6,7,8,9]
    cnames = ['c1','c2','c3','c4','c5','c6','c7','c8','c9']
    
    m_dir = numcoll_all['dirs'] == pdir
    m_pid = numcoll_all['pid'] == pid
    numcoll_masked = numcoll_all[m_dir&m_pid]
    
    new_tab = pd.DataFrame(columns=('ctype','num'))
    for i in range(0,len(ctypes)):
        new_tab.loc[i] = [ctypes[i],numcoll_masked[cnames[i]].iloc[0]]
    fig = px.bar(new_tab,x='ctype',y='num')
    return(fig)
    
@app.callback(
    [Output('planet-coll-history', 'figure'),
     Output('planet-ctype-hist','figure')],
    [Input('collision', 'clickData'),
    Input('crossfilter-yaxis-collhist', 'value')])
def update_coll_history(clickData,yaxis):
    """get information from scatter plot to plot two subplots (collision history and frequency histogram)""" 
    info = clickData['points'][0]['customdata']
    pdir = info[0] #directory
    pid = info[1] #planet id
    mass = round(info[2],3) #planet mass
    cmf = round(info[3],4) #planet cmf
    title = '<b>{}</b><br>{}<br>Mass: {} M_E<br>CMF: {}'.format(pdir, pid, mass, cmf) 
    return [plot_coll_history(pdir,pid,title,yaxis),
            plot_ctype_hist(pdir,pid)]

In [7]:
app.run_server(mode='jupyterlab',debug=False)

 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [12/Sep/2022 16:05:11] "[37mGET /_alive_b33e5e8e-b0d4-4ccf-8d47-142cf60eaf41 HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:15] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:15] "[37mGET /assets/dash-stylesheet.css?m=1660329003.9780881 HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:23] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:23] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:23] "[37mGET /_favicon.ico HTTP/1.1[0m" 200 -
127.0.0.1 - - [12/Sep/2022 16:05:23] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -
127.0.0.1 - - [12/Sep/2022 16:05:26] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
