In [1]:
import torch
import yaml
from utils.builder import build_vae_model
from models.condition import TextEncoder 
from models.mar import MAREncoderDecoder
from thsolver.config import parse_args
import sys 
import ocnn
import os
import copy
from ognn.octreed import OctreeD
from utils import utils
torch.cuda.set_device(3)

In [2]:
log_path = 'logs/objaverse/mar_text_bvflip5_expand1152'

sys.argv = ['']  # Reset sys.argv
sys.argv.extend(['--config', log_path + '/all_configs.yaml'])
flags = parse_args(backup=False)
flags

CfgNode({'BASE': ['configs/shapenet_octar.yaml'], 'SOLVER': CfgNode({'alias': '', 'gpu': (0, 1, 2, 3), 'run': 'generate', 'logdir': 'logs/objaverse/mar_text_bvflip5_expand1152', 'ckpt': '', 'ckpt_num': 10, 'type': 'adamw', 'weight_decay': 0.01, 'clip_grad': -1.0, 'max_epoch': 400, 'warmup_epoch': 20, 'warmup_init': 0.001, 'eval_epoch': 1, 'eval_step': -1, 'test_every_epoch': 1, 'log_per_iter': 50, 'best_val': 'min:loss', 'zero_grad_to_none': False, 'use_amp': True, 'lr_type': 'constant', 'lr': 1e-05, 'lr_min': 0.0001, 'gamma': 0.1, 'milestones': (120, 180), 'lr_power': 0.9, 'port': 20001, 'progress_bar': True, 'rand_seed': 0, 'empty_cache': 50, 'expand_ckpt': False, 'step_size': (160, 240), 'resolution': 256, 'save_sdf': False, 'sdf_scale': 0.9}), 'DATA': CfgNode({'train': CfgNode({'name': 'objaverse', 'disable': False, 'pin_memory': True, 'depth': 8, 'full_depth': 3, 'orient_normal': '', 'distort': False, 'scale': 0.0, 'uniform': False, 'jitter': 0.0, 'interval': (1, 1, 1), 'angle': (

In [3]:
device = 'cuda'

model = MAREncoderDecoder(vqvae_config=flags.MODEL.VQVAE, **flags.MODEL.GPT)
vqvae = build_vae_model(flags.MODEL.VQVAE)
text_encoder = TextEncoder(flags.MODEL.GPT.condition_encoder)

In [4]:
vqvae_checkpoint = torch.load(flags.MODEL.vqvae_ckpt, weights_only=True, map_location="cpu")
vqvae.load_state_dict(vqvae_checkpoint)
print("Load VQVAE from", flags.MODEL.vqvae_ckpt)

Load VQVAE from saved_ckpt/vqvae_objv_huge_bsq32_flip0.5.pth


In [5]:
ar_checkpoint = os.path.join(log_path, 'checkpoints/00018.model.pth')
model_checkpoint = torch.load(ar_checkpoint, map_location="cpu")
model.load_state_dict(model_checkpoint)
print("Load MAR from", ar_checkpoint)

  model_checkpoint = torch.load(ar_checkpoint, map_location="cpu")


Load MAR from logs/objaverse/mar_text_bvflip5_expand1152/checkpoints/00018.model.pth


In [6]:
model = model.to(device)
vqvae = vqvae.to(device)
text_encoder = text_encoder.to(device)

In [8]:
text = '3D palm tree model.'
export_path = f'results-inference/{text}'
os.makedirs(os.path.join(log_path, export_path))
# Save the text:
with open(os.path.join(log_path, export_path, f"input.txt"), "w") as f:
    f.write(text + '\n')

depth = flags.DATA.test.depth
full_depth = flags.DATA.test.full_depth
depth_stop = flags.MODEL.depth_stop
model.num_iters = [64, 128, 128, 256]
num_gen = 4

for i in range(num_gen):
    with torch.no_grad():
        condition = text_encoder(text, device=device)
        octree_out = ocnn.octree.init_octree(
            depth=depth,
            full_depth=full_depth,
            batch_size=1,
            device=device,
        )
        with torch.autocast('cuda', enabled=flags.SOLVER.use_amp):
            octree_out, vq_code = model.generate(
                octree=octree_out,
                depth_low=full_depth,
                depth_high=depth_stop,
                vqvae=vqvae,
                condition=condition,
                cfg_scale=3.0,
            )
    # Export octrees
    for d in range(full_depth+1, depth_stop+1):
        utils.export_octree(octree_out, d, os.path.join(
            log_path, export_path), index=f'octree_{d}')

    # Decode the mesh
    for d in range(depth_stop, depth):
        split_zero_d = torch.zeros(
            octree_out.nnum[d], device=octree_out.device).long()
        octree_out.octree_split(split_zero_d, d)
        octree_out.octree_grow(d + 1)
    doctree_out = OctreeD(octree_out)
    with torch.no_grad():
        output = vqvae.decode_code(
            vq_code, depth_stop, doctree_out,
            copy.deepcopy(doctree_out), update_octree=True)

    # extract the mesh
    utils.create_mesh(
        output['neural_mpu'],
        os.path.join(log_path, export_path, f"output.obj"),
        size=flags.SOLVER.resolution,
        level=0.002, clean=True,
        bbmin=-flags.SOLVER.sdf_scale,
        bbmax=flags.SOLVER.sdf_scale,
        mesh_scale=flags.DATA.test.points_scale,
        save_sdf=flags.SOLVER.save_sdf)
    
    

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 64/64 [00:10<00:00,  5.98it/s]
100%|██████████| 128/128 [00:22<00:00,  5.81it/s]
100%|██████████| 128/128 [01:26<00:00,  1.48it/s]
  2%|▏         | 6/256 [00:20<14:17,  3.43s/it]


KeyboardInterrupt: 