In [1]:
import sys
sys.path.append('../utils')
from eval import load_diffusion, compute_base_grid
import time
from train_flow_matching import *
import mesh_tools as mt

In [2]:
BASE_RES = 16
BATCH_SIZE = 10
device='cuda'
SRC = '../../experiments_fm'

In [3]:
def p1_to_flow(model, XT, T):
    p1 = model(XT, T)
    p1.feature.jdata = (p1.jdata-XT.jdata)/(1-T[:, None])
    return p1
    
@torch.no_grad()
def model_step(model, XT, t_start, t_end):
    TSTART = t_start.view(1).expand(len(XT.jdata))
    TEND = t_end.view(1).expand(len(XT.jdata))
    p1 = p1_to_flow(model, XT, TSTART)
    p1.feature.jdata *= (TEND-TSTART)[:, None]/2.
    p1.feature.jdata += XT.jdata
    p2 = p1_to_flow(model, p1, TSTART+ (TEND-TSTART)/2.)
    p2.feature.jdata *= (TEND-TSTART)[:, None]
    p2.feature.jdata += XT.jdata
    
    return p2

@torch.no_grad()
def sample(model, XB, n_steps):
    time_steps = torch.linspace(0, 1.0, n_steps + 1, device=device)
    model.eval()
    for i in tqdm(range(n_steps-1)):
        XB = model_step(model, XB, time_steps[i], time_steps[i+1])
    XB = model(XB, time_steps[-2].view(1).expand(len(XB.jdata)))
    return XB


In [4]:
def generate_input(generated_X, sparse_fm):
    with torch.no_grad():
        sparse_fm.model_upsampler.eval()
        input_X = sparse_fm.model_upsampler(
            generated_X, generated_X.trilinear_upsample()).detach()
        times = torch.zeros((input_X.grid_count,), device=generated_X.device).float()
        times = times[input_X.feature.jidx.long()]
        input_X.feature.jdata = sparse_fm.add_x0_noise(input_X)
        return input_X

def generate_level(generated_X, i, example_mesh_name, src, n_steps):
    sparse_fm = load_diffusion(example_mesh_name, i, src)
    sparse_fm.eval()
    t0 = time.time()
    new_XT = generate_input(generated_X, sparse_fm)
    generated_X = sample(sparse_fm.model, new_XT, n_steps)
    return DiffusionTensor.from_vdb(generated_X).remove_mask()


In [5]:
def compute_all_generations(example_mesh_name, src, base_res, max_level=3, eval_batch_size=10, features=10, n_steps=100, X0G=None):
    generated_Xs = []
    # blurs = []
    sparse_fm = load_diffusion(example_mesh_name, 0, src)
    sparse_fm.eval()
    if X0G is None:
        X0G = compute_base_grid(example_mesh_name, eval_batch_size, base_res)
    X = grid_to_VDB(X0G, torch.randn, [features])
    t0 = time.time()
    generated_X = sample(sparse_fm.model, X, n_steps)
    generated_X = DiffusionTensor.from_vdb(generated_X).remove_mask()

    generated_Xs.append(generated_X)

    for i in range(1, max_level+1):
        generated_X = generate_level(
            generated_X, i, example_mesh_name, src, n_steps)
        generated_Xs.append(generated_X)
    return generated_Xs

In [6]:
with torch.no_grad():
    GX = compute_all_generations("canyon", SRC, BASE_RES, 3, eval_batch_size=BATCH_SIZE, n_steps=10)

100%|██████████| 9/9 [00:00<00:00, 50.17it/s]
100%|██████████| 9/9 [00:00<00:00, 108.31it/s]
100%|██████████| 9/9 [00:00<00:00, 76.41it/s]
100%|██████████| 9/9 [00:00<00:00, 19.37it/s]


In [8]:
tens = GX[-1]

for ind in range(2):
    disp_d = DiffusionTensor(tens.grid[ind], tens.feature[ind]).remove_mask()
    disp_d.get_global().colored_PC(.12)

    # c=(disp_d.feature.jdata[:, 6:9].cpu().detach().numpy()+2)/4
    # plot(*grid_to_mesh(disp_d.grid, colors=c), shading={'wireframe':True})

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(4.7743320…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-7.227063…