In [83]:
import os
import torch
import numpy as np
import re
from pytorch3d.io import load_obj,load_objs_as_meshes
from pytorch3d.renderer import (
    PerspectiveCameras,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    HardPhongShader,
    PointLights,
    look_at_view_transform,
    TexturesVertex,
)

from pytorch3d.structures import join_meshes_as_scene, Meshes
from PIL import Image
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pytorch3d.renderer import TexturesVertex


In [84]:
def create_checkerboard_texture(verts, device, frequency=8):
    """
    Create a black and white checkerboard texture pattern.
    :param verts: Vertices of the mesh
    :param device: Device for the tensor (CPU or GPU)
    :param frequency: Frequency of the checkerboard pattern
    :return: Checkerboard texture tensor
    """
    # Extract X and Z coordinates and normalize them
    xz_coords = verts[:, [0, 2]]  # Use the X and Z coordinates for the checkerboard pattern
    xz_coords = xz_coords - xz_coords.min(dim=0, keepdim=True)[0]  # Shift to start from 0
    xz_coords = xz_coords / xz_coords.max(dim=0, keepdim=True)[0]  # Scale to [0, 1]

    # Create the checkerboard pattern
    pattern = (torch.floor(xz_coords[:, 0] * frequency) + torch.floor(xz_coords[:, 1] * frequency)) % 2
    checkerboard_texture = torch.where(
        pattern.unsqueeze(1) == 0,
        torch.tensor([1.0, 1.0, 1.0], device=device).unsqueeze(0),  # White color
        torch.tensor([0.0, 0.0, 0.0], device=device).unsqueeze(0)   # Black color
    ).expand(-1, 3)

    return checkerboard_texture

def ensure_vertex_texture(mesh):
    if not isinstance(mesh.textures, TexturesVertex):
        verts = mesh.verts_list()[0]
        checkerboard_texture = create_checkerboard_texture(verts, device)
        mesh.textures = TexturesVertex(verts_features=[checkerboard_texture])
    return mesh


In [85]:
class MeshLoader:
    def __init__(self, device, obj_folder):
        self.device = device
        self.obj_folder = obj_folder

    def load_mesh(self, filename):
        # Load the mesh with material information
        obj_path = os.path.join(self.obj_folder, filename)
        verts, faces, aux = load_obj(obj_path, device=self.device)

        faces_idx = faces.verts_idx

        # Create per-vertex colors based on Kd from mtl file
        verts_rgb = torch.ones_like(verts)

        if aux.material_colors:
            for mat_name, mat_props in aux.material_colors.items():
                if 'diffuse_color' in mat_props:
                    # Use 'diffuse_color' for vertex colors
                    diffuse_color = mat_props['diffuse_color']
                    verts_rgb = diffuse_color.unsqueeze(0).expand(verts.shape)
                    break  # Use the first material with 'diffuse_color'

        textures = TexturesVertex(verts_features=[verts_rgb])

        mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
        return mesh


In [86]:
def parse_position_line(line):
    # Line format: Ball 'SphereX' Position: (x, y, z)
    match = re.match(r"Ball '(\w+)' Position: \(([-\d\.]+), ([-\d\.]+), ([-\d\.]+)\)", line)
    if match:
        x = float(match.group(2))
        y = float(match.group(3))
        z = float(match.group(4))
        return [x, y, z]
    else:
        raise ValueError(f"Line format incorrect: {line}")


