In [1]:
from __future__ import annotations

from typing import List
import sys

import torch
import torch.nn.functional as f
import numpy as np

import matplotlib.pyplot as plt

infinity = sys.float_info.max

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")   

print(device)

In [3]:
def blend(color1: torch.Tensor, color2: torch.Tensor, t_map: torch.Tensor) -> torch.Tensor:
    return (1.0 - t_map) * color1 + t_map * color2


def dot(v1: torch.Tensor, v2: torch.Tensor) -> torch.tenaor:
    return torch.sum(v1 * v2, dim=-1, keepdim=True)


def normalize_color(color: torch.Tensor, samples_per_pixel: int):
    scale = 1.0 / samples_per_pixel
    return scale * color

In [4]:
class Interval:
    def __init__(self, min_map: torch.Tensor, max_map: torch.Tensor, device: torch.device):
        if isinstance(min_map, torch.Tensor):
            self.min_map = min_map
        else:
            self.min_map = torch.tensor(min_map, device=device)

        if isinstance(max_map, torch.Tensor):
            self.max_map = max_map
        else:
            self.max_map = torch.tensor(max_map, device=device)

    def size(self) -> torch.Tensor:
        return self.max_max - self.min_map

    def contains(self, x_map: torch.Tensor) -> torch.Tensor:
        return torch.logical_and(self.min_map <= x_map, x_map <= self.max_map)

    def surrounds(self, x_map: torch.Tensor) -> torch.Tensor:
        return torch.logical_and(self.min_map < x_map, x_map < self.max_map)

    def clamp(self, x_map: torch.Tensor) -> torch.Tensor:
        return torch.clamp(x_map, self.min_map, self.max_map)

In [5]:
class Ray:
    def __init__(self, origin: torch.Tensor, direction_map: torch.Tensor) -> None:
        self.origin = origin
        self.direction_map = f.normalize(direction_map, dim=-1)
        self.height, self.width = self.direction_map.shape[:2]

    def at(self, t_map: torch.Tensor) -> torch.Tensor:
        return self.origin + t_map * self.direction_map

    def get_device(self) -> torch.device:
        return self.origin.device

In [6]:
class HitRecord:
    def __init__(
        self,
        point_map: torch.Tensor,
        normal_map: torch.Tensor,
        t_map: torch.Tensor,
        valid_map: torch.Tensor,
    ) -> None:
        self.point_map = point_map
        self.normal_map = normal_map
        self.t_map = t_map
        self.valid_map = valid_map

    def set_face_normal(self, ray: Ray, outward_normal_map: torch.Tensor) -> None:
        is_front_face = dot(ray.direction_map, outward_normal_map) < 0.0
        self.normal_map = torch.where(
            is_front_face, outward_normal_map, -outward_normal_map
        )

    def update(self, other: HitRecord) -> None:
        self.point_map = torch.where(other.valid_map, other.point_map, self.point_map)
        self.normal_map = torch.where(other.valid_map, other.normal_map, self.normal_map)
        self.t_map = torch.where(other.valid_map, other.t_map, self.t_map)
        self.valid_map = torch.where(other.valid_map, other.valid_map, self.valid_map)

In [7]:
class Hittable:
    def hit(self, ray: Ray, ray_t: Interval) -> HitRecord:
        raise NotImplementedError

In [8]:
class Sphere(Hittable):
    def __init__(self) -> None:
        self.center = torch.zeros(3, device=device)
        self.radius = 1.0

    def __init__(self, center: torch.Tensor, radius: float) -> None:
        self.center = center
        self.radius = radius

    def hit(self, ray: Ray, ray_t: Interval) -> torch.Tensor:
        dir_center_to_origin = ray.origin - self.center

        a_map = dot(ray.direction_map, ray.direction_map)
        half_b_map = dot(dir_center_to_origin, ray.direction_map)
        c_map = dot(dir_center_to_origin, dir_center_to_origin) - self.radius**2.0

        discriminant_map = half_b_map**2 - a_map * c_map
        cond_discriminant = discriminant_map >= 0.0

        safe_discriminant_map = torch.where(cond_discriminant, discriminant_map, 0.0)

        sqrt_d_map = torch.sqrt(safe_discriminant_map)

        # find the nearest root that lies in the acceptable range.
        t_map1 = (-half_b_map - sqrt_d_map) / a_map
        cond1 = ray_t.surrounds(t_map1)
        t_map2 = (-half_b_map + sqrt_d_map) / a_map
        cond2 = ray_t.surrounds(t_map2)

        valid_map = torch.logical_and(cond_discriminant, torch.logical_or(cond1, cond2))

        t_map = torch.where(cond1, t_map1, t_map2)

        point_map = ray.at(t_map)
        outward_normal_map = (point_map - self.center) / self.radius
        hit_record = HitRecord(
            point_map=point_map,
            normal_map=outward_normal_map,
            t_map=t_map,
            valid_map=valid_map,
        )
        hit_record.set_face_normal(ray=ray, outward_normal_map=outward_normal_map)
        return hit_record

