In [None]:
import itertools
from multiprocessing import Pool
import warnings

EPSILON = 0.000001
BLOCK_SIZE=2

class World(object):
    def __init__(self, lights, contains=list()):
        self.lights = lights
        self.contains = contains

def world():
    """
    >>> w = world()
    >>> w.lights is None and len(w.contains) == 0
    True
    """
    return World(None)
    
def default_world():
    """
    >>> light = point_light(point(-10, 10, -10), color(1,1,1))
    >>> s1 = sphere()
    >>> s1.material.color = color(0.8,1.0,0.6)
    >>> s1.material.diffuse = 0.7
    >>> s1.material.specular = 0.2
    >>> s2 = sphere()
    >>> s2.transform = scaling(0.5,0.5,0.5)
    >>> w = default_world()
    >>> w.lights[0].position == light.position
    array([ True,  True,  True,  True])

    >>> w.lights[0].intensity == light.intensity
    array([ True,  True,  True])
    >>> len(w.contains) == 2
    True
    >>> w.contains[0].material.color == color(0.8,1.0,0.6)  
    array([ True,  True,  True])
    >>> w.contains[0].material.diffuse == 0.7 and w.contains[0].material.specular == 0.2
    True
    >>> w.contains[1].transform == scaling(0.5,0.5,0.5)
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])
    """
    light = point_light(point(-10, 10, -10), color(1,1,1))
    s1 = sphere()
    s1.material.color = color(0.8,1.0,0.6)
    s1.material.diffuse = 0.7
    s1.material.specular = 0.2
    s2 = sphere()
    s2.transform = scaling(0.5,0.5,0.5)
    return World([light], [s1,s2])

def intersect_world(world, r):
    """
    >>> w = default_world()
    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> xs = intersect_world(w, r)
    >>> len(xs) == 4 and xs[0].t == 4 and xs[1].t == 4.5 and xs[2].t == 5.5 and xs[3].t == 6
    True
    """
    xs = []
    for obj in world.contains:
        xs.extend(intersect(obj, r))
    return intersections(*xs)

class Computations(object):
    def __init__(self, t, obj, point, eyev, normalv, inside):
        self.t = t
        self.object = obj
        self.point = point
        self.eyev = eyev
        self.normalv = normalv
        self.inside = inside
        self.over_point = None
        
def prepare_computations(intersection, r):
    """
    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> shape = sphere()
    >>> i = intersection(4, shape)
    >>> comps = prepare_computations(i, r)
    >>> comps.t == i.t
    True
    >>> comps.object == i.object
    True
    >>> comps.point == point(0,0,-1)
    array([ True,  True,  True,  True])

    >>> comps.eyev == vector(0,0,-1)
    array([ True,  True,  True,  True])

    >>> comps.normalv == vector(0,0,-1)
    array([ True,  True,  True,  True])

    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> shape = sphere()
    >>> i = intersection(4, shape)
    >>> comps = prepare_computations(i, r)
    >>> comps.inside
    False

    >>> r = ray(point(0,0,0), vector(0,0,1))
    >>> shape = sphere()
    >>> i = intersection(1, shape)
    >>> comps = prepare_computations(i, r)
    >>> comps.inside
    True
    >>> comps.point == point(0,0,1)
    array([ True,  True,  True,  True])

    >>> comps.eyev == vector(0,0,-1)
    array([ True,  True,  True,  True])

    >>> comps.normalv == vector(0,0,-1)
    array([ True,  True,  True,  True])

    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> shape = sphere()
    >>> shape.transform = translation(0,0,1)
    >>> i = intersection(5, shape)
    >>> comps = prepare_computations(i, r)
    >>> comps.over_point[2] < -EPSILON/2
    True
    >>> comps.point[2] > comps.over_point[2]
    True
    """
    p = position(r, intersection.t)

    c = Computations(intersection.t,
                    intersection.object,
                    p,
                    -r.direction,
                    normal_at(intersection.object, p),
                    False)

    if dot(c.normalv, c.eyev) < 0:
        c.inside = True
        c.normalv = -c.normalv

    c.over_point = c.point + c.normalv * EPSILON

    return c

