In [None]:
#%matplotlib widget

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#%pip install lab_black
#%load_ext lab_black
%cd ..

In [None]:
from pathlib import Path

import torch as th
import torch.nn.functional as F
import numpy as np
import yaml
from easydict import EasyDict

from src.utils import instantiate_from_config, get_device
from src.utils.vis import save_sdf_as_mesh

In [None]:
th.set_grad_enabled(False)
device = get_device()
device

# Load Pretrained Models

In [None]:
gen32_args_path = "config/gen32/chair.yaml"
gen32_ckpt_path = "results/gen32/chair.pth"
sr64_args_path = "config/sr32_64/chair.yaml"
sr64_ckpt_path = "results/sr32_64/chair.pth"

In [None]:
with open(gen32_args_path) as f:
    args1 = EasyDict(yaml.safe_load(f))
with open(sr64_args_path) as f:
    args2 = EasyDict(yaml.safe_load(f))

In [None]:
model1 = instantiate_from_config(args1.model)
ckpt = th.load(gen32_ckpt_path, map_location=device)
model1.load_state_dict(ckpt["model_ema"])
model1 = model1.to(device)
model1.eval()
model1.training

In [None]:
model2 = instantiate_from_config(args2.model)
ckpt = th.load(sr64_ckpt_path, map_location=device)
model2.load_state_dict(ckpt["model"])
model2 = model2.to(device)
model2.eval()
model2.training

In [None]:
ddpm_sampler1 = instantiate_from_config(args1.ddpm.valid, device=device)
ddpm_sampler2 = instantiate_from_config(args2.ddpm.valid, device=device)

ddpm_sampler1, ddpm_sampler2 = ddpm_sampler1.to(device), ddpm_sampler2.to(device)

In [None]:
preprocessor1 = instantiate_from_config(args1.preprocessor, device=device)
preprocessor2 = instantiate_from_config(args2.preprocessor, device=device)

# Generate Low-Resolution ($32^3$)

Generates 5 low-resolution samples

In [None]:
from src.utils.utils import seed_everything
seed_everything(40)

In [None]:
out1 = ddpm_sampler1.sample_ddim(model1, (1, 1, 32, 32, 32), show_pbar=True)

In [None]:
out1 = preprocessor1.destandardize(out1)
out1.shape

In [None]:
# TODO: with mps I get same samples across batches
#th.allclose(*out1[:2]), th.allclose(*out1[1:3]), th.allclose(*out1[0:3:2])

In [None]:
from src.utils.vis import plot_sdfs
#view_kwargs = {"azim": 57, "elev": 6, "roll": 0, "vertical_axis": "y"}
view_kwargs = {"azim": 30, "elev": 30, "roll": 0, "vertical_axis": "y"}
plot_sdfs(list(out1), view_kwargs=view_kwargs)

In [None]:
# save as an obj file
# for i, out in enumerate(out1):
#     save_sdf_as_mesh(f"gen32_{i}.obj", out, safe=True)

In [None]:
lr_cond = F.interpolate(out1, (64, 64, 64), mode="nearest")
lr_cond = preprocessor2.standardize(lr_cond, 0)
out2 = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond, x], 1), t), (out1.shape[0], 1, 64, 64, 64), show_pbar=True)

out2 = preprocessor2.destandardize(out2, 1)

#for i, out in enumerate(out2):
#    save_sdf_as_mesh(f"sr64_{i}.obj", out, safe=True)

plot_sdfs(list(out2), title="Super-resolution origianal samples")

In [None]:
# Test inversion
out1_inv = ddpm_sampler1.sample_ddim(model1, x_t = ddpm_sampler1.invert_ddim(model1, preprocessor1.standardize(out1), show_pbar=True, debug_plot=True), show_pbar=True)

In [None]:
out1_inv = preprocessor1.destandardize(out1_inv)
out1_inv.shape

In [None]:
# save as an obj file
# for i, out in enumerate(out1_inv):
#     save_sdf_as_mesh(f"inv_gen32_{i}.obj", out, safe=True)

In [None]:
# compute norm difference
th.norm(out1 - out1_inv)

In [None]:
plot_sdfs([out1, out1_inv], titles=["Original", "Predicted from inversion"])

In [None]:
del out1_inv

