In [5]:
class Ray(object):
    def __init__(self, o, d):
        self.origin = o
        self.direction = d

class Shape(object):
    def __init__(self):
        self.transform = matrix4x4identity()
        self.material = material()

    def intersect(self, ray_original):
        ray = transform(ray_original, inverse(self.transform))
        return self.local_intersect(ray)

    def normal_at(self, world_point):
        local_point = matrix_multiply(inverse(self.transform), world_point)
        local_normal = self.local_normal_at(local_point)
        world_normal = matrix_multiply(transpose(inverse(self.transform)), local_normal)
        world_normal[3] = 0
        return normalize(world_normal)

class Sphere(Shape):
    def __init__(self):
        Shape.__init__(self)
        self.origin = point(0,0,0)
        self.radius = 1

    def local_intersect(self, ray_local):
        sphere_to_ray = ray_local.origin - self.origin
        a = dot(ray_local.direction, ray_local.direction)
        b = 2 * dot(ray_local.direction, sphere_to_ray)
        c = dot(sphere_to_ray, sphere_to_ray) - 1
        discriminant = b ** 2 - 4 * a * c

        if discriminant < 0:
            return []
        return intersections(Intersection((-b - np.sqrt(discriminant)) / (2 * a), self),
                             Intersection((-b + np.sqrt(discriminant)) / (2 * a), self))

    def local_normal_at(self, local_point):
        return local_point - point(0,0,0)


class Plane(Shape):
    def __init__(self):
        Shape.__init__(self)
        self.normalv = vector(0,1,0)

    def local_intersect(self, ray_local):
        if abs(ray_local.direction[1]) < EPSILON:
            return []

        t = -ray_local.origin[1] / ray_local.direction[1]
        return intersections(Intersection(t, self))

    def local_normal_at(self, local_point):
        return self.normalv


class Intersection(object):
    def __init__(self, t, obj):
        self.t = t
        self.object = obj

def plane():
    """
    >>> p = plane()
    >>> n1 = p.local_normal_at(point(0,0,0))
    >>> n2 = p.local_normal_at(point(10,0,-10))
    >>> n3 = p.local_normal_at(point(-5,0,150))
    >>> n1 == vector(0,1,0)
    array([ True,  True,  True,  True])

    >>> n2 == vector(0,1,0)
    array([ True,  True,  True,  True])

    >>> n3 == vector(0,1,0)
    array([ True,  True,  True,  True])

    >>> p = plane()
    >>> r = ray(point(0,10,0), vector(0,0,1))
    >>> xs = p.local_intersect(r)
    >>> len(xs) == 0
    True

    >>> p = plane()
    >>> r = ray(point(0,0,0), vector(0,0,1))
    >>> xs = p.local_intersect(r)
    >>> len(xs) == 0
    True

    >>> p = plane()
    >>> r = ray(point(0,1,0), vector(0,-1,0))
    >>> xs = p.local_intersect(r)
    >>> len(xs) == 1 and np.isclose(xs[0].t,1) and xs[0].object == p
    True

    >>> p = plane()
    >>> r = ray(point(0,-1,0), vector(0,1,0))
    >>> xs = p.local_intersect(r)
    >>> len(xs) == 1 and np.isclose(xs[0].t,1) and xs[0].object == p
    True
    """
    return Plane()

def ray(o, d):
    """
    >>> p = point(1,2,3)
    >>> d = vector(4,5,6)
    >>> r = ray(p,d)
    >>> r.origin == p
    array([ True,  True,  True,  True])

    >>> r.direction == d
    array([ True,  True,  True,  True])
    """
    return Ray(o,d)

def position(ray, t):
    """
    >>> r = ray(point(2,3,4), vector(1,0,0))
    >>> position(r, 0) == point(2,3,4)
    array([ True,  True,  True,  True])

    >>> position(r, 1) == point(3,3,4)
    array([ True,  True,  True,  True])

    >>> position(r,-1) == point(1,3,4)
    array([ True,  True,  True,  True])

    >>> position(r,2.5) == point(4.5,3,4)
    array([ True,  True,  True,  True])
    """
    return ray.origin + ray.direction * t

