In [13]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import plotly.colors as colors

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

import pickle as pkl
from datetime import datetime

def ll2cart(r, lat, lon):
    # Converts lat, lon to 3D space coord
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)

    x = r * np.cos(lon_rad) * np.cos(lat_rad)
    y = r * np.sin(lon_rad) * np.cos(lat_rad)
    z = r * np.sin(lat_rad)

    return x, y, z

# Load data
pred = pkl.load(open('data/BayNNE_pred.pkl', 'rb'))
unc = pkl.load(open('data/BayNNE_std.pkl', 'rb'))
lat = pkl.load(open('data/BayNNE_lat.pkl', 'rb'))
lon = pkl.load(open('data/BayNNE_lon.pkl', 'rb'))

# Make time axis
time = []
for year in range(1980,2011):
    for mon in range(1,13):
        time.append(datetime.fromisoformat('{:04d}-{:02d}-15'.format(year, mon)))
time = np.array(time)
        
# Reshape preds and uncertainty to the right shape
pred = pred.reshape([len(time), len(lat), len(lon)])
unc = unc.reshape([len(time), len(lat), len(lon)])

# Fix coords - move from points being at the middle of the grid square to the edges
# This means that plots have no gaps and boundaries
# May also need to think about extending latitude to cover the poles?
lon = np.hstack([lon, [0]]) # Force wrap around lon coord
        
# Convert to cartesian
lon_grid, lat_grid = np.meshgrid(lon, lat)
x, y, z = ll2cart(1, lat_grid, lon_grid)

# Extend vals in lon to match the above extenstion of the lon coord
unc_data = np.zeros([len(time), len(lat), len(lon)])
unc_data[:, :, :-1] = unc
unc_data[:, :, -1] = unc[:, :, 0]

# Extend vals in lon to match the above extenstion of the lon coord
pred_data = np.zeros([len(time), len(lat), len(lon)])
pred_data[:, :, :-1] = pred
pred_data[:, :, -1] = pred[:, :, 0]

# For colorbar consistency
pred_min = pred_data.min()
pred_max = pred_data.max()
unc_min = unc_data.min()
unc_max = unc_data.max()

pred_name = 'toz'
unc_name = 'uncertainty'
units = 'DU'

# coastlines
xc, yc, zc = pkl.load(open('data/coastlines.pkl', 'rb'))

In [14]:
camera = dict(
    eye=dict(x=0.75, y=0.75, z=0.75)
)

def pred_srfc_trace(data):
    trace = go.Surface(x=x,
                       y=y,
                       z=z,
                       surfacecolor=data,
                       cmin=pred_min,
                       cmax=pred_max,
                       hoverinfo='skip',
                       colorbar=dict(
                              thickness=25,
                              len=0.5,
                            title='{} ({})'.format(pred_name, units),
                              x=0, y=0.75))
    return trace

def unc_srfc_trace(data):
    trace = go.Surface(x=x,
                       y=y,
                       z=z,
                       surfacecolor=data,
                       cmin=unc_min,
                       cmax=unc_max,
                       hoverinfo='skip',
                       colorbar=dict(
                              thickness=25,
                              len=0.5,
                            title='{} ({})'.format(unc_name, units),
                              x=0.5, y=0.75),
                      colorscale=colors.sequential.Viridis)
    return trace


def create_line_trace(data):
    trace = go.Scatter(x=time,
                        y=data,
                       hoverinfo='skip',
                       showlegend=False,
                        visible=True) ###
    return trace

# Remove axis
noaxis = dict(showbackground=False,
              showline=False,
              showgrid=False,
              showticklabels=False,
              title='',
              ticks='',
              zeroline=False)

# Plot layout
scene_layout = dict(xaxis=noaxis, 
                    yaxis=noaxis, 
                    zaxis=noaxis,
                    aspectratio=dict(x=1,
                                     y=1,
                                     z=1))

coastline_traces = []
for i in range(len(xc)):
    coastline_traces.append(go.Scatter3d(x=xc[i],
                                        y=yc[i],
                                        z=zc[i],
                                        mode='lines',
                                        showlegend=False,
                                        hoverinfo='skip',
                                        line=dict(
                                            color='black',
                                            width=1)))

In [15]:
fig = go.FigureWidget(make_subplots(2, 2, 
                      specs=[[{"type": "scene"}, {"type": "scene"}],
                             [{"type": "xy", "colspan": 2}, None]]))

# Plot 2 surfaces and 1 scatter
plot_s_pred = fig.add_trace(pred_srfc_trace(pred_data[0]), row=1, col=1)
plot_s_unc = fig.add_trace(unc_srfc_trace(unc_data[0]), row=1, col=2)
plot_l = fig.add_trace(create_line_trace(pred_data[:, 0, 0]), row=2, col=1)

# Initial plotting of markers
x0, y0, z0 = ll2cart(1.002, lat[1], lon[1])
plot_s_pred_point = fig.add_scatter3d(x=[x0],
                                 y=[y0],
                                 z=[z0],
                                 showlegend=False,
                                 hoverinfo='skip',
                                 marker_symbol='cross',
                                 marker=dict(color='white'),
                                 row=1, col=1)

plot_s_unc_point = fig.add_scatter3d(x=[x0],
                                 y=[y0],
                                 z=[z0],
                                 showlegend=False,
                                 hoverinfo='skip',
                                 marker_symbol='cross',
                                 marker=dict(color='white'),
                                 row=1, col=2)

plot_l_point = fig.add_scatter(x=[time[1]],
                               y=[pred_data[1,1,1]],
                               showlegend=False,
                               hoverinfo='skip',
                               marker_symbol='cross-thin',
                               marker_size=12,
                               marker_line_width=2,
                               row=2, col=1)