In [9]:
class HittableList(Hittable):
    def __init__(self) -> None:
        self.objects: List[Hittable] = []

    def clear(self) -> None:
        self.objects.clear()

    def add(self, object: Hittable) -> None:
        self.objects.append(object)

    def hit(self, ray: Ray, ray_t: Interval) -> HitRecord:
        device = ray.get_device()

        record = None
        closest_so_far_map = ray_t.max_map

        for object in self.objects:
            tmp_record = object.hit(
                ray=ray, ray_t=Interval(ray_t.min_map, closest_so_far_map, device)
            )
            closest_so_far_map = torch.where(
                tmp_record.valid_map, tmp_record.t_map, closest_so_far_map
            )
            if record is None:
                record = tmp_record
            else:
                record.update(tmp_record)

        return record

In [10]:
class Camera:
    def __init__(self, image_width: int, image_height: int, device: torch.device) -> None:
        self.aspect_ratio = 16.0 / 9.0
        self.viewport_height = 2.0
        self.viewport_width = self.aspect_ratio * self.viewport_height
        self.focal_length = 1.0

        self.origin = torch.tensor([0.0, 0.0, 0.0], device=device)
        self.horizontal_vec = torch.tensor([self.viewport_width, 0.0, 0.0], device=device)
        self.vertical_vec = torch.tensor([0.0, self.viewport_height, 0.0], device=device)
        self.frontal_vec = torch.tensor([0.0, 0.0, self.focal_length], device=device)
        self.lower_left_corner = (
            self.origin
            - self.horizontal_vec / 2.0
            - self.vertical_vec / 2.0
            - self.frontal_vec
        )

        self.pixel_delta_u = self.horizontal_vec / image_width
        self.pixel_delta_v = self.vertical_vec / image_height

    def get_ray(self, u_map: torch.Tensor, v_map: torch.Tensor) -> Ray:
        return Ray(
            origin=self.origin,
            direction_map=(
                self.lower_left_corner
                + u_map * self.pixel_delta_u
                + v_map * self.pixel_delta_v
            )
            - self.origin,
        )

In [11]:
def ray_color(ray: Ray, world: Hittable) -> torch.tensor:
    device = ray.get_device()
    record = world.hit(ray=ray, ray_t=Interval(0.0, infinity, device=device))

    world_color = 0.5 * (record.normal_map + 1.0)

    t_map = 0.5 * (ray.direction_map[..., 1:2] + 1.0)
    color1 = torch.tensor([1.0, 1.0, 1.0], device=device)
    color2 = torch.tensor([0.5, 0.7, 1.0], device=device)
    background_color = blend(color1=color1, color2=color2, t_map=t_map)

    return torch.where(record.valid_map, world_color, background_color)

In [12]:
# image
aspect_ratio = 16.0 / 9.0
image_width = 400
image_height = int(image_width / aspect_ratio)
samples_per_pixel = 100

In [13]:
# world
world = HittableList()
world.add(Sphere(center=torch.tensor([0.0, 0.0, -1.0], device=device), radius=0.5))
world.add(Sphere(center=torch.tensor([0.0, -100.5, -1.0], device=device), radius=100.0))

In [14]:
camera = Camera(image_width=image_width, image_height=image_height, device=device)

In [15]:
def rand_uniform(low: float, high: float, size) -> torch.Tensor:
    return low + (high - low) * torch.rand(size, device=device)

In [None]:
%%time
with torch.no_grad():
    us = torch.arange(start=0, end=image_width, dtype=torch.float32, device=device) + 0.5
    vs = torch.arange(start=0, end=image_height, dtype=torch.float32, device=device) + 0.5
    u_map, v_map = torch.meshgrid(us, vs, indexing="xy")
    u_map = u_map[..., None, None]
    v_map = v_map[..., None, None]

    u_jitter = rand_uniform(low=0.0, high=1.0,
                            size=[image_height, image_width, samples_per_pixel, 1])
    v_jitter = rand_uniform(low=0.0, high=1.0,
                            size=[image_height, image_width, samples_per_pixel, 1])

    u_map = u_map + u_jitter
    v_map = v_map + v_jitter
    ray = camera.get_ray(u_map=u_map, v_map=v_map)

    color_map = ray_color(ray=ray, world=world)
    color_map = torch.sum(color_map, dim=2)
    image = normalize_color(color_map, samples_per_pixel=samples_per_pixel)

In [24]:
if image.get_device() == 0:
    image_np = image.cpu().numpy()
else:
    image_np = image.numpy()

In [None]:
plt.imshow(image_np, origin="lower")
plt.show()

In [None]:
plt.imshow(image_np[60:90, 135:165], origin="lower")
plt.show()