In [9]:
import torch
from main_pipeline import BodyReconstructionPipeline
import config


IMAGE_PATH = "/home/khater/pose-check/tom.jpg"
GENDER = "male"
OUTPUT_DIR = "output"

# Initialize and run pipeline
pipeline = BodyReconstructionPipeline(output_dir=OUTPUT_DIR)

results = pipeline.run(
    image_path=IMAGE_PATH,
    gender=GENDER,
    enable_visualization=True
)

Starting 3D Body Reconstruction Pipeline

[1/5] Estimating depth and camera FOV...
[97m[INFO ] using MLP layer as FFN[0m
[97m[INFO ] Processed Images Done taking 0.0817408561706543 seconds. Shape:  torch.Size([1, 3, 336, 504])[0m
[97m[INFO ] Model Forward Pass Done. Time: 0.10876703262329102 seconds[0m
[97m[INFO ] Conversion to Prediction Done. Time: 0.0037078857421875 seconds[0m
[97m[INFO ] Export Results Done. Time: 0.0009083747863769531 seconds[0m
########### Using fov estimator: MoGe2...
✓ Depth map shape: (336, 504)
✓ Camera intrinsics estimated

[2/5] Detecting 2D pose keypoints...

image 1/1 /home/khater/pose-check/tom.jpg: 448x640 1 person, 8.7ms
Speed: 4.8ms preprocess, 8.7ms inference, 4.6ms postprocess per image at shape (1, 3, 448, 640)
✓ Detected 1 person(s)

[3/5] Generating segmentation masks...
✓ Generated 1 mask(s)

[4/5] Creating 3D point cloud...
✓ Point cloud created: 169344 points

Cleaning up GPU memory...
✓ GPU memory freed

[5/5] Fitting SMPL-X model..

In [None]:
fitter = pipeline.fitter

depth_map = results["depth_map"]
cam_intrinsics = results["cam_intrinsics"]
keypoints_2d = results["keypoints_2d"]
processed_image = results["processed_image"]
masks = results["masks"]
point_cloud_array = results["point_cloud"]


transl = fitter.fitted_params['transl'].clone().requires_grad_(True)
global_orient = fitter.fitted_params['global_orient'].clone().requires_grad_(True)
body_pose = fitter.fitted_params['body_pose'].clone().requires_grad_(True)
betas = fitter.fitted_params['betas'].clone().requires_grad_(True)

left_hand_pose = fitter.fitted_params['left_hand_pose']
right_hand_pose = fitter.fitted_params['right_hand_pose']
expression = fitter.fitted_params['expression']

optimizer = torch.optim.Adam(
    [transl, global_orient, body_pose, betas],
    lr=0.005
)


output = fitter.smplx_model(
    betas=betas,
    global_orient=global_orient,
    body_pose=body_pose,
    transl=transl,
    left_hand_pose=left_hand_pose,
    right_hand_pose=right_hand_pose,
    expression=expression,
    return_verts=True,
    return_faces=True
)



In [19]:
verts = output.vertices
faces = fitter.smplx_model.faces

In [20]:
import torch
import torch.nn.functional as F

class DepthNormalLoss:
    def _compute_vertex_normals(self, vertices, faces):
        """
        Compute vertex normals using weighted face normals.
        vertices: [V, 3]
        faces: [F, 3]
        """
        # 1. Compute Face Normals
        v0 = vertices[faces[:, 0]]
        v1 = vertices[faces[:, 1]]
        v2 = vertices[faces[:, 2]]
        
        edge1 = v1 - v0
        edge2 = v2 - v0
        # Cross product for face normal
        face_normals = torch.cross(edge1, edge2, dim=1) # [F, 3]
        
        # 2. Scatter to Vertices (Simple uniform weighting)
        # Note: Ideally area-weighted, but uniform is faster/differentiable
        vertex_normals = torch.zeros_like(vertices)
        
        # Add face normal to each vertex in the face
        for i in range(3):
            vertex_normals.index_add_(0, faces[:, i], face_normals)
            
        # 3. Normalize
        vertex_normals = F.normalize(vertex_normals, dim=1, eps=1e-6)
        return vertex_normals

    def _depth_to_normals(self, depth, intrinsics):
        """
        Compute normal map from depth map using back-projection and gradients.
        depth: [H, W]
        intrinsics: [3, 3]
        """
        H, W = depth.shape
        fx, fy = intrinsics[0, 0], intrinsics[1, 1]
        cx, cy = intrinsics[0, 2], intrinsics[1, 2]

        # Create grid
        y, x = torch.meshgrid(torch.arange(H, device=depth.device), 
                              torch.arange(W, device=depth.device), indexing='ij')
        
        # Back-project to 3D (approximate for gradients)
        # X = (x - cx) * Z / fx
        # Y = (y - cy) * Z / fy
        X = (x - cx) * depth / fx
        Y = (y - cy) * depth / fy
        XYZ = torch.stack([X, Y, depth], dim=-1) # [H, W, 3]

        # Compute gradients (central difference)
        # Pad to handle borders
        padded = F.pad(XYZ, (0, 0, 1, 1, 1, 1), mode='replicate')
        
        # d/dy
        v_up   = padded[:-2, 1:-1, :]
        v_down = padded[2:, 1:-1, :]
        dy = v_down - v_up

        # d/dx
        v_left  = padded[1:-1, :-2, :]
        v_right = padded[1:-1, 2:, :]
        dx = v_right - v_left

        # Cross product: dx x dy gives normal
        cross = torch.cross(dx, dy, dim=-1)
        
        # Normalize (and flip sign if pointing away, usually Z should be positive)
        # In standard camera, normals point towards camera (-Z) or away (+Z)? 
        # Standard: Surface normal usually points opposite to viewing direction.
        normals = F.normalize(cross, dim=-1)
        return normals

    def compute_depth_loss(
            self,
            vertices: torch.Tensor,        # [V, 3]
            faces: torch.Tensor,           # [F, 3] (ADDED)
            cam_intrinsics: torch.Tensor,  # [3, 3]
            depth_map: torch.Tensor,       # [H, W]
            mask: torch.Tensor,            # [H, W]
            max_verts: int = 3000,
            it: int = 0,
            normal_weight: float = 0.1     # (ADDED)
        ) -> torch.Tensor:
            
            device = vertices.device
            H, W = depth_map.shape

            # ----------------------------------
            # 0. Pre-compute Normals
            # ----------------------------------
            # Source Normals (Vertex)
            v_normals = self._compute_vertex_normals(vertices, faces)
            
            # Target Normals (from Ground Truth Depth)
            # You might want to cache this if depth_map is static
            gt_normals_map = self._depth_to_normals(depth_map, cam_intrinsics) # [H, W, 3]

            # ----------------------------------
            # 1. Project vertices
            # ----------------------------------
            points = self._project_points(vertices, cam_intrinsics)
            u = points[:, 0].long()
            v = points[:, 1].long()
            z = vertices[:, 2]

            # ----------------------------------
            # 2. Filter: Image bounds
            # ----------------------------------
            inside = (u >= 0) & (u < W) & (v >= 0) & (v < H)
            u, v, z = u[inside], v[inside], z[inside]
            # Also filter normals
            v_normals_subset = v_normals[inside]

            if z.numel() == 0:
                return torch.zeros((), device=device)

            # ----------------------------------
            # 3. Filter: Mask
            # ----------------------------------
            linear_idx = v * W + u
            mask_flat = mask.reshape(-1)
            mask_values = mask_flat[linear_idx]
            mask_valid = mask_values > 0.5

            u, v, z = u[mask_valid], v[mask_valid], z[mask_valid]
            v_normals_subset = v_normals_subset[mask_valid] # Keep normals in sync

            if z.numel() == 0:
                return torch.zeros((), device=device)

            # ----------------------------------
            # 4. Z-buffer (Min Z per pixel)
            # ----------------------------------
            pixel_ids = v * W + u
            unique_ids, inverse = torch.unique(pixel_ids, return_inverse=True)

            # Compute min Z
            z_min = torch.full((unique_ids.shape[0],), float('inf'), device=device)
            z_min.scatter_reduce_(0, inverse, z, reduce='amin')
            
            # ----------------------------------
            # 5. Identify "Winning" Fragments
            # ----------------------------------
            # We need to find which vertices actually contributed to z_min 
            # to pick the correct source normals.
            
            # Broadcast min values back to original list size
            z_min_expanded = z_min[inverse]
            
            # Create a mask for vertices that are visible (closest)
            # Use a small epsilon for float comparison
            is_closest = torch.abs(z - z_min_expanded) < 1e-5
            
            # Filter our attributes to just the visible fragments
            # Note: This might still contain duplicates if two vertices have 
            # exact same depth at same pixel, but typically rare.
            visible_z = z[is_closest]
            visible_u = u[is_closest]
            visible_v = v[is_closest]
            visible_src_normals = v_normals_subset[is_closest]

            # ----------------------------------
            # 6. Compute Losses
            # ----------------------------------
            
            # A. Depth Loss
            # Lookup GT depth at the visible pixels
            depth_gt = depth_map[visible_v, visible_u]
            loss_depth = torch.mean((visible_z - depth_gt) ** 2)

            # B. Normal Loss
            # Lookup GT normals at the visible pixels
            target_normals = gt_normals_map[visible_v, visible_u] # [N, 3]
            
            # Cosine Similarity Loss: 1 - cos(theta)
            # Dot product of normalized vectors
            dot_prod = torch.sum(visible_src_normals * target_normals, dim=1)
            
            # Clamp for stability in case of numerical noise > 1.0
            dot_prod = torch.clamp(dot_prod, -1.0, 1.0)
            
            # We want vectors to align, so maximize dot product -> minimize (1 - dot)
            loss_normal = torch.mean(1.0 - dot_prod)

            # ----------------------------------
            # Debug Visualization
            # ----------------------------------
            if it % 50 == 0:
                print(f"Depth Loss: {loss_depth.item():.4f}, Normal Loss: {loss_normal.item():.4f}")
                # ... (existing visualization code) ...

            return loss_depth + (normal_weight * loss_normal)
    

    def _project_points(
        self,
        points_3d: torch.Tensor,
        cam_intrinsics: torch.Tensor
    ) -> torch.Tensor:
        """
        Project 3D points to 2D using camera intrinsics.
        
        Args:
            points_3d: [N, 3] 3D points
            cam_intrinsics: [3, 3] camera intrinsics
            
        Returns:
            points_2d: [N, 2] projected 2D points
        """
        fx = cam_intrinsics[0, 0]
        fy = cam_intrinsics[1, 1]
        cx = cam_intrinsics[0, 2]
        cy = cam_intrinsics[1, 2]
        
        # Perspective projection
        x_2d = fx * points_3d[:, 0] / (points_3d[:, 2] + 1e-6) + cx
        y_2d = -fy * points_3d[:, 1] / (points_3d[:, 2] + 1e-6) + cy
        
        return torch.stack([x_2d, y_2d], dim=1)

In [23]:
print("Computing depth loss...")
print(f" verts shape: {verts.shape}, cam_intrinsics shape: {cam_intrinsics[0].shape}, depth_map shape: {depth_map.shape}, mask shape: {masks[0].shape}")
# turn depth map to torch tensor
loss_fn = DepthNormalLoss()

# optimizer = torch.optim.Adam([verts], lr=0.01)

for it in range(1000):
    optimizer.zero_grad()
    
    # Calculate loss
    loss = loss_fn.compute_depth_loss(
        vertices=verts,
        faces=faces,                  # <--- NEW: Must pass faces
        cam_intrinsics=cam_intrinsics[0],
        depth_map=depth_map,
        mask=masks[0],
        it=it,
        normal_weight=0.1             # <--- NEW: Control normal influence
    )

Computing depth loss...
 verts shape: torch.Size([1, 10475, 3]), cam_intrinsics shape: torch.Size([3, 3]), depth_map shape: (336, 504), mask shape: torch.Size([1, 336, 504])


AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [8]:
print(depth)

tensor(1.5794, device='cuda:0', grad_fn=<MeanBackward0>)


In [4]:

vertices = verts
mask = masks[0]
max_verts = 3000
eps = 1e-6
device = vertices.device
H, W = depth_map.shape

# Subsample vertices
V = vertices.shape[0]
if V > max_verts:
    idx = torch.randperm(V, device=device)[:max_verts]
    verts = vertices[idx]
else:
    verts = vertices

# Keep vertices in front of camera
verts = verts[verts[:, 2] > eps]
# if verts.numel() == 0:
#     return torch.zeros((), device=device)
if cam_intrinsics.ndim == 3:
    cam_intrinsics = cam_intrinsics[0]


fx, fy = cam_intrinsics[0, 0], cam_intrinsics[1, 1]
cx, cy = cam_intrinsics[0, 2], cam_intrinsics[1, 2]

u = fx * verts[:, 0] / (verts[:, 2] + eps) + cx
v = -fy * verts[:, 1] / (verts[:, 2] + eps) + cy

u = u.long()
v = v.long()

# Image bounds
valid = (
    (u >= 0) & (u < W) &
    (v >= 0) & (v < H)
)

u = u[valid]
v = v[valid]
z = verts[valid, 2]

if z.numel() == 0:
    print("No valid vertices after image bounds check.")
u = u.long()
v = v.long()
# print(f" shape of mask: {mask.shape}")


print(f" shape of u: {u.shape}, v: {v.shape}, z: {z.shape}")
# Apply segmentation mask



 shape of u: torch.Size([3000]), v: torch.Size([3000]), z: torch.Size([3000])


In [5]:
## mask is a binary mask of shape [H,W]
## for each (u,v) coordinate, check if mask[v,u] is True

# First, verify mask shape and clamp indices to be safe
print(f"mask shape: {mask.shape}, expected: [{H}, {W}]")
print(f"u range: [{u.min()}, {u.max()}], v range: [{v.min()}, {v.max()}]")

# Clamp indices to valid range (defensive programming)
v_clamped = torch.clamp(v, 0, H - 1)
u_clamped = torch.clamp(u, 0, W - 1)

mask_valid = mask[v_clamped, u_clamped]

mask shape: torch.Size([1, 336, 504]), expected: [336, 504]
u range: [200, 298], v range: [13, 302]


In [None]:
import torch
import torch.nn.functional as F

class DepthNormalLoss:
    def _compute_vertex_normals(self, vertices, faces):
        """
        Compute vertex normals using weighted face normals.
        vertices: [V, 3]
        faces: [F, 3]
        """
        # 1. Compute Face Normals
        v0 = vertices[faces[:, 0]]
        v1 = vertices[faces[:, 1]]
        v2 = vertices[faces[:, 2]]
        
        edge1 = v1 - v0
        edge2 = v2 - v0
        # Cross product for face normal
        face_normals = torch.cross(edge1, edge2, dim=1) # [F, 3]
        
        # 2. Scatter to Vertices (Simple uniform weighting)
        # Note: Ideally area-weighted, but uniform is faster/differentiable
        vertex_normals = torch.zeros_like(vertices)
        
        # Add face normal to each vertex in the face
        for i in range(3):
            vertex_normals.index_add_(0, faces[:, i], face_normals)
            
        # 3. Normalize
        vertex_normals = F.normalize(vertex_normals, dim=1, eps=1e-6)
        return vertex_normals

    def _depth_to_normals(self, depth, intrinsics):
        """
        Compute normal map from depth map using back-projection and gradients.
        depth: [H, W]
        intrinsics: [3, 3]
        """
        H, W = depth.shape
        fx, fy = intrinsics[0, 0], intrinsics[1, 1]
        cx, cy = intrinsics[0, 2], intrinsics[1, 2]

        # Create grid
        y, x = torch.meshgrid(torch.arange(H, device=depth.device), 
                              torch.arange(W, device=depth.device), indexing='ij')
        
        # Back-project to 3D (approximate for gradients)
        # X = (x - cx) * Z / fx
        # Y = (y - cy) * Z / fy
        X = (x - cx) * depth / fx
        Y = (y - cy) * depth / fy
        XYZ = torch.stack([X, Y, depth], dim=-1) # [H, W, 3]

        # Compute gradients (central difference)
        # Pad to handle borders
        padded = F.pad(XYZ, (0, 0, 1, 1, 1, 1), mode='replicate')
        
        # d/dy
        v_up   = padded[:-2, 1:-1, :]
        v_down = padded[2:, 1:-1, :]
        dy = v_down - v_up

        # d/dx
        v_left  = padded[1:-1, :-2, :]
        v_right = padded[1:-1, 2:, :]
        dx = v_right - v_left

        # Cross product: dx x dy gives normal
        cross = torch.cross(dx, dy, dim=-1)
        
        # Normalize (and flip sign if pointing away, usually Z should be positive)
        # In standard camera, normals point towards camera (-Z) or away (+Z)? 
        # Standard: Surface normal usually points opposite to viewing direction.
        normals = F.normalize(cross, dim=-1)
        return normals

    def compute_depth_loss(
            self,
            vertices: torch.Tensor,        # [V, 3]
            faces: torch.Tensor,           # [F, 3] (ADDED)
            cam_intrinsics: torch.Tensor,  # [3, 3]
            depth_map: torch.Tensor,       # [H, W]
            mask: torch.Tensor,            # [H, W]
            max_verts: int = 3000,
            it: int = 0,
            normal_weight: float = 0.1     # (ADDED)
        ) -> torch.Tensor:
            
            device = vertices.device
            H, W = depth_map.shape

            # ----------------------------------
            # 0. Pre-compute Normals
            # ----------------------------------
            # Source Normals (Vertex)
            v_normals = self._compute_vertex_normals(vertices, faces)
            
            # Target Normals (from Ground Truth Depth)
            # You might want to cache this if depth_map is static
            gt_normals_map = self._depth_to_normals(depth_map, cam_intrinsics) # [H, W, 3]

            # ----------------------------------
            # 1. Project vertices
            # ----------------------------------
            points = self._project_points(vertices, cam_intrinsics)
            u = points[:, 0].long()
            v = points[:, 1].long()
            z = vertices[:, 2]

            # ----------------------------------
            # 2. Filter: Image bounds
            # ----------------------------------
            inside = (u >= 0) & (u < W) & (v >= 0) & (v < H)
            u, v, z = u[inside], v[inside], z[inside]
            # Also filter normals
            v_normals_subset = v_normals[inside]

            if z.numel() == 0:
                return torch.zeros((), device=device)

            # ----------------------------------
            # 3. Filter: Mask
            # ----------------------------------
            linear_idx = v * W + u
            mask_flat = mask.reshape(-1)
            mask_values = mask_flat[linear_idx]
            mask_valid = mask_values > 0.5

            u, v, z = u[mask_valid], v[mask_valid], z[mask_valid]
            v_normals_subset = v_normals_subset[mask_valid] # Keep normals in sync

            if z.numel() == 0:
                return torch.zeros((), device=device)

            # ----------------------------------
            # 4. Z-buffer (Min Z per pixel)
            # ----------------------------------
            pixel_ids = v * W + u
            unique_ids, inverse = torch.unique(pixel_ids, return_inverse=True)

            # Compute min Z
            z_min = torch.full((unique_ids.shape[0],), float('inf'), device=device)
            z_min.scatter_reduce_(0, inverse, z, reduce='amin')
            
            # ----------------------------------
            # 5. Identify "Winning" Fragments
            # ----------------------------------
            # We need to find which vertices actually contributed to z_min 
            # to pick the correct source normals.
            
            # Broadcast min values back to original list size
            z_min_expanded = z_min[inverse]
            
            # Create a mask for vertices that are visible (closest)
            # Use a small epsilon for float comparison
            is_closest = torch.abs(z - z_min_expanded) < 1e-5
            
            # Filter our attributes to just the visible fragments
            # Note: This might still contain duplicates if two vertices have 
            # exact same depth at same pixel, but typically rare.
            visible_z = z[is_closest]
            visible_u = u[is_closest]
            visible_v = v[is_closest]
            visible_src_normals = v_normals_subset[is_closest]

            # ----------------------------------
            # 6. Compute Losses
            # ----------------------------------
            
            # A. Depth Loss
            # Lookup GT depth at the visible pixels
            depth_gt = depth_map[visible_v, visible_u]
            loss_depth = torch.mean((visible_z - depth_gt) ** 2)

            # B. Normal Loss
            # Lookup GT normals at the visible pixels
            target_normals = gt_normals_map[visible_v, visible_u] # [N, 3]
            
            # Cosine Similarity Loss: 1 - cos(theta)
            # Dot product of normalized vectors
            dot_prod = torch.sum(visible_src_normals * target_normals, dim=1)
            
            # Clamp for stability in case of numerical noise > 1.0
            dot_prod = torch.clamp(dot_prod, -1.0, 1.0)
            
            # We want vectors to align, so maximize dot product -> minimize (1 - dot)
            loss_normal = torch.mean(1.0 - dot_prod)

            # ----------------------------------
            # Debug Visualization
            # ----------------------------------
            if it % 50 == 0:
                print(f"Depth Loss: {loss_depth.item():.4f}, Normal Loss: {loss_normal.item():.4f}")
                # ... (existing visualization code) ...

            return loss_depth + (normal_weight * loss_normal)

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
