# 🧪 Test JEPA3D Encoder

This notebook shows how to import and run your `JEPA3DEncoderWrapper` wrapper from  
`src/models/encoders/jepa3d_backbone.py`.

The function `create_featurized_scene_dict` creates a suitable randomized input for the pretrained JEPA3D Encoder.

Since the current VM is not compatible for using Flash Attention, all default arguments enable_flash in the `point_transfomer.py` file are assumed to be False (instead of True, like before).

In [3]:
import sys
from pathlib import Path

# Add the src directory to the Python path
sys.path.append(str(Path().resolve().parent / 'src'))

import torch
import numpy as np

from models.encoders.jepa3d_wrapper import JEPA3DEncoderWrapper
from ext.jepa3d.models.encoder_3djepa import Encoder3DJEPA

print("✅ Both imports resolve!")


✅ Both imports resolve!




In [4]:
def create_featurized_scene_dict(num_points=30000, device='cpu', model=None):
    """
    Create a featurized scene dictionary that matches the expected input format
    for the 3D-JEPA encoder based on the forward method.
    
    Args:
        num_points: Number of points in the scene
        device: Device to create tensors on
        model: The model instance to check expected dimensions
    
    Returns:
        dict: featurized_scene_dict with all required keys
    """
    
    # Create XYZ coordinates - simulate a room-like environment
    xyz = torch.zeros(num_points, 3, device=device)
    
    # Generate points that look like a room with furniture
    # Floor points (z near 0)
    floor_mask = torch.rand(num_points, device=device) < 0.3
    xyz[floor_mask, :2] = torch.rand((floor_mask.sum(), 2), device=device) * 4 - 2  # -2 to 2 meters
    xyz[floor_mask, 2] = torch.rand(floor_mask.sum(), device=device) * 0.1  # 0 to 0.1 meters
    
    # Ceiling points (z near 3)
    ceiling_mask = (~floor_mask) & (torch.rand(num_points, device=device) < 0.2)
    xyz[ceiling_mask, :2] = torch.rand((ceiling_mask.sum(), 2), device=device) * 4 - 2
    xyz[ceiling_mask, 2] = 2.8 + torch.rand(ceiling_mask.sum(), device=device) * 0.2  # 2.8 to 3.0 meters
    
    # Wall and furniture points (middle z values)
    remaining_mask = ~floor_mask & ~ceiling_mask
    xyz[remaining_mask, :2] = torch.rand((remaining_mask.sum(), 2), device=device) * 4 - 2
    xyz[remaining_mask, 2] = torch.rand(remaining_mask.sum(), device=device) * 2.8  # 0 to 2.8 meters

    print(xyz.shape)
    
    # Create RGB colors (0-1 range - model will multiply by 255)
    rgb = torch.rand(num_points, 3, device=device)
    
    # Floor tends to be darker/brownish
    rgb[floor_mask] = torch.tensor([0.4, 0.3, 0.2], device=device) + torch.rand((floor_mask.sum(), 3), device=device) * 0.3
    
    # Ceiling tends to be white/light
    rgb[ceiling_mask] = torch.tensor([0.8, 0.8, 0.8], device=device) + torch.rand((ceiling_mask.sum(), 3), device=device) * 0.2
    
    # Clamp RGB to [0, 1]
    rgb = torch.clamp(rgb, 0, 1)
    
    # Determine feature dimensions based on model if available
    if model is not None and hasattr(model, 'input_feat_dim'):
        total_feat_dim = model.input_feat_dim
        print(f"Model expects {total_feat_dim} total feature dimensions")
        
        # Split roughly evenly between CLIP and DINO
        clip_feat_dim = total_feat_dim // 2
        dino_feat_dim = total_feat_dim - clip_feat_dim
        print(f"Using CLIP: {clip_feat_dim}, DINO: {dino_feat_dim}")
    else:
        # The error showed we need 896 total dimensions
        # Current attempt: 512 + 384 = 896
        total_feat_dim = 896
        clip_feat_dim = 512  
        dino_feat_dim = 384  
        print(f"Using default dimensions - CLIP: {clip_feat_dim}, DINO: {dino_feat_dim}, Total: {total_feat_dim}")
    
    # Create CLIP and DINO features with correct dimensions
    features_clip = torch.randn(num_points, clip_feat_dim, device=device) * 0.1
    features_dino = torch.randn(num_points, dino_feat_dim, device=device) * 0.1
    
    # Create the featurized scene dictionary
    featurized_scene_dict = {
        "features_clip": features_clip,      # Shape: (num_points, clip_feat_dim)
        "features_dino": features_dino,      # Shape: (num_points, dino_feat_dim)
        "rgb": rgb,                          # Shape: (num_points, 3) in [0,1] range
        "points": xyz,                       # Shape: (num_points, 3)
    }
    
    return featurized_scene_dict

In [5]:
# import the pretrained model
model_3djepa = Encoder3DJEPA.from_pretrained("facebook/3d-jepa")

featurized_scene_dict = create_featurized_scene_dict(
    num_points=1024, 
    model=model_3djepa,
    device=torch.device('cuda')
)

# Load and move model to CUDA
model_3djepa = Encoder3DJEPA.from_pretrained("facebook/3d-jepa")
model_3djepa = model_3djepa.cuda()

# Explicitly move zero_token to CUDA
if hasattr(model_3djepa, 'zero_token'):
    model_3djepa.zero_token = model_3djepa.zero_token.cuda()

# Check device
print(f"zero_token device: {model_3djepa.zero_token.device}")
# Check if you have CUDA
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")

print(f"GPU: {torch.cuda.get_device_name()}")
print(f"CUDA Capability: {torch.cuda.get_device_capability()}")

# Ampere (supports FlashAttention): capability >= (8, 0)
# Since our current VM doesn't support, flash attention was disabled in point_transformer_v3.py
capability = torch.cuda.get_device_capability()
is_ampere_or_newer = capability[0] >= 8
print(f"Supports FlashAttention: {is_ampere_or_newer}")


output = model_3djepa(featurized_scene_dict)


torch.Size([1024, 3])
Model expects 1536 total feature dimensions
Using CLIP: 768, DINO: 768
zero_token device: cuda:0
CUDA available: True
CUDA device count: 1
GPU: Tesla T4
CUDA Capability: (7, 5)
Supports FlashAttention: False
[DEBUG] features.shape = torch.Size([1, 1024, 1536])
[DEBUG] zero_token.shape = torch.Size([1536])


In [5]:
print(vars(model_3djepa).keys()) 

dict_keys(['_hub_mixin_config', 'voxel_size', 'input_feat_dim', 'embed_dim', 'training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_pre_hooks', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_hooks_always_called', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_state_dict_hooks', '_state_dict_pre_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', 'num_features'])


In [None]:
# Analysis of the output od the 3Djepa encoder
print(f"Model output keys: {output.keys()}")
if 'features' in output:
    print(f"Features shape: {output['features'].shape}")
if 'points' in output:
    print(f"Points shape: {output['points'].shape}")
    
print("Success! Generated embeddings from synthetic featurized scene.")

# Print some statistics about the output features
if 'features' in output:
    features = output['features']
    print(f"\nFeature statistics:")
    print(f"  Mean: {features.mean().item():.4f}")
    print(f"  Std: {features.std().item():.4f}")
    print(f"  Min: {features.min().item():.4f}")
    print(f"  Max: {features.max().item():.4f}")