# Camera pose 0, 0

In [1]:
import plotly.graph_objects as go
import numpy as np
import torch
import plotly.express as px

# Ray Casting

In [2]:
# 0.
def to_pytorch(tensor, return_type=False):
    ''' Converts input tensor to pytorch.
    Args:
        tensor (tensor): Numpy or Pytorch tensor
        return_type (bool): whether to return input type
    '''
    is_numpy = False
    if type(tensor) == np.ndarray:
        tensor = torch.from_numpy(tensor).float()
        is_numpy = True
    tensor = tensor.clone()
    if return_type:
        return tensor, is_numpy
    return tensor

# 1. get camera intrinsic
def get_camera_mat(fov=49.13, invert=True):
    # fov = 2 * arctan( sensor / (2 * focal))
    # focal = (sensor / 2)  * 1 / (tan(0.5 * fov))
    # in our case, sensor = 2 as pixels are in [-1, 1]
    focal = 1. / np.tan(0.5 * fov * np.pi/180.)
    focal = focal.astype(np.float32)
    mat = torch.tensor([
        [focal, 0., 0., 0.],
        [0., focal, 0., 0.],
        [0., 0., 1, 0.],
        [0., 0., 0., 1.]
    ]).reshape(1, 4, 4)

    if invert:
        mat = torch.inverse(mat)
    return mat

# 2. get camera position with camera pose (theta & phi)
def to_sphere(u, v):
    theta = 2 * np.pi * u
    phi = np.arccos(1 - 2 * v)
    cx = np.sin(phi) * np.cos(theta)
    cy = np.sin(phi) * np.sin(theta)
    cz = np.cos(phi)
    return np.stack([cx, cy, cz], axis=-1)

# 3. get camera coordinate system assuming it points to the center of the sphere
def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,
            to_pytorch=True):
    at = at.astype(float).reshape(1, 3)
    up = up.astype(float).reshape(1, 3)
    eye = eye.reshape(-1, 3)
    up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
    eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)

    z_axis = eye - at
    z_axis /= np.max(np.stack([np.linalg.norm(z_axis,
                                              axis=1, keepdims=True), eps]))

    x_axis = np.cross(up, z_axis)
    x_axis /= np.max(np.stack([np.linalg.norm(x_axis,
                                              axis=1, keepdims=True), eps]))

    y_axis = np.cross(z_axis, x_axis)
    y_axis /= np.max(np.stack([np.linalg.norm(y_axis,
                                              axis=1, keepdims=True), eps]))

    r_mat = np.concatenate(
        (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(
            -1, 3, 1)), axis=2)

    if to_pytorch:
        r_mat = torch.tensor(r_mat).float()

    return r_mat

# 5. arange 2d array of pixel coordinate and give depth of 1
def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
                  subsample_to=None, invert_y_axis=False):
    ''' Arranges pixels for given resolution in range image_range.
    The function returns the unscaled pixel locations as integers and the
    scaled float values.
    Args:
        resolution (tuple): image resolution
        batch_size (int): batch size
        image_range (tuple): range of output points (default [-1, 1])
        subsample_to (int): if integer and > 0, the points are randomly
            subsampled to this value
    '''
    h, w = resolution
    n_points = resolution[0] * resolution[1]

    # Arrange pixel location in scale resolution
    pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))
    pixel_locations = torch.stack(
        [pixel_locations[0], pixel_locations[1]],
        dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
    pixel_scaled = pixel_locations.clone().float()

    # Shift and scale points to match image_range
    scale = (image_range[1] - image_range[0])
    loc = scale / 2
    pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc
    pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc

    # Subsample points if subsample_to is not None and > 0
    if (subsample_to is not None and subsample_to > 0 and
            subsample_to < n_points):
        idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),
                               replace=False)
        pixel_scaled = pixel_scaled[:, idx]
        pixel_locations = pixel_locations[:, idx]

    if invert_y_axis:
        assert(image_range == (-1, 1))
        pixel_scaled[..., -1] *= -1.
        pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]

    return pixel_locations, pixel_scaled

