In [56]:
import os
import sys
import torch as t
from torch import Tensor
import einops
from ipywidgets import interact
import plotly.express as px

from ipywidgets import interact
from pathlib import Path
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtype
import typeguard

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part1_ray_tracing"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from part1_ray_tracing.utils import render_lines_with_plotly, setup_widget_fig_ray, setup_widget_fig_triangle
import part1_ray_tracing.tests as tests

MAIN = __name__ == "__main__"

ImportError: cannot import name 'jaxtype' from 'jaxtyping' (/Users/kortukov/miniconda3/envs/arena-env/lib/python3.11/site-packages/jaxtyping/__init__.py)

In [4]:
def make_rays_1d(num_pixels: int, y_limit: float) -> t.Tensor:
    '''
    num_pixels: The number of pixels in the y dimension. Since there is one ray per pixel, this is also the number of rays.
    y_limit: At x=1, the rays should extend from -y_limit to +y_limit, inclusive of both endpoints.

    Returns: shape (num_pixels, num_points=2, num_dim=3) where the num_points dimension contains (origin, direction) and the num_dim dimension contains xyz.

    Example of make_rays_1d(9, 1.0): [
        [[0, 0, 0], [1, -1.0, 0]],
        [[0, 0, 0], [1, -0.75, 0]],
        [[0, 0, 0], [1, -0.5, 0]],
        ...
        [[0, 0, 0], [1, 0.75, 0]],
        [[0, 0, 0], [1, 1, 0]],
    ]
    '''
    origin = t.Tensor([0, 0, 0])
    middle = t.tensor([1, 1, 0])
    pair = t.stack([origin, middle])

    rays = einops.repeat(pair, 'p c -> num_rays p c', num_rays=num_pixels).clone()

    t.linspace(-y_limit, y_limit, num_pixels, out=rays[:, 1, 1])

    return rays

    

rays1d = make_rays_1d(9, 10.0)

fig = render_lines_with_plotly(rays1d)

In [5]:
fig = setup_widget_fig_ray()
display(fig)

@interact
def response(seed=(0, 10, 1), v=(-2.0, 2.0, 0.01)):
    t.manual_seed(seed)
    L_1, L_2 = t.rand(2, 2)
    P = lambda v: L_1 + v * (L_2 - L_1)
    x, y = zip(P(-2), P(2))
    with fig.batch_update(): 
        fig.data[0].update({"x": x, "y": y}) 
        fig.data[1].update({"x": [L_1[0], L_2[0]], "y": [L_1[1], L_2[1]]}) 
        fig.data[2].update({"x": [P(v)[0]], "y": [P(v)[1]]})

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'a444a574-b8c3-4a5e-a9b4-ee933ba474ad', 'x': [], 'y': []},
             {'marker': {'size': 12},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'd1182889-e2ce-4d3e-9c49-5f7a483eadad',
              'x': [],
              'y': []},
             {'marker': {'size': 12, 'symbol': 'x'},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'ae9ed712-4fe0-40fa-bc8d-4fe0236fd058',
              'x': [],
              'y': []}],
    'layout': {'height': 500,
               'showlegend': False,
               'template': '...',
               'width': 600,
               'xaxis': {'range': [-1.5, 2.5]},
               'yaxis': {'range': [-1.5, 2.5]}}
})