# Updates to name axes and remove axes on surface plots
fig.update_yaxes(title_text="{} ({})".format(pred_name, units), row=2, col=1)
fig.update_scenes(scene_layout)
fig.update_layout(xaxis=dict(range=[time[0],time[-1]]))

# Add coastlines for both surfaces
for trace in coastline_traces:
    fig.add_trace(trace, row=1, col=1)
    fig.add_trace(trace, row=1, col=2)
    
fig.update_layout(paper_bgcolor="LightSteelBlue")



fig.update_layout(legend=dict(
    orientation='h',
    yanchor="top",
    y=-0.1,
    xanchor="left",
    x=0
))
    
# Buttons for adding and clearing time series
add_plot_button = widgets.Button(description="Add time series")
output = widgets.Output()
display(add_plot_button, output)
clear_plot_button = widgets.Button(description="Clear plot")
output = widgets.Output()
display(clear_plot_button, output)


lat_idxs = [lat[1]]
lon_idxs = [lon[1]]
qual_colors = colors.qualitative.D3

### Issue currently updata time moves the earth.... I think this means that we need to define or remember camera positions...

@interact(time_slice=widgets.SelectionSlider(options=time,
                                    index=1),
          latitude=widgets.SelectionSlider(
                                    options=lat,
                                    index=1),
          longitude=widgets.SelectionSlider(
                                    options=lon,
                                    index=1))
def update(time_slice, latitude, longitude):
    with fig.batch_update():        
        # Find new indices
        time_idx = np.where(time == time_slice)[0][0]
        lat_idx = np.where(np.abs(lat-latitude)<0.1)[0][0]
        lon_idx = np.where(np.abs(lon-longitude)<0.1)[0][0]
        lat_idxs.append(lat_idx)
        lon_idxs.append(lon_idx)
        
        # Surface plot
        plot_s_pred.data[0].surfacecolor = pred_data[time_idx]
        plot_s_unc.data[1].surfacecolor = unc_data[time_idx]
        # Surface plot marker
        xm, ym, zm = ll2cart(1.002, latitude, longitude)
        plot_s_unc_point.data[4].x = [xm]
        plot_s_unc_point.data[4].y = [ym]
        plot_s_unc_point.data[4].z = [zm]
        plot_s_pred_point.data[3].x = [xm]
        plot_s_pred_point.data[3].y = [ym]
        plot_s_pred_point.data[3].z = [zm]
        
        # Line plot
        plot_l.data[2].y = pred_data[:, lat_idx, lon_idx]
        
        ### Add shading for uncertainty
        
        # Line plot marker
        plot_l_point.data[5].x = [time_slice]
        plot_l_point.data[5].y = [pred_data[time_idx, lat_idx, lon_idx]]


traces_to_keep = len(fig.data)

def add_plot(b):
    n_color = (len(fig.data) - traces_to_keep)//3
    fig.add_scatter(x=time,
                   y=pred_data[:, lat_idxs[-1], lon_idxs[-1]],
                   name='lat:{:.2f}, lon:{:.2f}'.format(lat[lat_idxs[-1]], lon[lon_idxs[-1]]),
                     mode='lines',
                    hoverinfo='skip',
                    line=dict(color=qual_colors[n_color]))
    x0, y0, z0 = ll2cart(1.002, lat[lat_idxs[-1]], lon[lon_idxs[-1]])
    fig.add_scatter3d(x=[x0],
                     y=[y0],
                     z=[z0],
                     showlegend=False,
                     mode='markers',
                     hoverinfo='skip',
                     marker=dict(color=qual_colors[n_color],
                                size=12),
                     marker_symbol='cross', row=1,col=1)
    fig.add_scatter3d(x=[x0],
                     y=[y0],
                     z=[z0],
                     showlegend=False,
                     mode='markers',
                     hoverinfo='skip',
                     marker=dict(color=qual_colors[n_color],
                                size=12),
                     marker_symbol='cross', row=1,col=2)
    
    fig.add_trace(go.Scatter(
        x=time, 
        y=pred_data[:, lat_idxs[-1], lon_idxs[-1]] - unc_data[:, lat_idxs[-1], lon_idxs[-1]],
        mode='lines',
        line_width=0,
        hoverinfo='skip',
        showlegend=False
        ))
    fig.add_trace(go.Scatter(
        x=time,
        y=pred_data[:, lat_idxs[-1], lon_idxs[-1]] + unc_data[:, lat_idxs[-1], lon_idxs[-1]],
        fill='tonexty', # fill area between trace0 and trace1
        hoverinfo='skip',
        showlegend=False,
        mode='lines',
        line_width=0,
        line=dict(color=qual_colors[n_color])))

add_plot_button.on_click(add_plot)

def clear_plot(b):
    fig.data = fig.data[:traces_to_keep]
    # Reset the lat and lon saves
    lat_idxs = [lat[1]]
    lon_idxs = [lon[1]]
clear_plot_button.on_click(clear_plot)

fig.update_scenes(dict(camera=camera), row=1, col=1)
fig.update_scenes(dict(camera=camera), row=1, col=2)

fig

Button(description='Add time series', style=ButtonStyle())

Output()

Button(description='Clear plot', style=ButtonStyle())

Output()

interactive(children=(SelectionSlider(description='time_slice', index=1, options=(datetime.datetime(1980, 1, 1…

FigureWidget({
    'data': [{'cmax': 0.9439714550971985,
              'cmin': -1.0151327848434448,
          …