In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class DifferentiablePolygonRenderer(nn.Module):
    def __init__(self, target_image, num_polygons=50, vertices_per_polygon=3, softness=1.0):
        """
        target_image: a torch.Tensor of shape (3,H,W) with pixel values in [0,1]
        num_polygons: number of polygons in the approximation
        vertices_per_polygon: how many vertices per polygon
        canvas_size: tuple, (H,W) dimensions of rendered image
        softness: controls the blur of the edge (larger gives softer edges)
        """
        super().__init__()
        self.target_image = target_image # should be float32 tensor (3, H, W)
        self.num_polygons = num_polygons
        self.vertices_per_polygon = vertices_per_polygon
        self.canvas_height, self.canvas_width = target_image.shape[1:]
        self.softness = softness
        self.frames = [] # to store uint8 frame images

        # Initialize polygon vertices: (num_polygons, vertices_per_polygon, 2)
        # Coordinates in normalized space [0,1] (later multiplied by canvas size)
        self.vertices = nn.Parameter(torch.rand(num_polygons, vertices_per_polygon, 2))
        # Colors for each polygon: (num_polygons, 3), each channel in [0,1]
        self.colors = nn.Parameter(torch.rand(num_polygons, 3))
        # Alpha (opacity) for each polygon: (num_polygons,)
        self.alpha = nn.Parameter(torch.rand(num_polygons))  # we'll use it for blending

    def forward(self):
        # Render the current polygons into an image
        rendered = self.render_polygons()
        # Compute loss (MSE) between rendered image and target image
        loss = F.mse_loss(rendered, self.target_image)
        # Convert rendered image to uint8 (0-255) and store in frames for visualization
        with torch.no_grad():
            # Clamp, and bring to CPU numpy format (3,H,W) with uint8 values
            img_uint8 = (rendered.clamp(0,1)*255).byte().cpu().numpy()
            # Transpose to H,W,3 if needed
            img_uint8 = np.transpose(img_uint8, (1,2,0))
            self.frames.append(img_uint8)
        return loss

    def render_polygons(self):
        """
        Render all polygons onto a canvas using a differentiable soft rasterizer.
        This implementation uses a soft mask for each polygon and blends over a white background.
        """
        device = self.vertices.device

        # Create a canvas grid
        y = torch.linspace(0, 1, self.canvas_height, device=device).view(-1, 1).expand(self.canvas_height, self.canvas_width)
        x = torch.linspace(0, 1, self.canvas_width, device=device).view(1, -1).expand(self.canvas_height, self.canvas_width)
        grid = torch.stack([x, y], dim=-1)  # shape (H, W, 2)

        # Prepare canvas for RGBA accumulation: start with white background.
        canvas = torch.ones(3, self.canvas_height, self.canvas_width, device=device)

        # For each polygon, compute soft mask and blend onto canvas
        # We can vectorize the computation over the polygons.

        # First, transform vertices to image coordinates
        poly_vertices = self.vertices.clone()  # shape: (num_polygons, vertices_per_polygon, 2)
        # They are normalized coordinates [0, 1]. No conversion needed since grid runs 0-1.

        # Compute soft mask for each polygon: 
        # We use the idea of inside-outside function: for each edge of the polygon, compute the distance of each grid pixel to the edge.
        # Then combine these distances with a sigmoid to produce a soft mask.
        # We assume the polygon is convex for simplicity.
        num_poly = self.num_polygons
        mask = torch.ones(num_poly, self.canvas_height, self.canvas_width, device=device)

        # vectorizing over vertices/edges
        for j in range(self.vertices_per_polygon):
            # current vertex index and next vertex index (with wrap-around)
            v0 = poly_vertices[:, j, :]  # (num_poly, 2)
            v1 = poly_vertices[:, (j+1) % self.vertices_per_polygon, :]  # (num_poly, 2)
            # edge vector from v0 to v1
            edge = v1 - v0  # (num_poly, 2)
            # vector from v0 to each grid point: grid shape (H,W,2) -> unsqueeze polygon dim: (num_poly, H, W, 2)
            vec = grid.unsqueeze(0) - v0.unsqueeze(1).unsqueeze(2)  # (num_poly, H, W, 2)
            # Compute perp dot product to determine which side of the line the pixel is on.
            # perp = (edge_y, -edge_x) for each edge
            perp = torch.stack([edge[:,1], -edge[:,0]], dim=1)  # (num_poly, 2)
            # Dot product for each polygon and each pixel:
            dot = (vec * perp.unsqueeze(1).unsqueeze(2)).sum(dim=-1)  # (num_poly, H, W)
            # Soft mask: for pixels inside, dot should be >= 0; use sigmoid for differentiability
            edge_mask = torch.sigmoid(dot * self.softness)
            # Combine: pixel is inside polygon if it is inside all edges
            mask = mask * edge_mask  # (num_poly, H, W)

        # Now, blend each polygon on canvas.
        # We assume polygons are independent and apply alpha compositing.
        # For each polygon, its contribution is: color * alpha * mask
        # And the remaining canvas is: canvas*(1 - alpha*mask)
        for i in range(num_poly):
            p_mask = mask[i]  # shape (H, W)
            a = torch.sigmoid(self.alpha[i])  # ensure in (0,1)
            c = self.colors[i].view(3, 1, 1)  # shape (3,1,1)
            # Alpha blending: new_pixel = p_mask*a*c + (1-p_mask*a)*old_pixel
            canvas = p_mask * a * c + (1 - p_mask * a) * canvas

        return canvas


