In [28]:
import torch
import yaml
from utils.builder import build_vae_model
from models.condition import ImageEncoder
from models.mar import MAREncoderDecoder
from thsolver.config import parse_args
import sys 
import matplotlib.pyplot as plt
import ocnn
from PIL import Image
import os
import copy
from ognn.octreed import OctreeD
from utils import utils
torch.cuda.set_device(7)

In [29]:
log_path = '/mnt/sdc/wangrh/workspace/OctAR-solver/logs/sketch/airplane_p1024_d8'

sys.argv = ['']  # Reset sys.argv
sys.argv.extend(['--config', log_path + '/all_configs.yaml'])
flags = parse_args(backup=False)
flags

CfgNode({'BASE': ['configs/shapenet_octar.yaml'], 'SOLVER': CfgNode({'alias': '', 'gpu': (0, 1, 2, 3), 'run': 'generate', 'logdir': 'logs/airplane/mar_p1024_d8_image', 'ckpt': '', 'ckpt_num': 10, 'type': 'adamw', 'weight_decay': 0.01, 'clip_grad': -1.0, 'max_epoch': 400, 'warmup_epoch': 20, 'warmup_init': 0.001, 'eval_epoch': 1, 'eval_step': -1, 'test_every_epoch': 1, 'log_per_iter': 50, 'best_val': 'min:loss', 'zero_grad_to_none': False, 'use_amp': True, 'lr_type': 'constant', 'lr': 5e-05, 'lr_min': 0.0001, 'gamma': 0.1, 'milestones': (120, 180), 'lr_power': 0.9, 'port': 10001, 'progress_bar': True, 'rand_seed': 0, 'empty_cache': 50, 'expand_ckpt': False, 'step_size': (160, 240), 'resolution': 256, 'save_sdf': False, 'sdf_scale': 0.9}), 'DATA': CfgNode({'train': CfgNode({'name': 'shapenet_vae', 'disable': False, 'pin_memory': True, 'depth': 8, 'full_depth': 3, 'orient_normal': '', 'distort': False, 'scale': 0.0, 'uniform': False, 'jitter': 0.0, 'interval': (1, 1, 1), 'angle': (180, 18

In [30]:
device = 'cuda'

model = MAREncoderDecoder(vqvae_config=flags.MODEL.VQVAE, **flags.MODEL.GPT)
vqvae = build_vae_model(flags.MODEL.VQVAE)
sketch_encoder = ImageEncoder(flags.MODEL.GPT.condition_encoder)

In [31]:
vqvae_checkpoint = torch.load(flags.MODEL.vqvae_ckpt, weights_only=True, map_location="cpu")
vqvae.load_state_dict(vqvae_checkpoint)
print("Load VQVAE from", flags.MODEL.vqvae_ckpt)

Load VQVAE from saved_ckpt/vqvae_huge_im5_bsq32_200epoch.pth


In [32]:
ar_checkpoint = log_path + '/best_model.pth'
model_checkpoint = torch.load(ar_checkpoint, map_location="cpu")
model.load_state_dict(model_checkpoint)
print("Load MAR from", ar_checkpoint)

Load MAR from /mnt/sdc/wangrh/workspace/OctAR-solver/logs/sketch/airplane_p1024_d8/best_model.pth


In [33]:
model = model.to(device)
vqvae = vqvae.to(device)
sketch_encoder = sketch_encoder.to(device)

In [34]:
depth = flags.DATA.test.depth
full_depth = flags.DATA.test.full_depth
depth_stop = flags.MODEL.depth_stop

In [35]:
def generate_by_sketch(sketch_path, sketchname='default'):
    sketch = Image.open(sketch_path)

    with torch.no_grad():
        condition = sketch_encoder(sketch, device=device)
        octree_out = ocnn.octree.init_octree(
            depth=depth,
            full_depth=full_depth,
            batch_size=1,
            device=device,
        )
        with torch.autocast('cuda', enabled=flags.SOLVER.use_amp):
            octree_out, vq_code = model.generate(
                octree=octree_out,
                depth_low=full_depth,
                depth_high=depth_stop,
                vqvae=vqvae,
                condition=condition,
                cfg_scale=None
            )

    export_path = f'results-inference/{sketchname}'

    index = 'output'

    # Export octrees
    for d in range(full_depth+1, depth_stop+1):
        utils.export_octree(octree_out, d, os.path.join(
            log_path ,export_path), index=f'octree_{d}')

    # Decode the mesh
    for d in range(depth_stop, depth):
        split_zero_d = torch.zeros(
            octree_out.nnum[d], device=octree_out.device).long()
        octree_out.octree_split(split_zero_d, d)
        octree_out.octree_grow(d + 1)
    doctree_out = OctreeD(octree_out)
    with torch.no_grad():
        output = vqvae.decode_code(
            vq_code, depth_stop, doctree_out,
            copy.deepcopy(doctree_out), update_octree=True)

    # extract the mesh
    utils.create_mesh(
        output['neural_mpu'],
        os.path.join(log_path, export_path, f"output.obj"),
        size=flags.SOLVER.resolution,
        level=0.002, clean=True,
        bbmin=-flags.SOLVER.sdf_scale,
        bbmax=flags.SOLVER.sdf_scale,
        mesh_scale=flags.DATA.test.points_scale,
        save_sdf=flags.SOLVER.save_sdf)

    # Save the sketch image
    sketch.save(os.path.join(log_path, export_path, f"input.png"))

In [None]:
sketch_paths = [f'/mnt/sdc/wangrh/workspace/OctAR-solver/logs/sketch/airplane_p1024_d8/results/images/{i}.png' for i in range(1, 50)]

for i, sketch_path in enumerate(sketch_paths):
    print(f"Generating {i+1}/{len(sketch_paths)}")
    try:
        generate_by_sketch(sketch_path, f"sketch_{i+1}")
    except:
        pass
    print("Done")

Generating 1/49
Done
Generating 2/49
Done
Generating 3/49
Done
Generating 4/49
Done
Generating 5/49


100%|██████████| 64/64 [00:04<00:00, 15.24it/s]
100%|██████████| 128/128 [00:08<00:00, 15.02it/s]
100%|██████████| 128/128 [00:13<00:00,  9.19it/s]
100%|██████████| 256/256 [00:28<00:00,  9.11it/s]
100%|██████████| 1/1 [00:00<00:00, 83.92it/s]
100%|██████████| 1/1 [00:00<00:00, 39.99it/s]
100%|██████████| 1/1 [00:00<00:00, 12.84it/s]


Done
Generating 6/49


100%|██████████| 64/64 [00:04<00:00, 15.38it/s]
100%|██████████| 128/128 [00:08<00:00, 14.79it/s]
100%|██████████| 128/128 [00:09<00:00, 14.21it/s]
100%|██████████| 256/256 [00:23<00:00, 10.89it/s]
100%|██████████| 1/1 [00:00<00:00, 68.04it/s]
100%|██████████| 1/1 [00:00<00:00, 39.97it/s]
100%|██████████| 1/1 [00:00<00:00, 12.80it/s]


Done
Generating 7/49


100%|██████████| 64/64 [00:04<00:00, 15.38it/s]
100%|██████████| 128/128 [00:08<00:00, 15.02it/s]
100%|██████████| 128/128 [00:08<00:00, 14.35it/s]
100%|██████████| 256/256 [00:23<00:00, 10.88it/s]
100%|██████████| 1/1 [00:00<00:00, 96.43it/s]
100%|██████████| 1/1 [00:00<00:00, 43.26it/s]
100%|██████████| 1/1 [00:00<00:00, 13.48it/s]


Done
Generating 8/49


100%|██████████| 64/64 [00:04<00:00, 15.44it/s]
100%|██████████| 128/128 [00:08<00:00, 14.84it/s]
100%|██████████| 128/128 [00:08<00:00, 14.24it/s]
100%|██████████| 256/256 [00:24<00:00, 10.59it/s]
100%|██████████| 1/1 [00:00<00:00, 81.59it/s]
100%|██████████| 1/1 [00:00<00:00, 36.75it/s]
100%|██████████| 1/1 [00:00<00:00, 12.39it/s]


Done
Generating 9/49
Done
Generating 10/49


100%|██████████| 64/64 [00:04<00:00, 15.44it/s]
100%|██████████| 128/128 [00:08<00:00, 14.84it/s]
100%|██████████| 128/128 [00:08<00:00, 14.49it/s]
100%|██████████| 256/256 [00:23<00:00, 10.95it/s]
100%|██████████| 1/1 [00:00<00:00, 56.91it/s]
100%|██████████| 1/1 [00:00<00:00, 42.37it/s]
100%|██████████| 1/1 [00:00<00:00, 12.78it/s]


Done
Generating 11/49


100%|██████████| 64/64 [00:04<00:00, 15.49it/s]
100%|██████████| 128/128 [00:08<00:00, 14.81it/s]
100%|██████████| 128/128 [00:09<00:00, 14.11it/s]
100%|██████████| 256/256 [00:24<00:00, 10.55it/s]
100%|██████████| 1/1 [00:00<00:00, 95.73it/s]
100%|██████████| 1/1 [00:00<00:00, 37.99it/s]
100%|██████████| 1/1 [00:00<00:00, 12.89it/s]


Done
Generating 12/49


100%|██████████| 64/64 [00:04<00:00, 15.35it/s]
100%|██████████| 128/128 [00:08<00:00, 15.04it/s]
100%|██████████| 128/128 [00:08<00:00, 14.35it/s]
100%|██████████| 256/256 [00:23<00:00, 10.73it/s]
100%|██████████| 1/1 [00:00<00:00, 97.53it/s]
100%|██████████| 1/1 [00:00<00:00, 42.96it/s]
100%|██████████| 1/1 [00:00<00:00, 15.06it/s]


Done
Generating 13/49


100%|██████████| 64/64 [00:04<00:00, 15.29it/s]
100%|██████████| 128/128 [00:08<00:00, 14.62it/s]
100%|██████████| 128/128 [00:09<00:00, 14.10it/s]
100%|██████████| 256/256 [00:25<00:00,  9.87it/s]
100%|██████████| 1/1 [00:00<00:00, 88.36it/s]
100%|██████████| 1/1 [00:00<00:00, 38.81it/s]
100%|██████████| 1/1 [00:00<00:00,  9.62it/s]


Done
Generating 14/49
Done
Generating 15/49


100%|██████████| 64/64 [00:04<00:00, 15.23it/s]
100%|██████████| 128/128 [00:08<00:00, 14.69it/s]
100%|██████████| 128/128 [00:09<00:00, 13.04it/s]
 42%|████▏     | 108/256 [00:15<00:21,  6.81it/s]


Done
Generating 16/49


100%|██████████| 64/64 [00:04<00:00, 15.30it/s]
 91%|█████████ | 116/128 [00:07<00:00, 14.79it/s]