# 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) 
def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,
                          invert=False, negative_depth=True):
    ''' Transforms points on image plane to world coordinates.
    In contrast to transform_to_world, no depth value is needed as points on
    the image plane have a fixed depth of 1.
    Args:
        image_points (tensor): image points tensor of size B x N x 2
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert matrices (default: False)
    '''
    batch_size, n_pts, dim = image_points.shape
    assert(dim == 2)
    d_image = torch.ones(batch_size, n_pts, 1)
    if negative_depth:
        d_image *= -1.
    return transform_to_world(image_points, d_image, camera_mat, world_mat,
                              scale_mat, invert=invert)

def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,
                       invert=True, use_absolute_depth=True):
    ''' Transforms pixel positions p with given depth value d to world coordinates.
    Args:
        pixels (tensor): pixel tensor of size B x N x 2
        depth (tensor): depth tensor of size B x N x 1
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert matrices (default: true)
    '''
    assert(pixels.shape[-1] == 2)

    if scale_mat is None:
        scale_mat = torch.eye(4).unsqueeze(0).repeat(
            camera_mat.shape[0], 1, 1)

    # Convert to pytorch
    pixels, is_numpy = to_pytorch(pixels, True)
    depth = to_pytorch(depth)
    camera_mat = to_pytorch(camera_mat)
    world_mat = to_pytorch(world_mat)
    scale_mat = to_pytorch(scale_mat)

    # Invert camera matrices
    if invert:
        camera_mat = torch.inverse(camera_mat)
        world_mat = torch.inverse(world_mat)
        scale_mat = torch.inverse(scale_mat)

    # Transform pixels to homogen coordinates
    pixels = pixels.permute(0, 2, 1)
    pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)

    # Project pixels into camera space
    if use_absolute_depth:
        pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()
        pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)
    else:
        pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)
        
    # Transform pixels to world space
    p_world = scale_mat @ world_mat @ camera_mat @ pixels

    # Transform p_world back to 3D coordinates
    p_world = p_world[:, :3].permute(0, 2, 1)

    if is_numpy:
        p_world = p_world.numpy()
    return p_world


# 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)
def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,
                    invert=False):
    ''' Transforms origin (camera location) to world coordinates.
    Args:
        n_points (int): how often the transformed origin is repeated in the
            form (batch_size, n_points, 3)
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert the matrices (default: false)
    '''
    
    batch_size = camera_mat.shape[0]
    device = camera_mat.device
    # Create origin in homogen coordinates
    p = torch.zeros(batch_size, 4, n_points).to(device)
    p[:, -1] = 1.

    if scale_mat is None:
        scale_mat = torch.eye(4).unsqueeze(
            0).repeat(batch_size, 1, 1).to(device)

    # Invert matrices
    if invert:
        camera_mat = torch.inverse(camera_mat)
        world_mat = torch.inverse(world_mat)
        scale_mat = torch.inverse(scale_mat)
        
    camera_mat = to_pytorch(camera_mat)
    world_mat = to_pytorch(world_mat)
    scale_mat = to_pytorch(scale_mat)
    
    # Apply transformation
    p_world = scale_mat @ world_mat @ camera_mat @ p

    # Transform points back to 3D coordinates
    p_world = p_world[:, :3].permute(0, 2, 1)
    return p_world

In [3]:
def giraffe(u = 1,
            v = 0.5,
            r=2.713,
            depth_range=[0.5, 6.],
            n_ray_samples=16,
            resolution_vol = 4,
            batch_size = 1
           ):

    range_radius=[r, r]
    
    res = resolution_vol
    n_points = res * res

    # 1. get camera intrinsic 
    camera_mat = get_camera_mat()

    # 2. get camera position with camera pose (theta & phi)
    loc = to_sphere(u, v)
    loc = torch.tensor(loc).float()
    radius = range_radius[0] + \
        torch.rand(batch_size) * (range_radius[1] - range_radius[0])
    loc = loc * radius.unsqueeze(-1)

    # 3. get camera coordinate system assuming it points to the center of the sphere
    R = look_at(loc)

    # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic
    RT = np.eye(4).reshape(1, 4, 4)
    RT[:, :3, :3] = R
    RT[:, :3, -1] = loc
    world_mat = RT

    # 5. arange 2d array of pixel coordinate and give depth of 1
    pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]
    pixels[..., -1] *= -1. # still dunno why this is here

    # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) 
    pixels_world = image_points_to_world(pixels, camera_mat, world_mat)

    # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)
    camera_world = origin_to_world(n_points, camera_mat, world_mat)

    # 8. ray = pixel - camera origin (in world)
    ray_vector = pixels_world - camera_world

    # 9. depths from closest to furthest (0.5 ~ 6.0)
    di = depth_range[0] + \
        torch.linspace(0., 1., steps=n_ray_samples).reshape(1, 1, -1) * (
            depth_range[1] - depth_range[0])
    di = di.repeat(batch_size, n_points, 1)

    # 10. 
    p_i = camera_world.unsqueeze(-2).contiguous() + \
        di.unsqueeze(-1).contiguous() * ray_vector.unsqueeze(-2).contiguous()
    
    return pixels_world, camera_world, world_mat, p_i