In [None]:
# Optimization 
def volume_estimates(sdfs, dx=1., dy=1., dz=1.): 
    #inside_mask = (-sdfs) > 0
    #volume_estimates = th.sum(inside_mask.float(), dim=list(range(1, sdfs.ndim)))
    volume_estimates_activation = th.sum(StraightThroughEstimator()(-sdfs), dim=list(range(1, sdfs.ndim)))
    #assert th.allclose(volume_estimates, volume_estimates_bin_activation), f"using mask: {volume_estimates}, with activation function {volume_estimates_bin_activation}"
    return volume_estimates_activation * dx * dy * dz

class STEFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(th.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x
 

def volume_estimates_loss_fn(xs, target_volumes, max_volume=1., grad_var_reg_weight=0, tot_variation_reg_weight=0):
    input_volumes = volume_estimates(xs) / max_volume
    target_volumes = target_volumes / max_volume
    loss = th.nn.MSELoss()(input_volumes, target_volumes)
    return loss + grad_var_reg_weight * th.var(xs.grad, dim=list(range(1, xs.ndim))) + tot_variation_reg_weight * tot_variation(xs)


def tot_variation(sdfs):      
    tv_x = ((sdfs[:,:,1:,:,:] - sdfs[:,:,:-1,:,:]).pow(2)).sum()
    tv_y = ((sdfs[:,:,:,1:,:] - sdfs[:,:,:,:-1,:]).pow(2)).sum()    
    tv_z = ((sdfs[:,:,:,:,1:] - sdfs[:,:,:,:,:-1]).pow(2)).sum()
    return tv_x + tv_y + tv_z


In [None]:
out1_std = preprocessor1.standardize(out1)

In [None]:
out1_std.mean(), out1_std.min(), out1_std.max(), out1_std.var()

In [None]:
out1.mean(), out1.min(), out1.max(), out1.var()

In [None]:
shift = -0.062
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = 0.7
t_optim_idx = 8
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx, 
    tgt_noise_level="t_optim",
    opt_kwargs={"weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        #th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        #[f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 5
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx, 
    tgt_noise_level="t_optim",
    opt_kwargs={"weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        #th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        #[f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 15
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx, 
    tgt_noise_level="t_optim",
    opt_kwargs={"weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        #th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        #[f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.4
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "grad_var_reg_weight": 10},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 0.},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "grad_var_reg_weight": 1000},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=50,
    opt_kwargs={"lr":1e-2, "weight_decay": 0.},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "tot_variation_reg_weight": 1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=50,
    opt_kwargs={"lr":1e-2, "weight_decay": 0.},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-3},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.7
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-3, "weight_decay": 0.},
    grad_clip_value=2.,
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

too slow

In [None]:
out1_std.min(), out1_std.max(), out1_std.mean(), out1_std.var()

In [None]:
out1.min(), out1.max(), out1.mean(), out1.var()

In [None]:
shift = -0.049
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = 1.3
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
shift = -0.0621
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
shift = +0.055
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = -0.61
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-4}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 4
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 5
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-3},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[out1, preprocessor1.destandardize(x_edited), th.abs(out1- preprocessor1.destandardize(x_edited))], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
# for i, out in enumerate(preprocessor1.destandardize(x_edited)):
#     save_sdf_as_mesh(f"gen32_{i}_v40%incr_3steps.obj", out, safe=True)

# lr_cond_edit = F.interpolate(preprocessor1.destandardize(x_edited), (64, 64, 64), mode="nearest")
# lr_cond_edit = preprocessor2.standardize(lr_cond_edit, 0)
# x_edited_sr = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond_edit, x], 1), t), (out1.shape[0], 1, 64, 64, 64), show_pbar=True)

# x_edited_sr = preprocessor2.destandardize(x_edited_sr, 1)

# for i, out in enumerate(x_edited_sr):
#     save_sdf_as_mesh(f"sr64_{i}_v40%incr_3steps.obj", out, safe=True)

# plot_sdfs(
#     sdfs=[out2, x_edited_sr], 
#     title = f"Super-resolution results of optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
#     titles=[
#         [f"Original (SR) ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
#         [f"Edited (SR) ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
#         ]
# )

In [None]:
target_volume_increment = -0.61
t_optim_idx = 6
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 5
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)