<a href="https://colab.research.google.com/github/manmeet3591/python_class/blob/master/xarray_tutorial/three_d_plot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go


def get_arrow(axisname="x"):

    # Create arrow body
    body = go.Scatter3d(
        marker=dict(size=1, color=colorscale[0][1]),
        line=dict(color=colorscale[0][1], width=3),
        showlegend=False,  # hide the legend
    )

    head = go.Cone(
        sizeref=0.1,
        autocolorscale=None,
        colorscale=colorscale,
        showscale=False,  # disable additional colorscale for arrowheads
        hovertext=axisname,
    )
    for ax, direction in zip(("x", "y", "z"), ("u", "v", "w")):
        if ax == axisname:
            body[ax] = data[ax]["min"], data[ax]["max"]
            head[ax] = [data[ax]["max"]]
            head[direction] = [1]
        else:
            body[ax] = data[ax]["mid"], data[ax]["mid"]
            head[ax] = [data[ax]["mid"]]
            head[direction] = [0]

    return [body, head]


def add_axis_arrows(fig):
    for ax in ("x", "y", "z"):
        for item in get_arrow(ax):
            fig.add_trace(item)


def get_annotation_for_ax(ax):
    d = dict(showarrow=False, text=ax, xanchor="left", font=dict(color="#1f1f1f"))
    for ax_ in ("x", "y", "z"):
        if ax_ == ax:
            d[ax_] = data[ax]["max"] - data[ax]["range"] * 0.05
        else:
            d[ax_] = data[ax_]["mid"]

    if ax in {"x", "y"}:
        d["xshift"] = 15

    return d


def get_axis_names():
    return [get_annotation_for_ax(ax) for ax in ("x", "y", "z")]


def get_scene_axis(axisname="x"):

    return dict(
        title="",  # remove axis label (x,y,z)
        showbackground=False,
        visible=True,
        showticklabels=False,  # hide numeric values of axes
        showgrid=True,  # Show box around plot
        gridcolor="grey",  # Box color
        tickvals=[data[axisname]["min"], data[axisname]["max"]],  # Set box limits
        range=[
            data[axisname]["min"],
            data[axisname]["max"],
        ],  # Prevent extra lines around box
    )


fig = go.Figure(
  
    layout=dict(
        title="surface",
        autosize=True,
        width=700,
        height=500,
        margin=dict(l=20, r=20, b=25, t=25),
        scene=dict(
            xaxis=get_scene_axis("x"),
            yaxis=get_scene_axis("y"),
            zaxis=get_scene_axis("z"),
            annotations=get_axis_names(),
        ),
    ),
)

add_axis_arrows(fig)

N = 1000
t = np.linspace(0, 1, 100)
y = np.sin(t)
t = np.linspace(0, 10, 50)
x, y, z = np.cos(t), np.sin(t), t
t = np.zeros_like((t))+1.0
z = t

fig.add_trace(go.Scatter3d(x=x, y=y, z=z))
fig.show()