def shade_hit(world, comps):
    """
    >>> w = default_world()
    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> shape = w.contains[0]
    >>> i = intersection(4, shape)
    >>> comps = prepare_computations(i, r)
    >>> c = shade_hit(w, comps)
    >>> np.isclose(c,color(0.38066, 0.47583, 0.28549589))
    array([ True,  True,  True])

    >>> w = default_world()
    >>> w.lights[0] = point_light(point(0,0.25,0), color(1,1,1))
    >>> r = ray(point(0,0,0), vector(0,0,1))
    >>> shape = w.contains[1]
    >>> i = intersection(0.5, shape)
    >>> comps = prepare_computations(i,r)
    >>> c = shade_hit(w, comps)
    >>> np.isclose(c,color(0.90498, 0.90498, 0.90498))
    array([ True,  True,  True])

    >>> w = world()
    >>> w.lights = [point_light(point(0,0,-1), color(1,1,1))]
    >>> s1 = sphere()
    >>> s2 = sphere()
    >>> s2.transform = translation(0,0,10)
    >>> w.contains = [s1,s2]
    >>> r = ray(point(0,0,5), vector(0,0,1))
    >>> i = intersection(4, s2)
    >>> comps = prepare_computations(i, r)
    >>> c = shade_hit(w, comps)
    >>> w.contains = []
    >>> np.isclose(c, color(0.1,0.1,0.1))
    array([ True,  True,  True])

    """
    return sum([lighting(comps.object.material,
                    comps.object,
                    light,
                    comps.point,
                    comps.eyev,
                    comps.normalv,
                    is_shadowed(world, comps.over_point, light)) for light in world.lights])


def is_shadowed(world, point, light=None):
    """
    >>> w = default_world()
    >>> p = point(0,10,0)
    >>> is_shadowed(w, p)
    False

    >>> w = default_world()
    >>> p = point(10,-10,10)
    >>> is_shadowed(w, p)
    True

    >>> w = default_world()
    >>> p = point(-20,20,-20)
    >>> is_shadowed(w, p)
    False

    >>> w = default_world()
    >>> p = point(-2,2,-2)
    >>> is_shadowed(w, p)
    False

    """
    if light is None:
        light = world.lights[0]

    v = light.position - point
    distance = magnitude(v)
    direction = normalize(v)
    r = ray(point, direction)
    intersections = intersect_world(world, r)
    h = hit(intersections)
    
    return h is not None and h.t < distance

def color_at(world, ray):
    """
    >>> w = default_world()
    >>> r = ray(point(0,0,-5), vector(0,1,0))
    >>> c = color_at(w,r)
    >>> np.isclose(c, color(0,0,0))
    array([ True,  True,  True])

    >>> w = default_world()
    >>> r = ray(point(0,0,-5), vector(0,0,1))
    >>> c = color_at(w,r)
    >>> np.isclose(c, color(0.38066, 0.47583, 0.28549589))
    array([ True,  True,  True])

    >>> w = default_world()
    >>> outer = w.contains[0]
    >>> outer.material.ambient = 1.0
    >>> inner = w.contains[1]
    >>> inner.material.ambient = 1.0
    >>> r = ray(point(0,0,0.75), vector(0,0,-1))
    >>> c = color_at(w,r)
    >>> c == inner.material.color
    array([ True,  True,  True])
    """
    xs = intersect_world(world, ray)
    i = hit(xs)
    if i is None:
        return black
    comps = prepare_computations(i,ray)
    return shade_hit(world, comps)


def view_transform(fr, to, up):
    """
    >>> fr = point(0,0,0)
    >>> to = point(0,0,-1)
    >>> up = vector(0,1,0)
    >>> t = view_transform(fr, to, up)
    >>> t == matrix4x4identity()
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])

    >>> fr = point(0,0,0)
    >>> to = point(0,0,1)
    >>> up = vector(0,1,0)
    >>> t = view_transform(fr, to, up)
    >>> t == scaling(-1,1,-1)
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])

    >>> fr = point(0,0,8)
    >>> to = point(0,0,0)
    >>> up = vector(0,1,0)
    >>> t = view_transform(fr, to, up)
    >>> t == translation(0,0,-8)
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])

    >>> fr = point(1,3,2)
    >>> to = point(4,-2,8)
    >>> up = vector(1,1,0)
    >>> t = view_transform(fr, to, up)
    >>> t
    array([[-0.50709255,  0.50709255,  0.6761234 , -2.36643191],
           [ 0.76771593,  0.60609153,  0.12121831, -2.82842712],
           [-0.35856858,  0.5976143 , -0.71713717,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  1.        ]])

    """
    forward = normalize(to - fr)
    upn = normalize(up)
    left = cross(forward, upn)
    true_up = cross(left, forward)
    orientation = matrix(left[0], left[1], left[2], 0,
                         true_up[0], true_up[1], true_up[2], 0,
                         -forward[0], -forward[1], -forward[2], 0,
                         0,0,0,1)
    return matrix_multiply(orientation, translation(-fr[0], -fr[1], -fr[2]))

class Camera(object):
    def __init__(self, hsize, vsize, field_of_view, transform=matrix4x4identity()):
        self.hsize = hsize
        self.vsize = vsize
        self.field_of_view = field_of_view
        self.transform = transform
        self.half_width, self.half_height, self.pixel_size = self._compute_sizes()

    def _compute_sizes(self):
        half_view = np.tan(self.field_of_view / 2)
        aspect = float(self.hsize) / float(self.vsize)
        if aspect >= 1:
            half_width = half_view
            half_height = half_view / aspect
        else:
            half_width = half_view * aspect
            half_height = half_view
        return half_width, half_height, half_width * 2 / self.hsize