def sphere():
    """
    >>> s = sphere()
    >>> s.transform == matrix4x4identity()
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])
    """
    return Sphere()

def intersect(shape, ray):
    """
    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 2
    True
    >>> xs[0].t == 4.0 and xs[1].t == 6.0
    True

    >>> r = ray(point(0,1,-5), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 2
    True
    >>> xs[0].t == 5.0 and xs[1].t == 5.0
    True

    >>> r = ray(point(0,2,-5), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 0
    True

    >>> r = ray(point(0,0,0), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 2
    True
    >>> xs[0].t == -1.0 and xs[1].t == 1.0
    True

    >>> r = ray(point(0,0,5), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 2
    True
    >>> xs[0].t == -6.0 and xs[1].t == -4.0
    True

    >>> r = ray(point(0,0,5), vector(0,0,1))
    >>> s = sphere()
    >>> xs = intersect(s,r)
    >>> len(xs) == 2
    True
    >>> id(xs[0].object) == id(s) and id(xs[1].object) == id(s)
    True

    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> s = sphere()
    >>> s.transform = scaling(2,2,2)
    >>> xs = intersect(s,r)
    >>> len(xs) == 2 and xs[0].t == 3 and xs[1].t == 7
    True

    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> s = sphere()
    >>> s.transform = translation(5,0,0)
    >>> xs = intersect(s,r)
    >>> len(xs) == 0
    True
    """
    return shape.intersect(ray)

def intersection(t, obj):
    """
    >>> s = sphere()
    >>> i = intersection(3.5, s)
    >>> i.t == 3.5 and id(s) == id(i.object)
    True
    """
    return Intersection(t, obj)

def intersections(*args):
    """
    >>> s = sphere()
    >>> i1 = intersection(1,s)
    >>> i2 = intersection(2,s)
    >>> xs = intersections(i1,i2)
    >>> len(xs) == 2 and xs[0].t == 1 and xs[1].t == 2
    True
    """
    return sorted(list(args), key=lambda i: i.t)

def hit(intersections):
    """
    >>> s = sphere()
    >>> i1 = intersection(1,s)
    >>> i2 = intersection(2,s)
    >>> xs = intersections(i1,i2)
    >>> i = hit(xs)
    >>> i == i1
    True

    >>> s = sphere()
    >>> i1 = intersection(-1,s)
    >>> i2 = intersection(1,s)
    >>> xs = intersections(i1,i2)
    >>> i = hit(xs)
    >>> i == i2
    True

    >>> s = sphere()
    >>> i1 = intersection(-2,s)
    >>> i2 = intersection(-1,s)
    >>> xs = intersections(i1,i2)
    >>> i = hit(xs)
    >>> i is None
    True

    >>> s = sphere()
    >>> i1 = intersection(5,s)
    >>> i2 = intersection(7,s)
    >>> i3 = intersection(-3,s)
    >>> i4 = intersection(2,s)
    >>> xs = intersections(i1,i2,i3,i4)
    >>> i = hit(xs)
    >>> i == i4
    True
    """
    for i in intersections:
        if i.t > 0:
            return i
    return None

def transform(r, matrix):
    """
    >>> r = ray(point(1,2,3), vector(0,1,0))
    >>> m = translation(3,4,5)
    >>> r2 = transform(r,m)
    >>> r2.origin == point(4,6,8)
    array([ True,  True,  True,  True])

    >>> r2.direction == vector(0,1,0)
    array([ True,  True,  True,  True])

    >>> r = ray(point(1,2,3), vector(0,1,0))
    >>> m = scaling(2,3,4)
    >>> r2 = transform(r,m)
    >>> r2.origin == point(2,6,12)
    array([ True,  True,  True,  True])

    >>> r2.direction == vector(0,3,0)
    array([ True,  True,  True,  True])
    """
    return ray(matrix_multiply(matrix,r.origin),
               matrix_multiply(matrix,r.direction))