In [87]:
# Function to parse motion data (from Motion Example txt)
def parse_motion_file(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()

    frames = []
    i = 0
    while i < len(lines):
        if i + 1 >= len(lines):
            break  # Avoid index error
        line1 = lines[i].strip()
        line2 = lines[i + 1].strip()
        pos1 = parse_position_line(line1)
        pos2 = parse_position_line(line2)
        frames.append({'Sphere1': pos1, 'Sphere3': pos2})
        i += 2
    return frames


In [88]:
def translate_mesh(mesh, translation):
    # mesh is a Meshes object
    # translation is a list or tensor of shape (3,)
    verts = mesh.verts_list()[0]
    faces = mesh.faces_list()[0]
    textures = mesh.textures
    verts_translated = verts + torch.tensor(translation, device=device)
    new_mesh = Meshes(verts=[verts_translated], faces=[faces], textures=textures)
    return new_mesh


In [89]:
# Load meshes for ball1, ball2, and ball3 (from Current Base Code)
base_folder_path="./data/noncharacters"
Motion_data= 'BallPositions_31'
mesh_loader = MeshLoader(device, obj_folder=base_folder_path)
ball1_mesh_original = mesh_loader.load_mesh("ball1.obj")
#ball2_mesh_original = mesh_loader.load_mesh("ball2.obj")  # Static mesh
ball3_mesh_original = mesh_loader.load_mesh("ball3.obj")


# Read motion data
motions = parse_motion_file(f"{base_folder_path}/{Motion_data}.txt")
num_frames = len(motions)

# Set up renderer settings (from Previous Render Code)
raster_settings = RasterizationSettings(
    image_size=[320, 512],
    blur_radius=0.0,
    faces_per_pixel=1,
)

# Set up lights
lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]])

# Camera parameters (from Previous Render Code)

dist = 1.8
elev = 0.0
azim = 90.0
lookat_x = 0.0
lookat_y = 0.6
lookat_z = 0.6
at = torch.tensor([[lookat_x, lookat_y, lookat_z]], dtype=torch.float32)

# Output directory
output_dir = f'images/Balls/{Motion_data}_input'
os.makedirs(output_dir, exist_ok=True)


In [90]:
def generate_dense_grid(width, height, density=100):
    """
    Generate a dense grid of vertices for a mesh that lies flat on the X-Z plane, centered at the origin.
    :param width: Width of the grid
    :param height: Height of the grid
    :param density: Number of divisions along each axis
    :return: Dense grid vertices and faces
    """
    x = np.linspace(-width / 2, width / 2, density)
    z = np.linspace(-height / 2, height / 2, density)
    xz_grid = np.array(np.meshgrid(x, z)).reshape(2, -1).T

    # Add a y-component (set to 0 for a flat plane on the X-Z plane)
    y = np.zeros((xz_grid.shape[0], 1))
    xyz_grid = np.hstack((xz_grid[:, 0:1], y, xz_grid[:, 1:2]))  # [x, y, z] with y = 0

    # Create vertices
    verts = torch.tensor(xyz_grid, dtype=torch.float32)

    # Create faces (triangles) for the grid
    faces = []
    for i in range(density - 1):
        for j in range(density - 1):
            # Define two triangles for each square in the grid
            v0 = i * density + j
            v1 = v0 + 1
            v2 = v0 + density
            v3 = v2 + 1

            faces.append([v0, v2, v1])
            faces.append([v2, v3, v1])

    faces = torch.tensor(faces, dtype=torch.int64)

    return verts, faces


