In [1]:
import numpy as np
import plotly.graph_objects as go

# Shape Generator

In [15]:
class ShapesGenerator:
    def __init__(self, number_of_points=64):
        self.number_of_points = number_of_points

    def generate_cylinder(self, r=1, h=1, deformation=None):
        theta = np.linspace(0, (1 if deformation else 2)*np.pi, self.number_of_points)
        z = np.linspace(-h, h, self.number_of_points)
        theta, z = np.meshgrid(theta, z)
        x = r * np.cos(theta)
        y = r * np.sin(theta)
        if deformation:
            z = z + deformation*np.sinh(z)
        return self.generate_shape([x, y, z], f'data/{"dcylinder" if deformation else "cylinder"}.txt')

    def generate_torus(self, R=3, r=1, deformation=None):
        theta = np.linspace(0, 2*np.pi, self.number_of_points)
        phi = np.linspace(0, (1.0 if deformation else 2)*np.pi, self.number_of_points)
        theta, phi = np.meshgrid(theta, phi)
        x = (R + r*np.cos(theta))*np.cos(phi)
        y = (R + r*np.cos(theta))*np.sin(phi)
        z = r*np.sin(theta)
        if deformation:
            z = deformation*np.sin(2.0 * phi)
        return self.generate_shape([x, y, z], f'data/{"worm" if deformation else "torus"}.txt')

    def generate_sphere(self, r=1, deformation=None):
        theta = np.linspace(0, 2*np.pi, self.number_of_points)
        phi = np.linspace(0, np.pi, self.number_of_points)
        theta, phi = np.meshgrid(theta, phi)
        x, y, z = r*np.cos(theta)*np.sin(phi), r*np.sin(theta)*np.sin(phi), r*np.cos(phi)
        if deformation:
            z = z + deformation*np.sin(theta)
        return self.generate_shape([x, y, z], f'data/{"spheroid" if deformation else "sphere"}.txt')
    
    def generate_pseudo_sphere(self, deformation=None):
        # Create a grid of u and v values
        u = np.linspace(-np.pi, np.pi, self.number_of_points)
        v = np.linspace(-np.pi, np.pi, self.number_of_points)
        u, v = np.meshgrid(u, v)
        # Compute x, y, and z
        x = 1./np.cosh(u) * np.cos(v)
        y = 1./np.cosh(u) * np.sin(v)
        z = u - np.tanh(u)
        if deformation:
            z = z + deformation*np.sin(u)
        return self.generate_shape([x, y, z], f'data/{"dpsphere" if deformation else "psphere"}.txt')
    
    def generate_beltrami(self):
        # Asymptotical coordinates
        xi = np.linspace(1e-3, np.pi, self.number_of_points)
        theta = np.linspace(1e-3, 2*np.pi, self.number_of_points)
        # Meshgrid
        xi, theta = np.meshgrid(xi, theta)
        # Beltrami coordinates
        omega = theta - xi*np.sin(theta)
        sigma = xi*np.cos(theta)
        # Radiovector
        x = np.sin(theta)/np.cosh(omega) * np.cos(sigma)
        y = np.sin(theta)/np.cosh(omega) * np.sin(sigma)
        z = 0.5*xi + np.cos(theta) + np.sin(theta) * np.tanh(omega)
        return self.generate_shape([x, y, z], f'data/beltrami.txt')

    def generate_shape(self, coords, file_path):
        # Reshape coordinates
        points = np.array(coords).transpose(1, 2, 0)
        # Normalise points
        points *= 1./np.max(points)
        # Reshape points for handling vectors in R3
        points = points.reshape(self.number_of_points**2, 3)
        # If a coordinate is less than 1e-3, set it to 0
        points[np.abs(points) < 1e-3] = 0.0
        # Remove duplicates: the final length is N**2 - N
        points = np.unique(points, axis=0)
        # Save points to file
        np.savetxt(file_path, points)
        # Plot points
        self.plot_points(points)
        return points
    
    # plot with plotly
    @staticmethod
    def plot_points(points):
        # Dictionaries of axis style options
        axes_layout = dict(
                backgroundcolor="rgb(256, 256, 256)", 
                gridcolor="white", 
                showbackground=True, 
                zerolinecolor="white",
            )
        axes_options = dict(showticklabels=False, showgrid=False, zeroline=False, showline=False)
        # Create figure
        fig = go.Figure(
            data = [
                go.Scatter3d(x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers')
                ]
            )
        # Set axes style
        fig.update_layout(width=400, height=400, scene_camera_eye=dict(x=1.5, y=1, z=0.8)) 
        fig.update_traces(marker=dict(size=20, color=points[:, 2], colorscale='sunset', opacity=1))
        # Color edges
        #fig.update_traces(marker=dict(line=dict(color='black', width=2)))
        # Remove axes ticks
        fig.update_layout(scene=dict(xaxis=axes_options, yaxis=axes_options, zaxis=axes_options))
        # Remove axes labels
        fig.update_layout(scene=dict(xaxis_title='', yaxis_title='', zaxis_title=''))
        # Change box and background color
        fig.update_layout(scene=dict(xaxis=axes_layout, yaxis=axes_layout, zaxis=axes_layout))
        # Avoid the figure to be cropped and show it in full screen
        fig.update_layout(
            margin=dict(l=0, r=0, b=0, t=0), 
            scene=dict(
                xaxis=dict(range=[-1.5, 1.5]),
                yaxis=dict(range=[-1.5, 1.5]),
                zaxis=dict(range=[-1.5, 1.6]),
        ))
        fig.show()

In [16]:
generator = ShapesGenerator()
X = generator.generate_beltrami()