In [27]:
from visualbench.utils import to_float_hw3_tensor
from myai.transforms import znormalize, normalize

target = znormalize(to_float_hw3_tensor("/var/mnt/ssd/Файлы/Изображения/Сохраненное/sanic.jpg").moveaxis(-1, 0)[:,::3,::3])

# Instantiate model
model = DifferentiablePolygonRenderer(target, num_polygons=50, vertices_per_polygon=3, softness=50.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# For demonstration purposes, run a few iterations
num_iterations = 300
for i in range(num_iterations):
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(f"Iteration {i+1}, Loss: {loss.item()}")

# After training, you can visualize the final rendered image or any intermediate frames.
# For example, to display the last frame:
try:
    from PIL import Image
    final_frame = model.frames[-1]
    im = Image.fromarray(final_frame)
    im.show()
except ImportError:
    print("PIL is not installed; cannot display final frame.")

# Optionally, save the frames as an animation using your preferred tool.

Iteration 1, Loss: 1.5427018404006958
Iteration 11, Loss: 0.35043835639953613
Iteration 21, Loss: 0.28057435154914856
Iteration 31, Loss: 0.21504658460617065
Iteration 41, Loss: 0.20676743984222412
Iteration 51, Loss: 0.19344347715377808
Iteration 61, Loss: 0.18707914650440216
Iteration 71, Loss: 0.18119540810585022
Iteration 81, Loss: 0.17480488121509552
Iteration 91, Loss: 0.1724458485841751
Iteration 101, Loss: 0.1666036695241928
Iteration 111, Loss: 0.16317059099674225
Iteration 121, Loss: 0.16104786098003387
Iteration 131, Loss: 0.17182639241218567
Iteration 141, Loss: 0.15784022212028503
Iteration 151, Loss: 0.14826609194278717
Iteration 161, Loss: 0.15001191198825836
Iteration 171, Loss: 0.16519232094287872
Iteration 181, Loss: 0.14698506891727448
Iteration 191, Loss: 0.14287249743938446
Iteration 201, Loss: 0.13931120932102203
Iteration 211, Loss: 0.13787956535816193
Iteration 221, Loss: 0.13854244351387024
Iteration 231, Loss: 0.13744333386421204
Iteration 241, Loss: 0.1347547

In [28]:
from myai.video import render_frames
render_frames('poly', model.frames)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PolygonOptimizer(nn.Module):
    def __init__(self, target_image, num_polygons, num_vertices, tau=0.1, beta=10.0):
        super().__init__()
        self.register_buffer('target_image', target_image)
        self.num_polygons = num_polygons
        self.num_vertices = num_vertices
        self.tau = tau
        self.beta = beta
        
        # Assuming target_image is (3, H, W)
        _, self.H, self.W = target_image.shape
        
        # Initialize vertices parameters in logit space (will be passed through sigmoid)
        self.vertices_logit = nn.Parameter(torch.randn(num_polygons, num_vertices, 2))
        
        # Initialize color parameters in logit space (will be passed through sigmoid)
        self.colors_logit = nn.Parameter(torch.randn(num_polygons, 4))
        
        self.frames = []
    
    def forward(self):
        # Render the current image
        current_image = self.render_polygons()
        
        # Compute loss against target image
        loss = F.mse_loss(current_image, self.target_image)
        
        # Convert to uint8 and append to frames (detach and no grad)
        with torch.no_grad():
            current_image_uint8 = (current_image.clamp(0.0, 1.0) * 255).byte()
            self.frames.append(current_image_uint8.cpu())
        
        return loss
    
    def render_polygons(self):
        # Get vertices in [0, 1] with sigmoid
        vertices = torch.sigmoid(self.vertices_logit)  # (N, V, 2)
        N, V, _ = vertices.shape
        
        # Create pixel grid in [0, 1]
        y_coords = torch.linspace(0, 1, self.H, device=vertices.device)
        x_coords = torch.linspace(0, 1, self.W, device=vertices.device)
        grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing='ij')
        pixels = torch.stack([grid_x, grid_y], dim=-1)  # (H, W, 2)
        
        # Reshape for broadcasting
        pixels = pixels.unsqueeze(2).unsqueeze(2)  # (H, W, 1, 1, 2)
        vertices = vertices.unsqueeze(0).unsqueeze(0)  # (1, 1, N, V, 2)
        
        # Compute edges (v1, v2)
        v1 = vertices
        v2 = torch.roll(vertices, shifts=-1, dims=3)
        edge_vec = v2 - v1  # (1, 1, N, V, 2)
        
        # Compute edge lengths
        edge_length = torch.norm(edge_vec, dim=-1, keepdim=True) + 1e-8  # (1, 1, N, V, 1)
        
        # Compute cross product for each edge and pixel
        dx = pixels[..., 0] - v1[..., 0]  # (H, W, N, V)
        dy = pixels[..., 1] - v1[..., 1]
        cross = dx * edge_vec[..., 1] - dy * edge_vec[..., 0]  # (H, W, N, V)
        signed_distance = cross / edge_length.squeeze(-1)  # (H, W, N, V)
        
        # Compute smooth min distance per polygon (tau is temperature)
        smooth_min_dist = -self.tau * torch.logsumexp(-signed_distance / self.tau, dim=-1)  # (H, W, N)
        
        # Compute mask using sigmoid (beta controls sharpness)
        mask = torch.sigmoid(smooth_min_dist * self.beta)  # (H, W, N)
        
        # Get colors (RGBA) and apply sigmoid
        colors = torch.sigmoid(self.colors_logit)  # (N, 4)
        rgb = colors[:, :3]  # (N, 3)
        a = colors[:, 3]  # (N,)
        
        # Compute alpha (mask * a) and premultiplied RGB
        alpha = mask * a.view(1, 1, N)  # (H, W, N)
        premultiplied_rgb = rgb.view(1, 1, N, 3) * alpha.unsqueeze(-1)  # (H, W, N, 3)
        
        # Compute transmittance for compositing
        one_minus_alpha = 1 - alpha  # (H, W, N)
        transmittance = torch.cumprod(one_minus_alpha, dim=2)  # (H, W, N)
        transmittance = torch.roll(transmittance, shifts=1, dims=2)
        transmittance[:, :, 0] = 1.0  # First polygon has full transmittance
        
        # Compute final image
        contribution = premultiplied_rgb * transmittance.unsqueeze(-1)  # (H, W, N, 3)
        final_image = torch.sum(contribution, dim=2)  # (H, W, 3)
        
        # Permute to (3, H, W) and clamp
        final_image = final_image.permute(2, 0, 1).clamp(0.0, 1.0)
        
        return final_image

