In [1]:
import sys
import os

sys.path.append("../")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

class DeformableGrid3D(nn.Module):
    def __init__(self, grid_size=(16, 16, 16), deformation_scale=1.0):
        super().__init__()
        self.grid_size = grid_size
        self.deformation_scale = deformation_scale
        
        x = torch.linspace(-1, 1, grid_size[0])
        y = torch.linspace(-1, 1, grid_size[1])
        z = torch.linspace(-1, 1, grid_size[2])
        grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing='ij')
        self.base_grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
        
        self.deformation = nn.Parameter(torch.zeros(*grid_size, 3))
        self.transform_matrix = nn.Parameter(torch.eye(3))
        self.transform_bias = nn.Parameter(torch.zeros(3))
        # Agent parameters
        self.agent_pos = np.array([grid_size[0]//2, grid_size[1]//2, grid_size[2]//2])
        self.move_speed = 1
        
        # Visualization parameters
        self.rotation = np.array([0.3, 0.3, 0.])
        self.translation = np.array([0., 0., -10.])
        self.scale = 1.0
        self.last_mouse = None
        self.window_name = 'Deformable Grid 3D'

    def move_agent(self, direction):
        """Move agent in one of six directions: up, down, left, right, forward, back"""
        direction_map = {
            'up': np.array([0, 1, 0]),
            'down': np.array([0, -1, 0]),
            'left': np.array([-1, 0, 0]),
            'right': np.array([1, 0, 0]),
            'forward': np.array([0, 0, 1]),
            'back': np.array([0, 0, -1])
        }
        
        if direction in direction_map:
            new_pos = self.agent_pos + direction_map[direction] * self.move_speed
            # Check bounds
            for i in range(3):
                new_pos[i] = np.clip(new_pos[i], 0, self.grid_size[i] - 1)
            self.agent_pos = new_pos
    
    def apply_affine_transform(self, points):
        """Apply the current affine transformation to points"""
        return torch.matmul(points, self.transform_matrix.T) + self.transform_bias

    def set_transform(self, matrix=None, bias=None):
        """Set the affine transformation parameters"""
        if matrix is not None:
            self.transform_matrix.data = matrix
        if bias is not None:
            self.transform_bias.data = bias

    def get_deformed_grid(self):
        """Get grid with both local deformation and global affine transform"""
        local_deform = self.base_grid + self.deformation_scale * self.deformation
        return self.apply_affine_transform(local_deform)
        
    def apply_force(self, position, force, radius=0.2):
        grid_pos = (position + 1) * torch.tensor(self.grid_size).float() / 2
        distances = torch.zeros_like(self.base_grid[..., 0])
        
        for i in range(self.grid_size[0]):
            for j in range(self.grid_size[1]):
                for k in range(self.grid_size[2]):
                    point = torch.tensor([i, j, k]).float()
                    distances[i, j, k] = torch.norm(point - grid_pos)
        
        mask = (distances < radius * max(self.grid_size)).float()
        mask = mask.unsqueeze(-1).repeat(1, 1, 1, 3)
        force_field = force * mask * (1 - distances.unsqueeze(-1) / (radius * max(self.grid_size)))
        self.deformation.data += force_field
   
    def apply_inflation(self, center, radius, strength=1.0, deflate=False):
        """
        Apply inflation/deflation at specified center point
        center: (x,y,z) coordinates in grid space
        radius: affect radius
        strength: deformation magnitude
        deflate: if True, creates inward deformation instead of outward
        """
        grid_points = self.base_grid
        
        # Calculate distances from center
        distances = torch.norm(grid_points - torch.tensor(center), dim=-1)
        
        # Create smooth falloff based on radius
        falloff = torch.exp(-distances**2 / (2 * radius**2))
        
        # Calculate displacement directions (normalized vectors from center)
        directions = grid_points - torch.tensor(center)
        directions = directions / (torch.norm(directions, dim=-1, keepdim=True) + 1e-6)
        
        # Apply displacement
        factor = -1 if deflate else 1
        displacement = factor * strength * falloff.unsqueeze(-1) * directions
        
        self.deformation.data += displacement

    def reset(self):
        self.deformation.data.zero_()
        self.transform_matrix.data = torch.eye(3)
        self.transform_bias.data.zero_()
        
    def mouse_callback(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            self.last_mouse = (x, y)
        elif event == cv2.EVENT_MOUSEMOVE and flags & cv2.EVENT_FLAG_LBUTTON:
            if self.last_mouse is not None:
                dx = x - self.last_mouse[0]
                dy = y - self.last_mouse[1]
                self.rotation[1] += dx * 0.01
                self.rotation[0] += dy * 0.01
                self.last_mouse = (x, y)
        elif event == cv2.EVENT_MOUSEWHEEL:
            if flags > 0:
                self.scale *= 1.1
            else:
                self.scale /= 1.1
        elif event == cv2.EVENT_LBUTTONUP:
            self.last_mouse = None
            
    def project_points(self, points):
        if isinstance(points, torch.Tensor):
            points = points.detach().numpy()
        
        Rx = np.array([[1, 0, 0],
                      [0, np.cos(self.rotation[0]), -np.sin(self.rotation[0])],
                      [0, np.sin(self.rotation[0]), np.cos(self.rotation[0])]])
        
        Ry = np.array([[np.cos(self.rotation[1]), 0, np.sin(self.rotation[1])],
                      [0, 1, 0],
                      [-np.sin(self.rotation[1]), 0, np.cos(self.rotation[1])]])
        
        Rz = np.array([[np.cos(self.rotation[2]), -np.sin(self.rotation[2]), 0],
                      [np.sin(self.rotation[2]), np.cos(self.rotation[2]), 0],
                      [0, 0, 1]])
        
        R = Rz @ Ry @ Rx
        
        points = points @ R.T
        points = points * self.scale
        points = points + self.translation
        
        f = 1000
        points_2d = points[..., :2] * f / points[..., 2:3]
        
        return points_2d, points[..., 2]

    def render_interactive(self):
        cv2.namedWindow(self.window_name)
        cv2.setMouseCallback(self.window_name, self.mouse_callback)
        
        while True:
            frame = np.zeros((800, 800, 3), dtype=np.uint8)
            
            deformed = self.get_deformed_grid()
            points_2d, depths = self.project_points(deformed)
            
            points_2d = (points_2d * 200 + np.array([400, 400])).astype(int)
            
            # Draw grid lines with depth sorting
            lines = []
            for i in range(self.grid_size[0]):
                for j in range(self.grid_size[1]):
                    for k in range(self.grid_size[2]-1):
                        pt1 = points_2d[i, j, k]
                        pt2 = points_2d[i, j, k+1]
                        avg_depth = (depths[i, j, k] + depths[i, j, k+1]) / 2
                        lines.append((pt1, pt2, avg_depth))
            
            for i in range(self.grid_size[0]):
                for k in range(self.grid_size[2]):
                    for j in range(self.grid_size[1]-1):
                        pt1 = points_2d[i, j, k]
                        pt2 = points_2d[i, j+1, k]
                        avg_depth = (depths[i, j, k] + depths[i, j+1, k]) / 2
                        lines.append((pt1, pt2, avg_depth))
            
            for j in range(self.grid_size[1]):
                for k in range(self.grid_size[2]):
                    for i in range(self.grid_size[0]-1):
                        pt1 = points_2d[i, j, k]
                        pt2 = points_2d[i+1, j, k]
                        avg_depth = (depths[i, j, k] + depths[i+1, j, k]) / 2
                        lines.append((pt1, pt2, avg_depth))
            
            # Draw grid
            lines.sort(key=lambda x: x[2], reverse=True)
            for pt1, pt2, _ in lines:
                cv2.line(frame, tuple(pt1), tuple(pt2), (0, 255, 0), 1)
            
            # Draw agent
            agent_pos_grid = self.base_grid[
                int(self.agent_pos[0]),
                int(self.agent_pos[1]),
                int(self.agent_pos[2])
            ].numpy()
            agent_pos_2d, agent_depth = self.project_points(agent_pos_grid)
            agent_screen_pos = (agent_pos_2d * 200 + np.array([400, 400])).astype(int)
            cv2.circle(frame, tuple(agent_screen_pos), 5, (255, 0, 0), -1)
            
            # Controls
            cv2.putText(frame, "Mouse: rotate/zoom", (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(frame, "WASD: move in XY plane", (10, 60),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(frame, "QE: move in Z axis", (10, 90),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(frame, "ESC: exit", (10, 120),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            
            cv2.imshow(self.window_name, frame)
            
            key = cv2.waitKey(1) & 0xFF
            if key == 27:  # ESC
                break
            elif key == ord('w'):
                self.move_agent('up')
            elif key == ord('s'):
                self.move_agent('down')
            elif key == ord('a'):
                self.move_agent('left')
            elif key == ord('d'):
                self.move_agent('right')
            elif key == ord('q'):
                self.move_agent('back')
            elif key == ord('e'):
                self.move_agent('forward')

        
        cv2.destroyAllWindows()
    

In [3]:
grid = DeformableGrid3D()

# Create outward bulge at center
grid.apply_inflation(center=[0,0,0], radius=0.5, strength=0.3)

# Create inward dent
grid.apply_inflation(center=[0.5,0.5,0.5], radius=0.3, strength=0.2, deflate=True)

grid.render_interactive()

In [5]:
grid = DeformableGrid3D()

# Set a shear transformation
shear = torch.tensor([
    [1.0, 0.5, 0.0],
    [0.0, 2.0, 0.0],
    [0.0, 0.0, 1.0]
], dtype=torch.float32)
grid.set_transform(matrix=shear)

grid.render_interactive()

In [29]:
grid = DeformableGrid3D()
grid.apply_force(torch.tensor([0.0, 0.0, 0.0]), torch.tensor([0.1, 0.0, 0.0]))
grid.render_interactive()  # Left-click and drag to rotate, ESC to exit

In [38]:
from functools import wraps
from time import time
import numpy as np
import random

def timing(f):
    @wraps(f)
    def wrap(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        print('func:%r args:[%r, %r] took: %2.4f sec' % \
          (f.__name__, args, kw, te-ts))
        return result
    return wrap

@timing
def time_env(env,num_steps):
    env.reset()
    
    for i in range(num_steps):
        action = random.choice(['up', 'down', 'left', 'right', 'forward', 'back'])
        grid.move_agent(action)


In [39]:
env = DeformableGrid3D()
time_env(env,1000000)

func:'time_env' args:[(DeformableGrid3D(), 1000000), {}] took: 22.6228 sec


In [30]:
import random
grid = DeformableGrid3D()

def f():
    for i in range(100000):
        action = random.choice(['up', 'down', 'left', 'right', 'forward', 'back'])
        grid.move_agent(action)

%timeit f()

2.27 s ± 32.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

class DeformableGrid3D(nn.Module):
    def __init__(self, grid_size=(16, 16, 16), deformation_scale=1.0):
        super().__init__()
        self.grid_size = grid_size
        self.deformation_scale = deformation_scale

        x = torch.linspace(-1, 1, grid_size[0])
        y = torch.linspace(-1, 1, grid_size[1])
        z = torch.linspace(-1, 1, grid_size[2])
        grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing='ij')
        self.base_grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)

        self.deformation = nn.Parameter(torch.zeros(*grid_size, 3))
        self.transform_matrix = nn.Parameter(torch.eye(3))
        self.transform_bias = nn.Parameter(torch.zeros(3))

        self.agent_pos = np.array([grid_size[0]//2, grid_size[1]//2, grid_size[2]//2])
        self.agent_direction = np.array([1.0, 0.0, 0.0])
        self.agent_up = np.array([0.0, 1.0, 0.0])

    def apply_inflation(self, center, radius, strength=1.0, deflate=False):
        grid_points = self.base_grid
        distances = torch.norm(grid_points - torch.tensor(center), dim=-1)
        falloff = torch.exp(-distances**2 / (2 * radius**2))
        directions = grid_points - torch.tensor(center)
        directions = directions / (torch.norm(directions, dim=-1, keepdim=True) + 1e-6)
        factor = -1 if deflate else 1
        displacement = factor * strength * falloff.unsqueeze(-1) * directions
        self.deformation.data += displacement

    def apply_affine_transform(self, points):
        return torch.matmul(points, self.transform_matrix.T) + self.transform_bias

    def get_deformed_grid(self):
        local_deform = self.base_grid + self.deformation_scale * self.deformation
        return self.apply_affine_transform(local_deform)

    def move_agent(self, direction):
        direction_map = {
            'up': np.array([0, 1, 0]),
            'down': np.array([0, -1, 0]),
            'left': np.array([-1, 0, 0]),
            'right': np.array([1, 0, 0]),
            'forward': np.array([0, 0, 1]),
            'back': np.array([0, 0, -1])
        }

        if direction in direction_map:
            new_pos = self.agent_pos + direction_map[direction]
            for i in range(3):
                new_pos[i] = np.clip(new_pos[i], 0, self.grid_size[i] - 1)
            self.agent_pos = new_pos

    def rotate_agent(self, angle_x=0, angle_y=0):
        Rx = np.array([[1, 0, 0],
                      [0, np.cos(angle_x), -np.sin(angle_x)],
                      [0, np.sin(angle_x), np.cos(angle_x)]])

        Ry = np.array([[np.cos(angle_y), 0, np.sin(angle_y)],
                      [0, 1, 0],
                      [-np.sin(angle_y), 0, np.cos(angle_y)]])

        R = Ry @ Rx
        self.agent_direction = R @ self.agent_direction
        self.agent_up = R @ self.agent_up

    def render(self):
        frame = np.zeros((800, 800, 3), dtype=np.uint8)
        points_2d = self.base_grid[..., :2].reshape(-1, 2).numpy()
        points_2d = (points_2d * 200 + np.array([400, 400])).astype(int)

        for point in points_2d:
            cv2.circle(frame, tuple(point), 2, (0, 255, 0), -1)

        agent_pos = self.base_grid[int(self.agent_pos[0]), int(self.agent_pos[1]), int(self.agent_pos[2])].numpy()
        agent_screen_pos = (agent_pos[:2] * 200 + np.array([400, 400])).astype(int)

        cv2.circle(frame, tuple(agent_screen_pos), 10, (255, 0, 0), -1)

        direction_end = agent_pos[:2] + self.agent_direction[:2] * 0.5
        direction_screen_pos = (direction_end * 200 + np.array([400, 400])).astype(int)
        cv2.line(frame, tuple(agent_screen_pos), tuple(direction_screen_pos), (255, 0, 0), 2)

        up_end = agent_pos[:2] + self.agent_up[:2] * 0.3
        up_screen_pos = (up_end * 200 + np.array([400, 400])).astype(int)
        cv2.line(frame, tuple(agent_screen_pos), tuple(up_screen_pos), (0, 255, 0), 2)

        cv2.imshow('Deformable Grid 3D', frame)
        cv2.waitKey(1)

    def render_interactive(self):
        cv2.namedWindow('Deformable Grid 3D Interactive')
        
        while True:
            frame = np.zeros((800, 800, 3), dtype=np.uint8)

            points_2d = self.base_grid[..., :2].reshape(-1, 2).numpy()
            points_2d = (points_2d * 200 + np.array([400, 400])).astype(int)

            for point in points_2d:
                cv2.circle(frame, tuple(point), 2, (0, 255, 0), -1)

            agent_pos = self.base_grid[int(self.agent_pos[0]), int(self.agent_pos[1]), int(self.agent_pos[2])].numpy()
            agent_screen_pos = (agent_pos[:2] * 200 + np.array([400, 400])).astype(int)

            cv2.circle(frame, tuple(agent_screen_pos), 10, (255, 0, 0), -1)

            direction_end = agent_pos[:2] + self.agent_direction[:2] * 0.5
            direction_screen_pos = (direction_end * 200 + np.array([400, 400])).astype(int)
            cv2.line(frame, tuple(agent_screen_pos), tuple(direction_screen_pos), (255, 0, 0), 2)

            up_end = agent_pos[:2] + self.agent_up[:2] * 0.3
            up_screen_pos = (up_end * 200 + np.array([400, 400])).astype(int)
            cv2.line(frame, tuple(agent_screen_pos), tuple(up_screen_pos), (0, 255, 0), 2)

            controls = [
                "WASD: move in XY plane",
                "QE: move in Z axis",
                "Arrow Keys: rotate agent",
                "ESC: Exit"
            ]

            for i, text in enumerate(controls):
                cv2.putText(frame, text, (10, 30 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

            cv2.imshow('Deformable Grid 3D Interactive', frame)

            key = cv2.waitKey(1) & 0xFF
            if key == 27:  # ESC
                break
            elif key == ord('w'):
                self.move_agent('up')
            elif key == ord('s'):
                self.move_agent('down')
            elif key == ord('a'):
                self.move_agent('left')
            elif key == ord('d'):
                self.move_agent('right')
            elif key == ord('q'):
                self.move_agent('back')
            elif key == ord('e'):
                self.move_agent('forward')
            elif key == 82:  # Up arrow
                self.rotate_agent(angle_x=-0.1)
            elif key == 84:  # Down arrow
                self.rotate_agent(angle_x=0.1)
            elif key == 81:  # Left arrow
                self.rotate_agent(angle_y=-0.1)
            elif key == 83:  # Right arrow
                self.rotate_agent(angle_y=0.1)

        cv2.destroyAllWindows()

# Example usage
grid = DeformableGrid3D()
grid.render_interactive()
