# Subplots

In [4]:
# Import the library
import plotly as py

In [5]:
# Contrary to what we saw so far, the figure object with subplots is defined beforehand:
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1, 
    cols=2, 
    subplot_titles=('f(x) = sin(x)', 'f(x) = cos(x)'), 
    shared_yaxes=True,
)

## Adding traces with the `.add_` methods

In [6]:
# Once the figure has been created, we can add traces to the indicated subplot
import numpy as np
x = np.linspace(0, 2*np.pi)

fig.add_scatter(x=x, y=np.sin(x), name='sin(x)', row=1, col=1)
fig.add_scatter(x=x, y=np.cos(x), name='cos(x)', row=1, col=2)

## Change the figure a posteriori with the `.update_` methods

In [7]:
# Once the fig object has been created, and the traces have been added, 
# one can update any aspect of it by concatenating the .update methods.
# Here are some of the most common use cases:

(
    fig
    .update_layout(
        title='Trigonometric functions',
        showlegend=False,
        width=900,
        height=500,
    )
    .update_xaxes(
        title='x',
    )
    .update_yaxes(
        title='f(x)',
        col=1,
    )
    .update_traces(
        mode='lines+markers',
    )
)

In [8]:
# These methods return the modified fig, but also modify the fig object itself
fig

In [9]:
# You can .update_traces only in a particular row and/or col
fig.update_traces(
    col=1,
    mode='markers',
)

In [10]:
# You can also select a trace to update according to a selector
fig.update_traces(
    selector=dict(mode='lines+markers'),
    marker=dict(color='green'),
)

## Exercise

Using the example cluster data, loaded with:

In [11]:
import pandas as pd
table = pd.read_csv('https://raw.githubusercontent.com/chumo/Data2Serve/master/transition_clusters.csv')

Build the following figure

![](images/fig_03.1.png)

In [12]:
 from plotly.subplots import make_subplots

fig = (
    make_subplots(
        rows=2, 
        cols=1, 
        subplot_titles=('Initial values', 'Final values'), 
        shared_xaxes=True,
        vertical_spacing=0.05,
    )
    .add_scatter(
        x=table.Xi,
        y=table.Yi, 
    #     marker=dict(color=table.color),
        name='Initial',
        row=1,
        col=1,
    )
    .add_scatter(
        x=table.Xf,
        y=table.Yf,
        hovertext=table.color, 
    #     marker=dict(color=table.color),
        name='Final',
#         showlegend=False,
        row=2,
        col=1,
    )
    .update_layout(
        title = 'Random clusters',
        xaxis2_range=[0,500],
        width = 600,
        height = 900,
        showlegend = False,
    )
    .update_xaxes(
        title='X', 
    #     range=[0, 500],
        row=2,
    )
   
.update_traces(
        mode='markers',
        marker_color=table.color,
    )
)

fig