In [91]:
for frame_idx, positions in enumerate(tqdm(motions)):
    # Get positions
    pos1 = positions['Sphere1']
    pos3 = positions['Sphere3']
    

    # Translate meshes according to motion data
    ball1_mesh = translate_mesh(ball1_mesh_original, pos1)
    #ball2_mesh = ball2_mesh_original  # Static mesh
    ball3_mesh = translate_mesh(ball3_mesh_original, pos3)
    # Combine meshes into a scene
    #plane = load_objs_as_meshes(["./data/checkerboard/checkerboard.obj"], device=device)
    verts, faces = generate_dense_grid(width=10, height=10, density=200)
    # Create the plane mesh with the dense grid
    plane_mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
    
    #ball1_mesh = ensure_vertex_texture(ball1_mesh_original)
    #ball3_mesh = ensure_vertex_texture(ball3_mesh_original)
    plane = ensure_vertex_texture(plane_mesh)

    scene_mesh = join_meshes_as_scene([ball1_mesh, ball3_mesh,plane_mesh])


    # Views and corresponding camera azimuth adjustments
    views = {
        'front': azim,
        'back': azim - 179,
        'left': azim - 5,
        'right': azim + 5,
    }
    
    if (frame_idx>60) and (frame_idx%30 != 0):
        continue
    
    for view_name, azimuth in views.items():
        # Adjust look-at point if necessary (from Previous Render Code)
        if view_name == 'left':
            at_view = at + torch.tensor([[0, 0, 0]], dtype=torch.float32)
        elif view_name == 'right':
            at_view = at - torch.tensor([[0, 0, 0]], dtype=torch.float32)
        else:
            at_view = at
        
        # Set up camera
        R, T = look_at_view_transform(dist=dist, elev=elev, azim=azimuth, at=at_view)
        cameras = PerspectiveCameras(device=device, R=R, T=T)

        # Set up renderer
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
            shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
        )
        renderer.rasterizer.cameras.image_size =[320, 512] 
        # Render scene
        images = renderer(scene_mesh)
        image = images[0, ..., :3].cpu().numpy()
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)

        # Save image
        
        pos1 = torch.tensor(pos1)
        pos3 = torch.tensor(pos3)
        position = torch.stack((pos1, pos3)).to('cuda')
        screen_pos = renderer.rasterizer.cameras.transform_points_screen(position)

        if view_name=='front':
            filename = os.path.join(output_dir, f'frame_{frame_idx:04d}.png')
            position_name = f"{output_dir}/frame_{frame_idx:04d}.npy"
            #print(screen_pos)
        else:
            filename = os.path.join(output_dir, f'{view_name}_{frame_idx:04d}.png')
            position_name = f"{output_dir}/{view_name}_{frame_idx:04d}.npy"
        
        image.save(filename)        
        np.save(position_name, screen_pos.cpu().numpy())



  pos1 = torch.tensor(pos1)
  pos3 = torch.tensor(pos3)
100%|██████████| 241/241 [00:32<00:00,  7.38it/s]


In [92]:
torch.cuda.empty_cache()

In [93]:
"""output = widgets.interactive_output(draw_image, {"frame_idx": frame_idx,
                                                 "azimuth": azimuth,
                                                 "elevation": elevation,
                                                 "distance": distance,
                                                 "lookat_x": lookat_x,
                                                 "lookat_y": lookat_y,
                                                 "lookat_z": lookat_z})

save_img_button = widgets.Button(description="Save Image")
save_img_button.on_click(save_image)

save_vid_button = widgets.Button(description="Save Video")
save_vid_button.on_click(save_video)

reset_params_button = widgets.Button(description="Reset")
reset_params_button.on_click(reset_value)

control_display = widgets.VBox([reset_params_button, cam_display, at_display,
                                widgets.HBox([save_img_button, save_vid_button]), #, save_cam_button]),
                                frame_idx])
display(widgets.HBox([control_display, output]))"""

'output = widgets.interactive_output(draw_image, {"frame_idx": frame_idx,\n                                                 "azimuth": azimuth,\n                                                 "elevation": elevation,\n                                                 "distance": distance,\n                                                 "lookat_x": lookat_x,\n                                                 "lookat_y": lookat_y,\n                                                 "lookat_z": lookat_z})\n\nsave_img_button = widgets.Button(description="Save Image")\nsave_img_button.on_click(save_image)\n\nsave_vid_button = widgets.Button(description="Save Video")\nsave_vid_button.on_click(save_video)\n\nreset_params_button = widgets.Button(description="Reset")\nreset_params_button.on_click(reset_value)\n\ncontrol_display = widgets.VBox([reset_params_button, cam_display, at_display,\n                                widgets.HBox([save_img_button, save_vid_button]), #, save_cam_button]),\n  