In [1]:
import h5py
import torch
import numpy as np
import glob

In [2]:
CONFIG = {
    'hash_encoding': {
        'num_levels': 16,
        'level_dim': 2,
        'input_dim': 3,
        'log2_hashmap_size': 19,
        'base_resolution': 16
    },
    'mlp': {
        'num_layers': 3,  # Number of layers in geometric MLP
        'hidden_dim': 64,  # Hidden dimension size
    }
}

def load_torch_weights(file_path):
    """Load model weights from a checkpoint file."""
    try:
        weights = torch.load(file_path, map_location='cpu')
        return weights['model']
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        return None
        
def extract_hash_encoding_structure(model_weights, num_levels=16, level_dim=2, input_dim=3, log2_hashmap_size=19, base_resolution=16):
    """
    Extract and organize hash encoding weights into hierarchical structure.
    
    Args:
        model_weights (dict): The loaded model weights dictionary
        num_levels (int): Number of levels in hash encoding
        level_dim (int): Dimension of encoding at each level
        input_dim (int): Input dimension (typically 3 for 3D)
        log2_hashmap_size (int): Log2 of maximum hash table size
        base_resolution (int): Base resolution of the grid
        
    Returns:
        dict: Hierarchical structure of hash encoding weights
    """
    # Extract hash encoding embeddings
    embeddings = model_weights['_orig_mod.grid_encoder.embeddings']
    
    # Calculate per-level parameters
    max_params = 2 ** log2_hashmap_size
    per_level_scale = np.exp2(np.log2(2048 / base_resolution) / (num_levels - 1))
    
    # Initialize structure to store weights
    hash_structure = {}
    offset = 0
    
    for level in range(num_levels):
        # Calculate resolution at this level
        resolution = int(np.ceil(base_resolution * (per_level_scale ** level)))
        
        # Calculate number of parameters for this level
        params_in_level = min(max_params, (resolution) ** input_dim)
        params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible by 8
        
        # Extract weights for this level
        level_weights = embeddings[offset:offset + params_in_level]
        
        # Store level information
        hash_structure[f'level_{level}'] = {
            'resolution': resolution,
            'num_params': params_in_level,
            'weights': level_weights,
            'weights_shape': level_weights.shape,
            'scale': per_level_scale ** level
        }
        
        offset += params_in_level
    
    # Add global information
    hash_structure['global_info'] = {
        'total_params': offset,
        'embedding_dim': level_dim,
        'base_resolution': base_resolution,
        'max_resolution': int(np.ceil(base_resolution * (per_level_scale ** (num_levels-1)))),
        'per_level_scale': per_level_scale
    }
    
    return hash_structure

def extract_mlp_weights(model_weights):
    """Extract geometric and view-dependent MLP weights from the model."""
    geometry_layers = {}
    view_mlp_layers = {}
    
    # Extract geometry MLP weights
    for i in range(CONFIG['mlp']['num_layers']):
        weight_key = f'_orig_mod.grid_mlp.net.{i}.weight'
        bias_key = f'_orig_mod.grid_mlp.net.{i}.bias'
        
        if weight_key in model_weights:
            geometry_layers[f'layer_{i}'] = {
                'weights': model_weights[weight_key],
                'shape': model_weights[weight_key].shape
            }
            
            if bias_key in model_weights:
                geometry_layers[f'layer_{i}']['bias'] = model_weights[bias_key]
    
    # Extract view-dependent MLP weights
    for i in range(CONFIG['mlp']['num_layers']):
        weight_key = f'_orig_mod.view_mlp.net.{i}.weight'
        bias_key = f'_orig_mod.view_mlp.net.{i}.bias'
        
        if weight_key in model_weights:
            view_mlp_layers[f'layer_{i}'] = {
                'weights': model_weights[weight_key],
                'shape': model_weights[weight_key].shape
            }
            
            if bias_key in model_weights:
                view_mlp_layers[f'layer_{i}']['bias'] = model_weights[bias_key]
    
    return {
        'geometry_mlp': geometry_layers,
        'view_mlp': view_mlp_layers
    }

# Example usage


In [3]:
data_path = "../../ten_objs/ten_objs/shared_data/CarrotKhanStatue/base_000_000_000/checkpoints"
nerf = load_torch_weights(data_path + "/final.pth")
mlp_weights = extract_mlp_weights(nerf)
mrhe_by_layer = extract_hash_encoding_structure(nerf)

del mrhe_by_layer["global_info"]


for key,value in mlp_weights["geometry_mlp"].items():
    new_key = "geo_" + key
    mrhe_by_layer[new_key] = value

for key,value in mlp_weights["view_mlp"].items():
    new_key = "view_" + key
    mrhe_by_layer[new_key] = value

with h5py.File(data_path + "/final.h5", "w") as h5f:
    for name, tensor in mrhe_by_layer.items():
        h5f.create_dataset(name, data=tensor["weights"].numpy(), compression="gzip")

In [9]:
objects = glob.glob("../../../../../../..//media/boz/408422C88422C070/ada_and_hal/*")
files = []
for obj in objects:
    files += glob.glob(obj +  "/*/checkpoints/*.pth")

files.sort()

for ix,data_path in enumerate(files):
    data_path = data_path.replace("/final.pth", "")
    nerf = load_torch_weights(data_path + "/final.pth")
    mlp_weights = extract_mlp_weights(nerf)
    mrhe_by_layer = extract_hash_encoding_structure(nerf)
    
    del mrhe_by_layer["global_info"]
    
    
    for key,value in mlp_weights["geometry_mlp"].items():
        new_key = "geo_" + key
        mrhe_by_layer[new_key] = value
    
    for key,value in mlp_weights["view_mlp"].items():
        new_key = "view_" + key
        mrhe_by_layer[new_key] = value

    print(data_path + "/final.h5")
    with h5py.File(data_path + "/final.h5", "w") as h5f:
        for name, tensor in mrhe_by_layer.items():
            h5f.create_dataset(name, data=tensor["weights"].numpy(), compression="gzip")
    print(data_path)


../../../../../../..//media/boz/408422C88422C070/ada_and_hal/1Story/base_000_000_000/checkpoints/final.h5
../../../../../../..//media/boz/408422C88422C070/ada_and_hal/1Story/base_000_000_000/checkpoints
../../../../../../..//media/boz/408422C88422C070/ada_and_hal/1Story/compound_090_000_090/checkpoints/final.h5
../../../../../../..//media/boz/408422C88422C070/ada_and_hal/1Story/compound_090_000_090/checkpoints
../../../../../../..//media/boz/408422C88422C070/ada_and_hal/1Story/x_180_000_000/checkpoints/final.h5


KeyboardInterrupt: 