# Visualization

In [4]:
# draw sphere with radius r 
# also draw contours and vertical lines
def draw_sphere(r, sphere_colorscale, sphere_opacity):
    # sphere
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x = r * np.outer(np.cos(u), np.sin(v))
    y = r * np.outer(np.sin(u), np.sin(v))
    z = r * np.outer(np.ones(np.size(u)), np.cos(v))
    
    # vertical lines on sphere
    u2 = np.linspace(0, 2 * np.pi, 20)
    x2 = r * np.outer(np.cos(u2), np.sin(v))
    y2 = r * np.outer(np.sin(u2), np.sin(v))
    z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))
    
    # create sphere and draw sphere with contours
    fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, 
                                 colorscale=sphere_colorscale, opacity=sphere_opacity,
                                 contours = {
                                     'z' : {'show' : True, 'start' : -r,
                                           'end' : r, 'size' : r/10,
                                           'color' : 'white',
                                           'width' : 1}
                                 }
                                , showscale=False)])
    
    # vertical lines on sphere
    for i in range(len(u2)):
        fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], 
                          line=dict(
                              color='white',
                              width=1
                          ),
                         mode='lines',
                         showlegend=False)
    
    return fig

# draw xyplane
def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):
    x3 = np.linspace(x_range[0], x_range[1], 100)
    y3 = np.linspace(y_range[0], y_range[1], 100)
    z3 = np.zeros(shape=(100,100))
    
    fig.add_surface(x=x3, y=y3, z=z3,
                colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,
                showscale=False
    )
    
    return fig
    

def draw_XYZworld(fig, world_axis_size):
    # x, y, z positive direction (world)
    X_axis = [0, world_axis_size]
    X_text = [None, "X"]
    X0 = [0, 0]
    Y_axis = [0, world_axis_size]
    Y_text = [None, "Y"]
    Y0 = [0, 0]
    Z_axis = [0, world_axis_size]
    Z_text = [None, "Z"]
    Z0 = [0, 0]
    
    fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, 
                      line=dict(
                          color='red',
                          width=10
                      ),
                    mode='lines+text',
                    text=X_text,
                    textposition='top center',
                    textfont=dict(
                        color="red",
                        size=18
                    ),
                    showlegend=False)

    fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, 
                          line=dict(
                              color='green',
                              width=10
                          ),
                         mode='lines+text',
                        text=Y_text,
                        textposition='top center',
                        textfont=dict(
                            color="green",
                            size=18
                        ),
                        showlegend=False)

    fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, 
                          line=dict(
                              color='blue',
                              width=10
                          ),
                         mode='lines+text',
                        text=Z_text,
                        textposition='top center',
                        textfont=dict(
                            color="blue",
                            size=18
                        ),
                        showlegend=False)
    
    return fig

# draw cam and cam coordinate system
def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):
    # camera at init

    Xc = [world_mat[0, : ,3][0]]
    Yc = [world_mat[0, : ,3][1]]
    Zc = [world_mat[0, : ,3][2]]
    text_c = ["Camera"]

    # camera axis
    Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]
    Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]
    Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]
    text_Xaxis = [None, "Xc"]
    
    # -z in world perspective
    Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]
    Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]
    Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]
    text_Yaxis = [None, "Yc"]

    # y in world perspective
    Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]
    Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]
    Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]
    text_Zaxis = [None, "Zc"]
        
    # cam pos
    fig.add_scatter3d(x=Xc, y=Yc, z=Zc, 
                     mode='markers',
                  marker=dict(
                      color=camera_color,
                      size=4,
                      sizemode='diameter'
                  ),
                    showlegend=False)

    # camera axis
    fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, 
                          line=dict(
                              color='red',
                              width=10
                          ),
                        mode='lines+text',
                        text=text_Xaxis,
                        textposition='top center',
                        textfont=dict(
                            color="red",
                            size=18
                        ),
                        showlegend=False)

    fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, 
                          line=dict(
                              color='green',
                              width=10
                          ),
                        mode='lines+text',
                        text=text_Yaxis,
                        textposition='top center',
                        textfont=dict(
                            color="green",
                            size=18
                        ),
                        showlegend=False)

    fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, 
                          line=dict(
                              color='blue',
                              width=10
                          ),
                        mode='lines+text',
                        text=text_Zaxis,
                        textposition='top center',
                        textfont=dict(
                            color="blue",
                            size=18
                        ),
                        showlegend=False)
    
    return fig

