In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from PIL import Image, ImageDraw


class SpaceFillingCurve(nn.Module):
    def __init__(
        self,
        num_points=100,
        canvas_width=256,
        canvas_height=256,
        pen_width=2,
        intersection_thresh=4.0,
        device="cpu",
    ):
        super().__init__()
        self.num_points = num_points
        self.canvas_width = canvas_width
        self.canvas_height = canvas_height
        self.pen_width = pen_width
        self.intersection_thresh = intersection_thresh
        self.device = device

        self.frames = []
        # Initialize control points as a straight horizontal line through the center.
        # Points will range in x from margin to width-margin, and y is constant.
        margin = 20
        xs = torch.linspace(margin, canvas_width - margin, num_points, device=device)
        xs += torch.randn_like(xs) * 0.01
        ys = torch.full_like(xs, canvas_height//2, device=device)
        ys += torch.randn_like(ys) * 0.01
        # Parameterize as a learnable tensor of shape (num_points, 2)
        self.points = nn.Parameter(torch.stack([xs, ys], dim=1))

    def forward(self):
        # Retrieve the curve points; shape: (num_points, 2)
        pts = self.points

        # Compute total curve length as sum of distances along the polyline.
        diffs = pts[1:] - pts[:-1]   # shape: (num_points-1,2)
        seg_lengths = torch.linalg.norm(diffs, dim=1)
        total_length = seg_lengths.sum()

        # Now, add an intersection penalty.
        # Idea: For each pair of non-adjacent segments, compute a (differentiable)
        # “distance” measure. If two segments come too close, we add a penalty.
        # (Note: This is not a full topological intersection test, but a differentiable proxy.)
        penalty = 0.0

        # For each segment i:
        # Let seg_i goes from pts[i] to pts[i+1]
        # For each segment j (where j>i+1 to avoid adjacent segments overlapping by design)
        num_segs = pts.shape[0] - 1
        for i in range(num_segs):
            p0 = pts[i]
            p1 = pts[i+1]
            seg_vec1 = p1 - p0
            seg_len1 = seg_lengths[i] + 1e-6
            # For segments far from i (avoid immediate neighbors: i-1, i, i+1)
            for j in range(i+2, num_segs):
                # Optionally, skip if j == i+1 (adjacent) but our loop already does that.
                q0 = pts[j]
                q1 = pts[j+1]
                seg_vec2 = q1 - q0
                seg_len2 = seg_lengths[j] + 1e-6

                # We will compute a differentiable “distance between segments”
                # by computing the distance between the infinite lines, and then
                # penalizing if the closest points along the segments are within
                # intersection_thresh.
                # Calculate a distance of two infinite lines using the formula:
                #   d = ||(q0-p0) . (n)||, where n is a unit vector perpendicular to seg_vec1.
                # However, we want a more refined approximation.
                #
                # Instead we compute the following: sample one or two points along seg2,
                # project into seg1’s coordinate, and compute distances – all in a differentiable way.

                # We sample the midpoint of seg2 for a rough estimate.
                mid_q = (q0 + q1) / 2.0
                # Compute the distance from mid_q to the line determined by seg1.
                # the projection of (mid_q - p0) onto seg_vec1:
                t = torch.dot(mid_q - p0, seg_vec1) / (seg_len1*seg_len1)
                # clamp t to [0,1] to get distance to segment not infinite line
                t = t.clamp(0.0, 1.0)
                closest_on_seg1 = p0 + t * seg_vec1
                dist = torch.linalg.norm(mid_q - closest_on_seg1)
                # Penalize if dist is less than threshold. Use a differentiable ReLU.
                penalty += F.relu(self.intersection_thresh - dist)

        # The loss encourages the curve to have as long a length as possible,
        # but subtracts a penalty for parts that get too close.
        loss = -total_length + penalty * 1000.0
        # The penalty weight (1000) may need tuning.

        # Create a visual frame. We detach the parameters and convert to numpy for rendering.
        # Rendering is done in pure PIL drawing and then converted to numpy (uint8).
        # For performance we draw only the curve.
        image = Image.new('L', (self.canvas_width, self.canvas_height), color=0)
        drawer = ImageDraw.Draw(image)
        # Get coordinates as list of (x,y) pairs – note converting floats to ints.
        pts_np: list[list[float]] = pts.detach().cpu().numpy().ravel().tolist() # type:ignore
        # Draw the polyline in white.
        # print(pts_np)
        drawer.line(pts_np, fill=255, width=self.pen_width)
        # Append the frame (as a numpy uint8 array)
        self.frames.append(np.array(image, dtype=np.uint8))

        return loss

In [36]:
device = 'cpu'
model = SpaceFillingCurve(num_points=50, canvas_width=256, canvas_height=256, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1)

# A simple training loop
for iter in range(300):
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    if iter % 5 == 0:
        print(f"Iteration {iter}, Loss: {loss.item()}")



Iteration 0, Loss: -216.00448608398438
Iteration 5, Loss: -454.9916687011719
Iteration 10, Loss: -740.27880859375
Iteration 15, Loss: -937.346923828125
Iteration 20, Loss: -843.263671875
Iteration 25, Loss: -1209.4986572265625
Iteration 30, Loss: -1294.6658935546875
Iteration 35, Loss: -1349.624755859375
Iteration 40, Loss: -1386.7459716796875
Iteration 45, Loss: -1467.969482421875
Iteration 50, Loss: -1528.9031982421875
Iteration 55, Loss: -1592.05908203125
Iteration 60, Loss: -1657.259765625
Iteration 65, Loss: -1724.446533203125
Iteration 70, Loss: -1778.10546875
Iteration 75, Loss: -1832.050048828125
Iteration 80, Loss: -1889.0452880859375
Iteration 85, Loss: -1948.328125
Iteration 90, Loss: -2009.464111328125
Iteration 95, Loss: -2072.1640625
Iteration 100, Loss: -2136.204345703125
Iteration 105, Loss: -2198.819091796875
Iteration 110, Loss: -2262.763671875
Iteration 115, Loss: -2202.33447265625
Iteration 120, Loss: -1740.2322998046875
Iteration 125, Loss: -2460.8642578125
Iterati

In [37]:
from myai.video import render_frames
render_frames('fill', model.frames)