In [4]:
from safetensors.torch import save_file, load_file
import glob

import os
import torch
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

# Load file paths

In [None]:


objects = glob.glob("../data/*")
files = []
for obj in objects:
    files += glob.glob(obj +  "/*/checkpoints")

files.sort()
len(files)

In [None]:

def process_file(data_path):
    data_path = data_path.replace("/final.pth", "")    
    
    #Extract MLP weights and save as mlp_raw.pth
    weights = torch.load(os.path.join(data_path, "final.pth"), map_location='cpu')['model']
    
    geometry_layers = []
    view_layers = []
    for layer in range(3):
        weight_key_grid = f'_orig_mod.grid_mlp.net.{layer}.weight'
        weight_key_view = f'_orig_mod.view_mlp.net.{layer}.weight'

        geometry_layers.append(torch.tensor(weights[weight_key_grid]))
        view_layers.append(torch.tensor(weights[weight_key_view]))

    data = {
            "geometry_layers": geometry_layers,
            "view_layers" : view_layers
              
        }
    torch.save(data,data_path + "/mlp_raw.pth")
    
    #Extract embeddings and save as hash.bin
    embeddings = weights['_orig_mod.grid_encoder.embeddings']
    has_nan = torch.isnan(embeddings).any().item()
    has_inf = torch.isinf(embeddings).any().item()
    
    if has_nan or has_inf:
        pass
    else:
        embeddings.numpy().astype('float32').tofile(os.path.join(data_path, "hash.bin"))
    return 


with ProcessPoolExecutor() as executor:
    futures = [executor.submit(process_file, f) for f in files]

    for _ in tqdm(as_completed(futures), total=len(futures)):
        pass  # just to show progress bar