interactive(children=(IntSlider(value=5, description='seed', max=10), FloatSlider(value=0.0, description='v', …

In [6]:
segments = t.tensor([
    [[1.0, -12.0, 0.0], [1, -6.0, 0.0]], 
    [[0.5, 0.1, 0.0], [0.5, 1.15, 0.0]], 
    [[2, 12.0, 0.0], [2, 21.0, 0.0]]
])

In [7]:
segments[0] # intersects with ray to -10 and to -7.5

tensor([[  1., -12.,   0.],
        [  1.,  -6.,   0.]])

In [8]:
segments[1] # intersects with nothing

tensor([[0.5000, 0.1000, 0.0000],
        [0.5000, 1.1500, 0.0000]])

In [9]:
segments[2] # intersects with y = 10 and y=7.5 

tensor([[ 2., 12.,  0.],
        [ 2., 21.,  0.]])

In [10]:
rays1d * 2

tensor([[[  0.,   0.,   0.],
         [  2., -20.,   0.]],

        [[  0.,   0.,   0.],
         [  2., -15.,   0.]],

        [[  0.,   0.,   0.],
         [  2., -10.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,  -5.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,   5.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,  10.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,  15.,   0.]],

        [[  0.,   0.,   0.],
         [  2.,  20.,   0.]]])

In [11]:
render_lines_with_plotly(t.concatenate([rays1d * 2, segments], dim=0))

In [12]:
@jaxtyped(typechecker=typeguard.typechecked)
def intersect_ray_1d(ray: Float[t.Tensor, "n=2 d=3"], segment: Float[t.Tensor, "n d"]) -> bool:
    '''
    ray: shape (n_points=2, n_dim=3)  # O, D points
    segment: shape (n_points=2, n_dim=3)  # L_1, L_2 points

    Return True if the ray intersects the segment.
    '''
    O = ray[0, 0:2]
    D = ray[1, 0:2] # shape (2,)
    L1 = segment[0, 0:2]
    L2 = segment[1, 0:2]

    mat = t.stack([D, L1 - L2], dim=1)

    rhs = L1 - O
    try:
        uv = t.linalg.solve(mat, rhs)
    except Exception as e:
        return False
    u, v = uv[0].item(), uv[1].item()
    return (u >= 0) and (0 <= v <= 1)




tests.test_intersect_ray_1d(intersect_ray_1d)
tests.test_intersect_ray_1d_special_case(intersect_ray_1d)

All tests in `test_intersect_ray_1d` passed!
All tests in `test_intersect_ray_1d_special_case` passed!



As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax
```
from jaxtyping import jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def foo(...):
```
and the old double-decorator syntax
```
@jaxtyped
@typechecker
def foo(...):
```
should no longer be used. (It will continue to work as it did before, but the new approach will produce more readable error messages.)
In particular note that `typechecker` must be passed via keyword argument; the following is not valid:
```
@jaxtyped(typechecker)
def foo(...):
```




In [13]:
t.linalg.solve(2*t.eye(3), t.Tensor([2,3,4]))

tensor([1.0000, 1.5000, 2.0000])

In [14]:
def intersect_rays_1d(rays: Float[Tensor, "nrays 2 3"], segments: Float[Tensor, "nsegments 2 3"]) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if it intersects any segment.
    '''
    nrays = rays.shape[0]
    nsegments = segments.shape[0]

    Os = rays[:, 0, :2] # (nrays, dim)
    Ds = rays[:, 1, :2]
    L1s = segments[:, 0, :2] # (nsegments, dim)
    L2s = segments[:, 1, :2]

    # Repeat each ray for all segments - one batch row for one ray
    batch_Ds = einops.repeat(Ds, 'nrays dim -> nrays nsegments dim', nsegments=nsegments)
    batch_Os = einops.repeat(Os, 'nrays dim -> nrays nsegments dim', nsegments=nsegments)

    # Repeat each segment for all rays one batch column for one ray
    batch_L1s = einops.repeat(L1s, 'nsegments dim -> nrays nsegments dim', nrays=nrays)
    batch_L2s = einops.repeat(L2s, 'nsegments dim -> nrays nsegments dim', nrays=nrays)

    mat_columns = [batch_Ds, batch_L1s - batch_L2s]
    mats = einops.rearrange(mat_columns, 'cols nrays nsegments dim -> nrays nsegments dim cols')

    rhss = batch_L1s - batch_Os # shape (nrays, nsegments, 2dim)

    # solutions = t.linalg.solve(mats, rhss) # shape(nrays, 2dim)

    # To handle non-invertible matrices
    solutions, residuals, rank, sing_vals = t.linalg.lstsq(mats, rhss)
    # Solution of linear system possible only if full rank.
    # Otherwise no solution - means ray and segment parallel.
    full_rank_indices = rank == 2
    assert full_rank_indices.shape == (nrays, nsegments)


    us = solutions[..., 0]
    vs = solutions[..., 1]

    intersect_on_ray = us >= 0
    intersect_on_segment = (0 <= vs) & (vs <= 1)
    # shape (nrays, nsegments)
    intersection = intersect_on_ray & intersect_on_segment & full_rank_indices 

    ray_intersects_with_any = intersection.any(dim=1)
    return ray_intersects_with_any




tests.test_intersect_rays_1d(intersect_rays_1d)
tests.test_intersect_rays_1d_special_case(intersect_rays_1d)

All tests in `test_intersect_rays_1d` passed!
All tests in `test_intersect_rays_1d_special_case` passed!


In [15]:
@jaxtyped(typechecker=typeguard.typechecked)
def make_rays_2d(num_pixels_y: int, num_pixels_z: int, y_limit: float, z_limit: float) -> Float[t.Tensor, "nrays 2 3"]:
    '''
    num_pixels_y: The number of pixels in the y dimension
    num_pixels_z: The number of pixels in the z dimension

    y_limit: At x=1, the rays should extend from -y_limit to +y_limit, inclusive of both.
    z_limit: At x=1, the rays should extend from -z_limit to +z_limit, inclusive of both.

    Returns: shape (num_rays=num_pixels_y * num_pixels_z, num_points=2, num_dims=3).
    '''
    nrays = num_pixels_y * num_pixels_z
    y_space = t.linspace(-y_limit, y_limit, num_pixels_y)
    z_space = t.linspace(-z_limit, z_limit, num_pixels_z)
    rays = t.zeros(nrays, 2, 3)
    rays[:, 1, 0] = 1
    rays[:, 1, 1] = einops.repeat(y_space, 'y -> (y z)', z=num_pixels_z)
    rays[:, 1, 2] = einops.repeat(z_space, 'z -> (y z)', y=num_pixels_y)
    return rays 


rays_2d = make_rays_2d(10, 10, 0.3, 0.3)
render_lines_with_plotly(rays_2d)

In [16]:
one_triangle = t.tensor([[0, 0, 0], [3, 0.5, 0], [2, 3, 0]])
A, B, C = one_triangle
x, y, z = one_triangle.T

fig = setup_widget_fig_triangle(x, y, z)

@interact(u=(-0.5, 1.5, 0.01), v=(-0.5, 1.5, 0.01))
def response(u=0.0, v=0.0):
    P = A + u * (B - A) + v * (C - A)
    fig.data[2].update({"x": [P[0]], "y": [P[1]]})

display(fig)

interactive(children=(FloatSlider(value=0.0, description='u', max=1.5, min=-0.5, step=0.01), FloatSlider(value…

FigureWidget({
    'data': [{'marker': {'size': 12},
              'mode': 'markers+text',
              'text': [A, B, C],
              'textfont': {'size': 18},
              'textposition': 'middle left',
              'type': 'scatter',
              'uid': '2e1c4ed7-f3ac-45d9-94c7-d7c9dd26e8ef',
              'x': array([0., 3., 2.], dtype=float32),
              'y': array([0. , 0.5, 3. ], dtype=float32)},
             {'mode': 'lines',
              'type': 'scatter',
              'uid': '1cf504c0-3fb5-4315-a5a3-7b898a5bf27e',
              'x': [0.0, 3.0, 2.0, 0.0],
              'y': [0.0, 0.5, 3.0, 0.0]},
             {'marker': {'size': 12, 'symbol': 'x'},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'a5220c38-629c-433b-a3f2-dc6e57e6ccb1',
              'x': [0.0],
              'y': [0.0]}],
    'layout': {'height': 600,
               'showlegend': False,
               'template': '...',
               'title': {'text': 'Barycen

In [17]:
Point = Float[Tensor, "points=3"]

@jaxtyped
@typeguard.typechecked
def triangle_ray_intersects(A: Point, B: Point, C: Point, O: Point, D: Point) -> bool:
    '''
    A: shape (3,), one vertex of the triangle
    B: shape (3,), second vertex of the triangle
    C: shape (3,), third vertex of the triangle
    O: shape (3,), origin point
    D: shape (3,), direction point

    Return True if the ray and the triangle intersect.
    '''
    lhs = t.stack([-D, (B-A), (C-A)], dim=1)
    rhs = O - A
    try:
        suv = t.linalg.solve(lhs, rhs)
    except Exception as e:
        return False
    s, u, v = suv

    return ((u >= 0) & (v >= 0) & (u + v <= 1)).item()


tests.test_triangle_ray_intersects(triangle_ray_intersects)

All tests in `test_triangle_ray_intersects` passed!



As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax
```
from jaxtyping import jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def foo(...):
```
and the old double-decorator syntax
```
@jaxtyped
@typechecker
def foo(...):
```
should no longer be used. (It will continue to work as it did before, but the new approach will produce more readable error messages.)
In particular note that `typechecker` must be passed via keyword argument; the following is not valid:
```
@jaxtyped(typechecker)
def foo(...):
```




In [18]:
a = t.Tensor([[1,2,3],[3,4,5]])
b = a.view((3,2))

In [19]:
a.storage().data_ptr(), b.storage().data_ptr()


TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()



(13717803776, 13717803776)

In [20]:
x = t.zeros(1024*1024*1024)
y = x[0].detach()

In [21]:
del x

In [22]:
y._base is None

True

In [23]:
@jaxtyped(typechecker=typeguard.typechecked)
def raytrace_triangle(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangle: Float[Tensor, "trianglePoints=3 dims=3"]
) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if the triangle intersects that ray.
    '''
    nrays = rays.shape[0]
    # (nrays, dims)
    Os, Ds = rays.unbind(dim=1)

    triangles = einops.repeat(triangle, 'points dims -> nrays points dims', nrays=nrays)
    As, Bs, Cs = triangles.unbind(dim=1)


    lhs_list = [-Ds, (Bs - As), (Cs - As)]
    lhs = einops.rearrange(lhs_list, 'cols nrays dim -> nrays dim cols')

    rhs = Os - As # (nrays, dims)

    dets: Float[Tensor, "nrays"] = t.linalg.det(lhs)
    is_singular = dets.abs() < 1e-8
    lhs[is_singular] = t.eye(3, 3)


    suv = t.linalg.solve(lhs, rhs)
    s, u, v = suv.unbind(dim=1)

    intersect_on_ray = s >= 0
    intersect_on_segment = (0 <= u) & (0 <= v) & (v + u <= 1)

    intersection: Bool[Tensor, "nrays"]= intersect_on_ray & intersect_on_segment & ~is_singular
    return intersection




A = t.tensor([1, 0.0, -0.5])
B = t.tensor([1, -0.5, 0.0])
C = t.tensor([1, 0.5, 0.5])
num_pixels_y = num_pixels_z = 30
y_limit = z_limit = 0.5

# Plot triangle & rays
test_triangle = t.stack([A, B, C], dim=0)
rays2d = make_rays_2d(num_pixels_y, num_pixels_z, y_limit, z_limit)
triangle_lines = t.stack([A, B, C, A, B, C], dim=0).reshape(-1, 2, 3)
render_lines_with_plotly(rays2d, triangle_lines)

# Calculate and display intersections
intersects = raytrace_triangle(rays2d, test_triangle)
img = intersects.reshape(num_pixels_y, num_pixels_z).int()
imshow(img, origin="lower", width=600, title="Triangle (as intersected by rays)")

In [37]:
def raytrace_triangle_with_bug(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangle: Float[Tensor, "trianglePoints=3 dims=3"]
) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if the triangle intersects that ray.
    '''
    NR = rays.size(0)

    A, B, C = einops.repeat(triangle, "pts dims -> pts NR dims", NR=NR)

    O, D = rays.unbind(1)

    mat = t.stack([- D, B - A, C - A], dim=-1)

    dets = t.linalg.det(mat)
    is_singular = dets.abs() < 1e-8
    mat[is_singular] = t.eye(3)

    vec = O - A

    sol = t.linalg.solve(mat, vec)
    s, u, v = sol.unbind(dim=-1)

    return ((u >= 0) & (v >= 0) & (u + v <= 1) & ~is_singular)


intersects = raytrace_triangle_with_bug(rays2d, test_triangle)
img = intersects.reshape(num_pixels_y, num_pixels_z).int()
imshow(img, origin="lower", width=600, title="Triangle (as intersected by rays)")

In [38]:
with open(section_dir / "pikachu.pt", "rb") as f:
    triangles = t.load(f)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [63]:
@jaxtyped(typechecker=typeguard.typechecked)
def raytrace_mesh(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangles: Float[Tensor, "ntriangles trianglePoints=3 dims=3"]
) -> Float[Tensor, "nrays"]:
    '''
    For each ray, return the distance to the closest intersecting triangle, or infinity.
    '''
    nrays = rays.size(0)
    ntriangles = triangles.size(0)
    rays = einops.repeat(rays, "nrays rpoints dims -> nrays ntriangles rpoints dims", ntriangles=ntriangles)
    # (nrays, ntriangles, dims)
    Os, Ds = rays.unbind(dim=2)

    triangles = einops.repeat(triangles, 'ntriangles tpoints dims -> nrays ntriangles tpoints dims', nrays=nrays)
    As, Bs, Cs = triangles.unbind(dim=2)

    tria_min_x = triangles[:, :, :, 0].min(2).values
    distances = tria_min_x - rays[:, :, 0, 0] 


    lhs_list = [-Ds, (Bs - As), (Cs - As)]
    lhs = einops.rearrange(lhs_list, 'cols nrays ntria dim -> nrays ntria dim cols')
    lhs = einops.repeat(lhs, 'a b c d -> 4 a b c d')

    rhs = Os - As # (nrays, ntriangles, dims)

    dets: Float[Tensor, "nrays ntria"] = t.linalg.det(lhs)
    is_singular = dets.abs() < 1e-8
    lhs[is_singular] = t.eye(3, 3)


    suv = t.linalg.solve(lhs, rhs)
    s, u, v = suv.unbind(dim=2)

    intersect_on_ray = s >= 0
    intersect_on_segment = (0 <= u) & (0 <= v) & (v + u <= 1)

    intersection: Bool[Tensor, "nrays ntria"]= intersect_on_ray & intersect_on_segment & ~is_singular
    
    distances[~intersection] = float("inf")
    min_distances = distances.min(dim=1).values


    return min_distances


num_pixels_y = 120
num_pixels_z = 120
y_limit = z_limit = 1

rays = make_rays_2d(num_pixels_y, num_pixels_z, y_limit, z_limit)
rays[:, 0] = t.tensor([-2, 0.0, 0.0])
dists = raytrace_mesh(rays, triangles)
intersects = t.isfinite(dists).view(num_pixels_y, num_pixels_z)
dists_square = dists.view(num_pixels_y, num_pixels_z)
img = t.stack([intersects, dists_square], dim=0)

fig = px.imshow(img, facet_col=0, origin="lower", color_continuous_scale="magma", width=1000)
fig.update_layout(coloraxis_showscale=False)
for i, text in enumerate(["Intersects", "Distance"]): 
    fig.layout.annotations[i]['text'] = text
fig.show()


Use of index_put_ on expanded tensors is deprecated. Please clone() the tensor before performing this operation. This also applies to advanced indexing e.g. tensor[indices] = tensor (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:719.)



RuntimeError: linalg.solve: Incompatible shapes of A and B for the equation AX = B (3x3 and 412x3)