In [23]:
from typing import Union, Tuple, Any, List
from abc import ABC, abstractmethod

import numpy as np
import matplotlib.pylab as plt

from dataclasses import dataclass
from shapely.geometry import Point, LineString, Polygon, box, asMultiPoint
from descartes import PolygonPatch 


In [62]:

@dataclass
class Geometry(ABC):
        
    center: Tuple[float, float, float]
    
    def __post_init__(self):
        self.contains_vectorized = np.vectorize(self.contains_pointwise, signature="(n),(n)->()")
        self.geometry_vectorized = np.vectorize(self.geometry_z, signature="()->()")

    @abstractmethod
    def geometry(self, x=None, y=None, z=None) -> list:
        """ returns shapely geoemtry at plane specified by one non None value of x,y,z """

    def inside(self, x, y, z) -> bool:
        """ is point (x,y,z) inside volume of Geometry? (note: here is slow (point by point) way, implement vectorized on subclasses. """
        insides = []
        for (_x, _y, _z) in zip(x, y, z):
            geo = self.geometry_z(_z)
            loc = Point(_x, _y)
            insides.append(geo.contains(loc))
        return np.array(insides)

    def contains_pointwise(self, xy_point, z_point):
        point = Point(xy_point)
        geo = self.geometry_vectorized(z_point)
        return geo.contains(point)

    def inside_vectorized(self, x, y, z) -> bool:
        xy_points = np.stack((x, y), axis=0)
        z_points = z[None, :]
        return self.contains_vectorized(xy_points, z_points)
        
    @staticmethod
    def pop_axis(coord: Tuple[Any, Any, Any], axis: int) -> Tuple[Any, Tuple[Any, Any]]:
        """separate coordinate at `axis` index from coordinates on the plane tangent to `axis`"""
        plane_vals = list(coord)
        axis_val = plane_vals.pop(axis)
        return axis_val, plane_vals

    def plot(self, x=None, y=None, z=None, ax=None, **patch_kwargs):
        """ plot structure """
        if ax is None:
            _, ax = plt.subplots()
        geo = self.geometry(x=x, y=y, z=z)
        for shape in geo:
            patch = PolygonPatch(shape, **patch_kwargs)
            ax.add_artist(patch)
        return ax

    @staticmethod
    def _parse_xyz_kwargs(**xyz) -> Tuple[int, float]:
        """ turns kwargs for plane specification like {x=None, y=None, z=None} into axis (0,1,2) and position along axis """
        xyz_filtered = {k : v for k, v in xyz.items() if v is not None}
        assert len(xyz_filtered) == 1, f"exatly one kwarg in [x,y,z] must be specified, given {xyz_filtered}."
        axis_label, position = list(xyz_filtered.items())[0]
        axis = 'xyz'.index(axis_label)
        return axis, position
    
@dataclass
class Sphere(Geometry):
    
    radius: float
    
    def _intersect_dist(self, position, z0):
        """ distance between points on circle at z=position where center of circle at z=z0 """
        dz = np.abs(z0 - position)
        return (dz <= self.radius) * 2 * np.sqrt(self.radius**2 - dz**2)

    def geometry(self, x=None, y=None, z=None):
        """ returns shapely geoemtry at plane specified by one non None value of x,y,z """
        axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z)
        z0, (x0, y0) = self.pop_axis(self.center, axis=axis)
        intersect_dist = self._intersect_dist(position, z0)
        center = Point(x0, y0)
        return center.buffer(0.5 * intersect_dist)

    def geometry_z(self, zs):
        """ returns shapely geoemtry at plane specified by one non None value of x,y,z """
        (x0, y0, z0) = self.center
        intersect_dist = self._intersect_dist(zs, z0)
        center = Point(x0, y0)
        return center.buffer(0.5 * intersect_dist)

    def inside_numpy(self, x, y, z) -> bool:
        """returns True if (x,y,z) is inside of geometry"""
        x0, y0, z0 = self.center
        dist_x = np.abs(x - x0)
        dist_y = np.abs(y - y0)
        dist_z = np.abs(z - z0)
        return (dist_x ** 2 + dist_y ** 2 + dist_z ** 2) <= (self.radius ** 2)


In [63]:
# sphere
N = 100
xs = np.linspace(-1, 1, N)
ys = np.linspace(-1, 1, N)
zs = np.linspace(-1, 1, 1)
xx, yy, zz = np.meshgrid(xs, ys, zs, indexing='ij')
x = xx.flatten()
y = yy.flatten()
z = zz.flatten()

s = Sphere(center=(0,0,0), radius=1)

In [64]:
%%time
s1 = s.inside(x=x, y=y, z=z)

CPU times: user 475 ms, sys: 6.24 ms, total: 481 ms
Wall time: 484 ms


In [65]:
%%time
s1 = s.inside_vectorized(x=x, y=y, z=z)

AttributeError: 'numpy.ndarray' object has no attribute 'contains'

In [66]:
%%time
s2 = s.inside_numpy(x=x, y=y, z=z)

CPU times: user 1.02 ms, sys: 1.1 ms, total: 2.12 ms
Wall time: 1.19 ms


In [29]:
assert np.all(s1 == s2)