In [7]:
from visualbench.utils import to_float_hw3_tensor
from myai.transforms import znormalize, normalize

target = znormalize(to_float_hw3_tensor("/var/mnt/ssd/Файлы/Изображения/Сохраненное/sanic.jpg").moveaxis(-1, 0)[:,::2,::2])

# Instantiate model
model = PolygonOptimizer(target, num_polygons=50, num_vertices=3, beta=200)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# For demonstration purposes, run a few iterations
num_iterations = 300
for i in range(num_iterations):
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(f"Iteration {i+1}, Loss: {loss.item()}")

from myai.video import render_frames
render_frames('poly2', model.frames)

Iteration 1, Loss: 0.9926592111587524
Iteration 11, Loss: 0.868086576461792
Iteration 21, Loss: 0.7758312225341797
Iteration 31, Loss: 0.733272135257721
Iteration 41, Loss: 0.7123379707336426
Iteration 51, Loss: 0.6996484994888306
Iteration 61, Loss: 0.6917396187782288
Iteration 71, Loss: 0.684929609298706
Iteration 81, Loss: 0.6799998879432678
Iteration 91, Loss: 0.6778824925422668
Iteration 101, Loss: 0.6761957406997681
Iteration 111, Loss: 0.6750072836875916
Iteration 121, Loss: 0.6742349863052368
Iteration 131, Loss: 0.6737663149833679
Iteration 141, Loss: 0.673454225063324
Iteration 151, Loss: 0.6731721758842468
Iteration 161, Loss: 0.6729599833488464
Iteration 171, Loss: 0.6727607846260071
Iteration 181, Loss: 0.6725674271583557
Iteration 191, Loss: 0.6723613739013672
Iteration 201, Loss: 0.6720983982086182
Iteration 211, Loss: 0.6717106699943542
Iteration 221, Loss: 0.671209990978241
Iteration 231, Loss: 0.6710312962532043
Iteration 241, Loss: 0.6709325909614563
Iteration 251, L