def camera(hsize, vsize, field_of_view):
    """
    >>> hsize = 160
    >>> vsize = 120
    >>> field_of_view = np.pi / 2
    >>> c = camera(hsize, vsize, field_of_view)
    >>> c.hsize == 160 and c.vsize == 120 and c.field_of_view == np.pi / 2
    True
    >>> c.transform == matrix4x4identity()
    array([[ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]])

    >>> c = camera(200, 125, np.pi / 2)    
    >>> np.isclose(c.pixel_size, 0.01)
    True

    >>> c = camera(125, 200, np.pi / 2)
    >>> np.isclose(c.pixel_size,0.01)
    True
    """
    return Camera(hsize, vsize, field_of_view)

def ray_for_pixel(cam, px, py):
    """
    >>> c = camera(201, 101, np.pi/2)
    >>> r = ray_for_pixel(c, 100, 50)
    >>> r.origin == point(0,0,0)
    array([ True,  True,  True,  True])

    >>> np.isclose(r.direction, vector(0,0,-1))
    array([ True,  True,  True,  True])

    >>> c = camera(201, 101, np.pi/2)
    >>> r = ray_for_pixel(c, 0, 0)
    >>> r.origin == point(0,0,0)
    array([ True,  True,  True,  True])

    >>> np.isclose(r.direction, vector(0.66519, 0.33259, -0.66851))
    array([ True,  True,  True,  True])

    >>> c = camera(201, 101, np.pi/2)
    >>> c.transform = matrix_multiply(rotation_y(np.pi/4), translation(0,-2,5))
    >>> r = ray_for_pixel(c, 100, 50)
    >>> np.isclose(r.origin, point(0,2,-5))
    array([ True,  True,  True,  True])

    >>> np.isclose(r.direction, vector(np.sqrt(2)/2, 0, -np.sqrt(2)/2))
    array([ True,  True,  True,  True])
    """
    xoffset = (px + 0.5) * cam.pixel_size
    yoffset = (py + 0.5) * cam.pixel_size
    world_x = cam.half_width - xoffset
    world_y = cam.half_height - yoffset

    pixel = matrix_multiply(inverse(cam.transform), point(world_x, world_y, -1))
    origin = matrix_multiply(inverse(cam.transform), point(0,0,0))
    direction = normalize(pixel - origin)

    return ray(origin, direction)

def render(cam, world):
    """
    >>> w = default_world()
    >>> c = camera(11, 11, np.pi/2)
    >>> fr = point(0,0,-5)
    >>> to = point(0,0,0)
    >>> up = vector(0,1,0)
    >>> c.transform = view_transform(fr, to, up)
    >>> image = render(c,w)
    >>> np.isclose(pixel_at(image, 5, 5), color(0.38066119, 0.47582649, 0.28549589))
    array([ True,  True,  True])
    """
    image = canvas(cam.hsize, cam.vsize)
    for y in range(cam.vsize):
        for x in range(cam.hsize):
            r = ray_for_pixel(cam, x, y)
            c = color_at(world, r)
            write_pixel(image, x, y, c)

    return image

# https://jonasteuwen.github.io/numpy/python/multiprocessing/2017/01/07/multiprocessing-numpy-array.html
# http://thousandfold.net/cz/2014/05/01/sharing-numpy-arrays-between-processes-using-multiprocessing-and-ctypes/
# Pool size 4 on raspberry pi 3b+
def render_multi_helper(args):
    cam, world, window_x, window_y = args

    for idx_x in range(window_x, window_x + BLOCK_SIZE):
        for idx_y in range(window_y, window_y + BLOCK_SIZE):
            r = ray_for_pixel(cam, idx_x, idx_y)
            c = color_at(world, r)
            write_pixel(image, idx_x, idx_y, c)

def render_multi(cam, world, num_threads=4):
    """
    >>> w = default_world()
    >>> c = camera(11, 11, np.pi/2)
    >>> fr = point(0,0,-5)
    >>> to = point(0,0,0)
    >>> up = vector(0,1,0)
    >>> c.transform = view_transform(fr, to, up)
    >>> image = render_multi(c,w)
    >>> np.isclose(pixel_at(image, 5, 5), color(0.38066119, 0.47582649, 0.28549589))
    array([ True,  True,  True])
    """
    global image
    image = canvas(cam.hsize, cam.vsize)
    window_idxs = [(cam, world, i, j) for i, j in
                   itertools.product(range(0, cam.hsize, BLOCK_SIZE),
                                     range(0, cam.vsize, BLOCK_SIZE))]

    p = Pool(num_threads)
    _ = p.map(render_multi_helper, window_idxs)
    p.close()
    p.join()
    return np.ctypeslib.as_array(image.shared_arr)