# draw all rays
def draw_all_rays(fig, p_i, ray_color):
    for i in range(p_i.shape[1]):
        Xray = p_i[0, i, :, 0]
        Yray = p_i[0, i, :, 1]
        Zray = p_i[0, i, :, 2]
        
        fig.add_scatter3d(x=Xray, y=Yray, z=Zray, 
                          line=dict(
                              color=ray_color,
                              width=5
                          ),
                         mode='lines',
                        showlegend=False)
        
    return fig

# draw near&far frustrum with rays connecting the corners
def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):
    
    for i in at:
        Xfrus = p_i[0, :, i, 0][[0, 3, -1, -4, 0]]
        Yfrus = p_i[0, :, i, 1][[0, 3, -1, -4, 0]]
        Zfrus = p_i[0, :, i, 2][[0, 3, -1, -4, 0]]
        
        fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, 
                        line=dict(
                              color=frustrum_color,
                              width=5
                          ),
                         mode='lines',
                          surfaceaxis=0,
                          surfacecolor=frustrum_color,
                          opacity=frustrum_opacity,
                        showlegend=False)
    
    return fig

In [5]:
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output

# for colors
import matplotlib.colors as mcolors

The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [6]:
app = JupyterDash(__name__)

app.layout = html.Div([
    html.H1("Ray casting visualization"),
    dcc.Graph(id='graph'),
    
    html.Div([
        html.Div([
            
            # changes to setting and ray casted 
            html.Label([ "u (theta)",
                dcc.Slider(
                    id='u-slider', 
                    min=0, max=1,
                    value=0.00,
                    marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},
                    step=0.01, tooltip = { 'always_visible': True }
                ), ]),
            html.Label([ "v (phi)",
                dcc.Slider(
                    id='v-slider', 
                    min=0, max=1,
                    value=0.25,
                    marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},
                    step=0.01, tooltip = { 'always_visible': True }
                ), ]),
            html.Label([ "r (sphere radius)",
                dcc.Slider(
                    id='r-slider', 
                    min=0, max=5,
                    value=2.713,
                    marks={str(val) : str(val) for val in [0.000, 1.000, 2.713, 5.000]},
                    step=0.001, tooltip = { 'always_visible': True }
                ), ]),
            html.Label([ "depth range",
                dcc.RangeSlider(
                    id='depth-range-slider', 
                    min=0, max=10,
                    value=[0.5, 6],
                    marks={str(val) : str(val) for val in [0.5 * i for i in range(21)]},
                    step=0.1, tooltip = { 'always_visible': True }
                ), ])
        ], style = {'width' : '48%', 'display' : 'inline-block'}),
        
        html.Div([
            # changes to visual appearance
            
            # axis scale
            html.Div([
                html.Label([ "world axis size",
                    html.Div([
                        dcc.Input(id='world-axis-size-input',
                                  value=1.5,
                                  type='number'
                                 )
                    ]),
                ], style = {'width' : '48%', 'display' : 'inline-block'}),
                html.Label([ "camera axis size",
                    html.Div([
                        dcc.Input(id='camera-axis-size-input',
                                  value=0.3,
                                  type='number'
                                 )
                    ]),
                ],  style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),
            ]),
            
            # color
            html.Div([
                html.Label([ "camera color",
                html.Div([
                    dcc.Dropdown(id='camera-color-input',
                                 clearable=False,
                              value='yellow',
                              options=[
                                     {'label': c, 'value': c}
                                     for (c, _) in mcolors.CSS4_COLORS.items()
                                 ]
                             )
                ])
                ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
                
                html.Label([ "ray color",
                    html.Div([
                        dcc.Dropdown(id='ray-color-input',
                                     clearable=False,
                                  value='yellow',
                                  options=[
                                         {'label': c, 'value': c}
                                         for (c, _) in mcolors.CSS4_COLORS.items()
                                     ]
                                 )
                    ])
                ],  style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
                
                html.Label([ "frustrum color",
                    html.Div([
                        dcc.Dropdown(id='frustrum-color-input',
                                     clearable=False,
                                  value='orange',
                                  options=[
                                         {'label': c, 'value': c}
                                         for (c, _) in mcolors.CSS4_COLORS.items()
                                     ]
                                 )
                    ])
                ],  style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
            ]),
            
            # colorscale
            html.Div([
                html.Label([ "sphere colorscale",
                    html.Div([
                        dcc.Dropdown(id='sphere-colorscale-input',
                                     clearable=False,
                                     value='greys',
                                     options=[
                                         {'label': c, 'value': c}
                                         for c in px.colors.named_colorscales()
                                     ]
                                    )
                    ])
                ], style = {'width' : '48%', 'display' : 'inline-block'}),
                
                html.Label([ "xy-plane colorscale",
                    html.Div([
                        dcc.Dropdown(id='xy-plane-colorscale-input',
                                     clearable=False,
                                     value='greys',
                                     options=[
                                         {'label': c, 'value': c}
                                         for c in px.colors.named_colorscales()
                                     ]
                                    )
                    ])
                ],  style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),
            ]),
            
            # opacity 
            html.Div([
                html.Label([ "sphere opacity",
                    html.Div([
                        dcc.Input(id='sphere-opacity-input',
                                  value=0.2,
                                  type='number'
                                 )
                    ])
                ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
                            
                html.Label([ "xy-plane opacity",            
                    html.Div([
                        dcc.Input(id='xy-plane-opacity-input',
                                  value=0.8,
                                  type='number'
                                 )
                    ])
                ],  style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
                
                html.Label([ "frustrum opacity",
                    html.Div([
                        dcc.Input(id='frustrum-opacity-input',
                                  value=0.3,
                                  type='number'
                                 )
                    ])
                ],  style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
            ]),
            
        ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),
            
    ]),
    
])

@app.callback(
    Output('graph', 'figure'),
    Input("u-slider", "value"),
    Input("v-slider", "value"),
    
    Input("r-slider", "value"),
    Input("depth-range-slider", "value"),
    
    Input("world-axis-size-input", "value"),
    Input("camera-axis-size-input", "value"),
    
    Input("camera-color-input", "value"),
    Input("ray-color-input", "value"),
    Input("frustrum-color-input", "value"),
    
    Input('sphere-colorscale-input', "value"),
    Input('xy-plane-colorscale-input', "value"),
    
    Input("sphere-opacity-input", "value"),
    Input("xy-plane-opacity-input", "value"),
    Input("frustrum-opacity-input", "value"),
)

def update_figure(u, v, 
                  r, depth_range,
                  world_axis_size, camera_axis_size,
                  camera_color, ray_color, frustrum_color,
                  sphere_colorscale, xy_plane_colorscale,
                  sphere_opacity, xy_plane_opacity, frustrum_opacity                  
                 ):
    # sphere
    fig = draw_sphere(r=r, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)

    # change figure size
#     fig.update_layout(autosize=False, width = 500, height=500)

    # draw axes in proportion to the proportion of their ranges
    fig.update_layout(scene_aspectmode='data')

    # xy plane
    fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,
                       x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])

    # show world coordinate system (X, Y, Z positive direction)
    fig = draw_XYZworld(fig, world_axis_size=world_axis_size)

    pixels_world, camera_world, world_mat, p_i = giraffe(u=u, v=v, r=r, depth_range=depth_range)

    #  draw camera at init (with its cooridnate system)
    fig = draw_cam_init(fig, world_mat, 
                        camera_axis_size=camera_axis_size, camera_color=camera_color)

    # draw all rays
    fig = draw_all_rays(fig, p_i, ray_color=ray_color)

    # draw near&far frustrum with rays connecting the corners
    fig = draw_ray_frus(fig, p_i, frustrum_color=frustrum_color, frustrum_opacity=frustrum_opacity,
                        at=[0, 8, -1])
    
    return fig

app.run_server(mode='inline')