In [None]:
#%matplotlib widget

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#%pip install lab_black
#%load_ext lab_black
%cd ..

In [None]:
from pathlib import Path

import torch as th
import torch.nn.functional as F
import numpy as np
import yaml
from easydict import EasyDict

from src.utils import instantiate_from_config, get_device
from src.utils.vis import save_sdf_as_mesh, plot_sdfs

from tqdm import tqdm

In [None]:
th.set_grad_enabled(False)
device = get_device()
#device='cpu'
device

In [None]:
gen32_args_path = "config/gen32/chair.yaml"
gen32_ckpt_path = "results/gen32/chair.pth"
sr64_args_path = "config/sr32_64/chair.yaml"
sr64_ckpt_path = "results/sr32_64/chair.pth"

In [None]:
with open(gen32_args_path) as f:
    args1 = EasyDict(yaml.safe_load(f))
with open(sr64_args_path) as f:
    args2 = EasyDict(yaml.safe_load(f))

In [None]:
model1 = instantiate_from_config(args1.model)
ckpt = th.load(gen32_ckpt_path, map_location=device)
model1.load_state_dict(ckpt["model_ema"])
model1 = model1.to(device)
model1.eval()
model1.training

In [None]:
model2 = instantiate_from_config(args2.model)
ckpt = th.load(sr64_ckpt_path, map_location=device)
model2.load_state_dict(ckpt["model"])
model2 = model2.to(device)
model2.eval()
model2.training

In [None]:
ddpm_sampler1 = instantiate_from_config(args1.ddpm.valid, device=device)
ddpm_sampler2 = instantiate_from_config(args2.ddpm.valid, device=device)

ddpm_sampler1, ddpm_sampler2 = ddpm_sampler1.to(device), ddpm_sampler2.to(device)

In [None]:
preprocessor1 = instantiate_from_config(args1.preprocessor, device=device)
preprocessor2 = instantiate_from_config(args2.preprocessor, device=device)

# Generate Low-Resolution ($32^3$)

Generates 5 low-resolution samples

In [None]:
from diffusers import DDIMScheduler, DDIMInverseScheduler

prediction_type_map = {
    "x_0": "sample", 
    "eps": "epsilon"
} 

ddim_scheduler = DDIMScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    #beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    #beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    trained_betas=ddpm_sampler1.betas.cpu(),
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)
#ddim_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.n_timestep-1, device=device) # TODO: to test if increasing the timesteps leads to better results

ddim_inverse_scheduler = DDIMInverseScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    #beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    #beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    trained_betas=ddpm_sampler1.betas.cpu(),
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_inverse_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)
#ddim_inverse_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.n_timestep-1, device=device)


def diffusers_sample(
        noise_scheduler, 
        model, 
        shape=None, 
        x_t=None, 
        from_t_idx=0,
        to_t_idx=None, 
        cond=None, 
        return_intermediates=False, 
        plot_debug=False, 
        log_every_t=5,
        device="cuda",
        ddpm_indexing=True
    ):
    """Sample from the diffusion model."""
    assert (shape is not None) != (x_t is not None), "Either shape or x_t must be provided, but not both." 

    # initialize noise
    samples = th.randn(shape[:1] + shape[2:], device=device).unsqueeze(1) if shape is not None else x_t
    shape = samples.shape
    if to_t_idx == 0: 
        to_t_idx = None # Until last value
    total_steps = len(noise_scheduler.timesteps)
    timesteps = noise_scheduler.timesteps[from_t_idx:to_t_idx]
    timesteps = th.cat((noise_scheduler.timesteps[[0]], timesteps[:-1])) if (from_t_idx % total_steps == 0) else th.cat((noise_scheduler.timesteps[[from_t_idx-1]], timesteps[:-1]))
    timesteps = timesteps.to(device)

    intermediates = [samples]
    #noise_levels = []
    # sample iteratively
    # as done in original SDF code
    #sqrt_alphas_cumprod_prev = th.sqrt(th.cat((noise_scheduler.alphas_cumprod[[0]], noise_scheduler.alphas_cumprod))).to(device)
    sqrt_alphas_cumprod_prev_ddpm = th.sqrt(th.cat((th.tensor([1.]), noise_scheduler.alphas_cumprod[:-1]))).to(device) # as it has been trained with 1 at first value
    # Similarly to inversion
    sqrt_alphas_cumprod_prev = th.sqrt(noise_scheduler.alphas_cumprod.to(device)[reversed(timesteps+1)]) # +1 shift as in ldm implementation (with only t results are similar)
    #alphas_cumprod = sqrt_alphas_cumprod_prev[:-1]
    #print(th.cat((ddim_alphas_cumprod[[0]], ddim_alphas_cumprod[:-1])))
    if from_t_idx is not None: 
        idx_shift = from_t_idx if from_t_idx >= 0 else (total_steps + from_t_idx)
    else:
        idx_shift = 0
    timesteps = timesteps[1:]
    with th.no_grad():
        for i, t in enumerate(tqdm(timesteps)):
            # prev_timestep = t - noise_scheduler.config.num_train_timesteps // noise_scheduler.num_inference_steps
            # noise_level = th.sqrt(noise_scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else noise_scheduler.final_alpha_cumprod)
            #noise_levels.append(noise_level)

            # as done in original SDF code
            #noise_level = sqrt_alphas_cumprod_prev[t+1] # t+1 as in ldm code 
            index = total_steps - i - 1 - idx_shift
            # noise_level = sqrt_alphas_cumprod_prev[[index]]
            #print(th.cat((ddim_alphas_cumprod[[0]], ddim_alphas_cumprod[:-1]))[[index]])
            if ddpm_indexing:
                noise_level = sqrt_alphas_cumprod_prev_ddpm[t+1] # t+1 as in ldm code (with only t results are similar)
            else:  
                noise_level = sqrt_alphas_cumprod_prev[[index]]

            #print(f"Using DDPM indexing: {ddpm_indexing}\t", "noise level: with DDPM indexing:", sqrt_alphas_cumprod_prev_ddpm[t+1], ", with DDIM indexing:", sqrt_alphas_cumprod_prev[[index]] )

            pred = model(
                samples, noise_level * th.ones(shape[0], device=device), c=cond
            )
            samples = noise_scheduler.step(pred, t, samples)
            if i == len(timesteps)-1:
                intermediates.append(samples)
                if plot_debug: 
                    print(f'noise_level: {noise_level[0].item() if noise_level.numel() > 1 else noise_level.item() }')
                    plot_sdfs(samples.prev_sample, title=f"DeNoising - t={t}/{noise_scheduler.config.num_train_timesteps-1} (DDIM: t={index}/{noise_scheduler.num_inference_steps-1})")
            elif i % log_every_t == 0:
                if return_intermediates:
                    intermediates.append(samples)
                if plot_debug: 
                    print(f'noise_level: {noise_level[0].item() if noise_level.numel() > 1 else noise_level.item() }')
                    plot_sdfs(samples.prev_sample, title=f"DeNoising - t={t}/{noise_scheduler.config.num_train_timesteps-1} (DDIM: t={index}/{noise_scheduler.num_inference_steps-1})")
            samples = samples.prev_sample
    if return_intermediates:
        return intermediates
    return intermediates[-1]

def diffusers_inverse_sample(
        inverse_noise_scheduler, 
        model, 
        x_t,
        from_t_idx=0,
        to_t_idx=None,
        cond=None, 
        return_intermediates=False, 
        log_every_t=5, 
        plot_debug=False,
        device="cuda", 
        ddpm_indexing=True
    ):
    """Invert a sample to noise with the diffusion model."""

    samples = x_t
    shape = samples.shape

    intermediates = [samples]
    # alphas_cumprod = inverse_noise_scheduler.alphas_cumprod.to(device)
    if to_t_idx == 0: 
        to_t_idx = None # Until last value
    total_steps = len(inverse_noise_scheduler.timesteps)
    #timesteps = inverse_noise_scheduler.timesteps[from_t_idx:to_t_idx]
    timesteps = th.cat((inverse_noise_scheduler.timesteps[from_t_idx+1:], inverse_noise_scheduler.timesteps[[-1]])) if (to_t_idx is None or to_t_idx % total_steps == 0) else inverse_noise_scheduler.timesteps[from_t_idx+1:to_t_idx+1]
    timesteps = timesteps.to(device)

    # as done in original SDF code
    # TODO: TRY ONLY INDEXING sqrt_alphas_cumprod as next should be t (DOESN'T WORK )
    #sqrt_alphas_cumprod_next = th.sqrt(th.cat((inverse_noise_scheduler.alphas_cumprod[1:], inverse_noise_scheduler.alphas_cumprod[[-1]]))).to(device)

    # As in ddim_sampler.invert_ddim 
    sqrt_alphas_cumprod_next = th.sqrt(inverse_noise_scheduler.alphas_cumprod.to(device)[timesteps+1]) # +1 shift as in ldm implementation (with only t results are similar)
    #sqrt_alphas_cumprod_next = th.sqrt(ddim_alphas_cumprod[1:])

    # # Similar to ldm make_ddim_sampling_parameters
    # TODO
    # alphas_prev = th.sqrt(th.cat((inverse_noise_scheduler.alphas_cumprod[1:], inverse_noise_scheduler.alphas_cumprod[[-1]]))).to(device)

    
    # Same as sample
    #sqrt_alphas_cumprod_ddpm = th.sqrt(inverse_noise_scheduler.alphas_cumprod).to(device) 
    sqrt_alphas_cumprod_next_ddpm = th.sqrt(th.cat((inverse_noise_scheduler.alphas_cumprod[1:], inverse_noise_scheduler.alphas_cumprod[[-1]]))).to(device)
    #sqrt_alphas_cumprod_next_ddpm = sqrt_alphas_cumprod_next_ddpm[timesteps+1]


    # sqrt_alphas_cumprod_next = th.sqrt(th.cat((alphas_cumprod[1:], alphas_cumprod[[-1]]))).to(device)
    # #print(len(sqrt_alphas_cumprod_next))
    if from_t_idx is not None: 
        idx_shift = from_t_idx if from_t_idx >= 0 else (total_steps + from_t_idx)
    else:
        idx_shift = 0
    # sample iteratively
    timesteps = inverse_noise_scheduler.timesteps[from_t_idx:to_t_idx]
    with th.no_grad():
        for i, t in enumerate(tqdm(timesteps)): #TODO maybe I should skip the last step? as right know im dpong .step() even at t=980 (DEBUG INSIDE STEP FUNCTION!!)
            # prev_timestep = t
            # timestep = min(t - inverse_noise_scheduler.config.num_train_timesteps // inverse_noise_scheduler.num_inference_steps, inverse_noise_scheduler.config.num_train_timesteps - 1)
            # alpha_prod_t = inverse_noise_scheduler.alphas_cumprod[timestep] if timestep >= 0 else inverse_noise_scheduler.initial_alpha_cumprod
            # alpha_prod_t_prev = inverse_noise_scheduler.alphas_cumprod[prev_timestep]
            # noise_level = th.sqrt(alpha_prod_t_prev) # TODO: NOT SURE WHAT OF THE TWO IS CORRECT (alpha_prod_t or alpha_prod_t_prev) -- it should be alpha_prod_t_prev tho 
            # #noise_level = th.sqrt(alpha_prod_t)
            index = i + idx_shift
            #noise_level = ddim_alphas_cumprod[[index]]
            #noise_level = sqrt_alphas_cumprod_ddpm[t+1] # t+1 as in ldm code 
            if ddpm_indexing:
                noise_level = sqrt_alphas_cumprod_next_ddpm[t+1] # t+1 as in ldm code (with only t results are similar)
                #noise_level = sqrt_alphas_cumprod_next_ddpm[index]
            else:
                noise_level = sqrt_alphas_cumprod_next[[index]]
            
            #print(f"Using DDPM indexing: {ddpm_indexing}\t", "noise level: with DDPM indexing:", sqrt_alphas_cumprod_next_ddpm[t+1], ", with DDIM indexing:", sqrt_alphas_cumprod_next[[index]] )

            #noise_level = inverse_noise_scheduler.alphas_cumprod[t+1]


            pred = model(
                samples, noise_level * th.ones(shape[0], device=device), c=cond
            )
            samples = inverse_noise_scheduler.step(pred, t, samples)
            if i == len(timesteps)-1:
                intermediates.append(samples)
                if plot_debug: 
                    print(f'noise_level: {noise_level[0].item() if noise_level.numel() > 1 else noise_level.item() }')
                    plot_sdfs(samples.prev_sample, title=f"Noising - t={t}/{inverse_noise_scheduler.config.num_train_timesteps-1} (DDIM: t={index}/{inverse_noise_scheduler.num_inference_steps-1})")
            elif i % log_every_t == 0:
                if return_intermediates:
                    intermediates.append(samples)
                if plot_debug: 
                    print(f'noise_level: {noise_level[0].item() if noise_level.numel() > 1 else noise_level.item() }')
                    plot_sdfs(samples.prev_sample, title=f"Noising - t={t}/{inverse_noise_scheduler.config.num_train_timesteps-1} (DDIM: t={index}/{inverse_noise_scheduler.num_inference_steps-1})")
            samples = samples.prev_sample
    if return_intermediates:
        return intermediates
    return intermediates[-1]


In [None]:
from src.utils.utils import seed_everything
seed_everything(40)
out1 = diffusers_sample(ddim_scheduler, model1, shape=(2, 1, 32, 32, 32), device=device, ddpm_indexing=False).prev_sample
out1_diffusers = out1
plot_sdfs(list(out1), title="Diffusers sampling -- using ddpm1_sampler betas") 
out1.shape

In [None]:
import torch
import numpy as np
from diffusers import DDIMInverseScheduler

num_train_timesteps=1000
scheduler = DDIMInverseScheduler(num_train_timesteps=num_train_timesteps, timestep_spacing='leading', prediction_type="sample")

inference_step = 10
scheduler.set_timesteps(inference_step)

# Before fix: The previous timestep can become negative.
previous_timesteps = torch.tensor([min(timestep - num_train_timesteps // inference_step, num_train_timesteps - 1) for timestep in scheduler.timesteps])
print('Previous Timestep\tCurrent timesteps')
for prev, cur in zip(previous_timesteps, scheduler.timesteps):
    print(f'{prev}\t\t\t{cur}')

# After fix
previous_timesteps = torch.tensor([min(timestep + num_train_timesteps // inference_step, num_train_timesteps - 1) for timestep in scheduler.timesteps])
print('Previous Timestep\tCurrent timesteps')
for prev, cur in zip(previous_timesteps, scheduler.timesteps):
    print(f'{prev}\t\t\t{cur}')

In [None]:
## TEST WITH https://github.com/huggingface/diffusers/issues/10695 modification: seems wrong (both with DDPM and DDIM indexing)

# plot_sdfs(diffusers_sample(
#     ddim_scheduler,
#     model1, 
#     x_t = diffusers_inverse_sample(
#         ddim_inverse_scheduler,
#         model1,
#         x_t=out1_diffusers,
#         device=device,
#         plot_debug=True,
#         log_every_t=20,
#         ddpm_indexing=True,
#     ).prev_sample,
#     device=device,
#     plot_debug=True,
#     log_every_t=20,
#     ddpm_indexing=True
# ).prev_sample)

In [None]:
plot_sdfs(diffusers_sample(
    ddim_scheduler,
    model1, 
    x_t = diffusers_inverse_sample(
        ddim_inverse_scheduler,
        model1,
        x_t=out1_diffusers.to(device),
        device=device,
        plot_debug=True,
        log_every_t=20,
        ddpm_indexing=False,
    ).prev_sample,
    device=device,
    plot_debug=True,
    log_every_t=20,
    ddpm_indexing=False,
).prev_sample)

In [None]:
plot_sdfs(diffusers_sample(
    ddim_scheduler,
    model1, 
    x_t = diffusers_inverse_sample(
        ddim_inverse_scheduler,
        model1,
        x_t=out1_diffusers.to(device),
        device=device,
        plot_debug=True,
        log_every_t=5,
        ddpm_indexing=False,
        to_t_idx=10,
    ).prev_sample,
    device=device,
    plot_debug=True,
    log_every_t=5,
    ddpm_indexing=False,
    from_t_idx=-10,
).prev_sample)

In [None]:
seed_everything(40)
out1 = diffusers_sample(ddim_scheduler, model1, shape=(2, 1, 32, 32, 32), device=device, ddpm_indexing=True).prev_sample
out1_diffusers = out1
plot_sdfs(list(out1), title="Diffusers sampling -- using ddpm1_sampler betas") 
out1.shape

In [None]:
th.equal(ddim_scheduler.alphas_cumprod, ddim_inverse_scheduler.alphas_cumprod)

In [None]:
th.equal(ddim_scheduler.betas.cpu(), ddpm_sampler1.betas.cpu())

In [None]:
th.equal(ddim_scheduler.alphas_cumprod.cpu(), ddpm_sampler1.alphas_cumprod.cpu()), len(ddim_scheduler.alphas_cumprod.cpu()), len(ddpm_sampler1.alphas_cumprod.cpu())

In [None]:
th.norm(ddim_scheduler.alphas_cumprod.cpu() - ddpm_sampler1.alphas_cumprod.cpu()), max(ddim_scheduler.alphas_cumprod.cpu() - ddpm_sampler1.alphas_cumprod.cpu())

In [None]:
ddim_scheduler.alphas_cumprod.cpu().dtype, ddpm_sampler1.alphas_cumprod.cpu().dtype

In [None]:
plot_sdfs(diffusers_sample(
    ddim_scheduler,
    model1, 
    x_t = diffusers_inverse_sample(
        ddim_inverse_scheduler,
        model1,
        x_t=out1_diffusers.to(device),
        device=device,
        plot_debug=True,
        log_every_t=20,
        ddpm_indexing=True,
    ).prev_sample,
    device=device,
    plot_debug=True,
    log_every_t=20,
    ddpm_indexing=True,
).prev_sample)

In [None]:
plot_sdfs(diffusers_sample(
    ddim_scheduler,
    model1, 
    x_t = diffusers_inverse_sample(
        ddim_inverse_scheduler,
        model1,
        x_t=out1_diffusers.to(device),
        device=device,
        plot_debug=True,
        log_every_t=5,
        ddpm_indexing=True,
        to_t_idx=10,
    ).prev_sample,
    device=device,
    plot_debug=True,
    log_every_t=5,
    ddpm_indexing=True,
    from_t_idx=-10,
).prev_sample)

In [None]:
# # Training mode doesn't affect the inversion -- as expected
# model1.train()
# print(model1.training)
# seed_everything(40)
# out1 = ddpm_sampler1.sample_ddim(model1, shape=(2, 1, 32, 32, 32), show_pbar=True)
# plot_sdfs(list(out1))
# plot_sdfs(ddpm_sampler1.sample_ddim(
#     model1, 
#     x_t = ddpm_sampler1.invert_ddim(
#         model1,
#         x_t=out1,
#         show_pbar=True
#     ),
#     show_pbar=True
# ))


# model1.eval()
# print(model1.training)

In [None]:
from src.utils.utils import seed_everything
seed_everything(40)

In [None]:
out1 = ddpm_sampler1.sample_ddim(model1, shape=(2, 1, 32, 32, 32), show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=False)
out1.shape

In [None]:
seed_everything(40)
out1 = ddpm_sampler1.sample_ddim(model1, shape=(2, 1, 32, 32, 32), show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=True)
out1.shape

Seems like ddpm_indexing=True yields better samples, but let's investigate further...

In [None]:
from src.utils.utils import seed_everything
seed_everything(40)
five_samples_ddpm_indexing = ddpm_sampler1.sample_ddim(model1, shape=(5, 1, 32, 32, 32), show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=True)
plot_sdfs(list(five_samples_ddpm_indexing), title="Samples generation (with DDPM indexing)")
plot_sdfs(list(
    ddpm_sampler1.sample_ddim(
        model1,
        x_t=ddpm_sampler1.invert_ddim(model1, x_t=five_samples_ddpm_indexing, show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=True),
        show_pbar=True,
        debug_plot=True, log_every_t=20,
        ddpm_indexing=True,
    )),
    title="Predicted samples from inversion (with DDPM indexing)")

In [None]:
from src.utils.utils import seed_everything
seed_everything(40)
five_samples_ddpm_indexing = ddpm_sampler1.sample_ddim(model1, shape=(5, 1, 32, 32, 32), show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=False)
plot_sdfs(list(five_samples_ddpm_indexing), title="Samples generation (with DDIM indexing)")
plot_sdfs(list(
    ddpm_sampler1.sample_ddim(
        model1,
        x_t=ddpm_sampler1.invert_ddim(model1, x_t=five_samples_ddpm_indexing, show_pbar=True, debug_plot=True, log_every_t=20, ddpm_indexing=False),
        show_pbar=True,
        debug_plot=True, log_every_t=20,
        ddpm_indexing=False,
    )),
    title="Predicted samples from inversion (with DDIM indexing)")

In [None]:
th.norm(out1_diffusers - out1)

In [None]:
out1 = preprocessor1.destandardize(out1)
print(out1.mean(), out1.min(), out1.max(), out1.var())
out1.shape

In [None]:
out1_std = preprocessor1.standardize(out1)
print(out1_std.mean(), out1_std.min(), out1_std.max(), out1_std.var())
out1_std.shape

In [None]:
from src.utils.vis import plot_sdfs
view_kwargs = {"azim": 30, "elev": 30, "roll": 0, "vertical_axis": "y"}
plot_sdfs(list(out1), view_kwargs=view_kwargs)

In [None]:
# save as an obj file
# for i, out in enumerate(out1):
#     save_sdf_as_mesh(f"gen32_{i}.obj", out, safe=True)

In [None]:
lr_cond = F.interpolate(out1, (64, 64, 64), mode="nearest")
lr_cond = preprocessor2.standardize(lr_cond, 0)
out2 = ddpm_sampler1.sample_ddim(lambda x, t: model2(th.cat([lr_cond, x], 1), t), shape=(out1.shape[0], 1, 64, 64, 64), show_pbar=True)

out2 = preprocessor2.destandardize(out2, 1)

#for i, out in enumerate(out2):
#    save_sdf_as_mesh(f"sr64_{i}.obj", out, safe=True)

plot_sdfs(list(out2), title="Super-resolution origianal samples")

In [None]:
# Test inversion
out1_inv = ddpm_sampler1.sample_ddim(
    model1, 
    x_t = ddpm_sampler1.invert_ddim(
        #ddim_inverse_scheduler,
        model1, 
        out1_std, 
        debug_plot=True, 
        log_every_t=10, 
        show_pbar=True,
        #device=device
    ), 
    return_intermediates=False, 
    debug_plot=True, 
    log_every_t=10, 
    show_pbar=True
)
#out1_inv = out1_invs[-1]
plot_sdfs(out1_inv)

In [None]:
out1_inv = preprocessor1.destandardize(out1_inv)
out1_inv.shape

In [None]:
# save as an obj file
# for i, out in enumerate(out1_inv):
#     save_sdf_as_mesh(f"inv_gen32_{i}.obj", out, safe=True)

In [None]:
# compute norm difference
th.norm(out1 - out1_inv)

In [None]:
plot_sdfs([out1, out1_inv], titles=["Original", "Predicted from inversion"])

After model.eval() update, the inversion always lead to the same (basic) shape. Might be related to https://github.com/CompVis/latent-diffusion/issues/136 

In [None]:
from diffusers import DDIMScheduler, DDIMInverseScheduler

prediction_type_map = {
    "x_0": "sample", 
    "eps": "epsilon"
} 

ddim_scheduler = DDIMScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)

ddim_inverse_scheduler = DDIMInverseScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_inverse_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)


def diffusers_sample(
        noise_scheduler, 
        model, 
        shape=None, 
        x_t=None, 
        from_t_idx=0,
        to_t_idx=None, 
        cond=None, 
        return_intermediates=False, 
        plot_debug=False, 
        log_every_t=5,
        device="cuda",
    ):
    """Sample from the diffusion model."""
    assert (shape is not None) != (x_t is not None), "Either shape or x_t must be provided, but not both." 

    # initialize noise
    samples = th.randn(shape[:1] + shape[2:], device=device).unsqueeze(1) if shape is not None else x_t
    shape = samples.shape

    timesteps = noise_scheduler.timesteps.to(device)[from_t_idx:to_t_idx]

    intermediates = [samples]
    #noise_levels = []
    # sample iteratively
    # as done in original SDF code
    #sqrt_alphas_cumprod_prev = th.sqrt(th.cat((noise_scheduler.alphas_cumprod[[0]], noise_scheduler.alphas_cumprod))).to(device)
    sqrt_alphas_cumprod_prev = th.sqrt(th.cat((th.tensor([1]), noise_scheduler.alphas_cumprod))).to(device) # as it has been trained with 1 at first value
    with th.no_grad():
        for i, t in enumerate(tqdm(timesteps)):
            # prev_timestep = t - noise_scheduler.config.num_train_timesteps // noise_scheduler.num_inference_steps
            # noise_level = th.sqrt(noise_scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else noise_scheduler.final_alpha_cumprod)
            #noise_levels.append(noise_level)

            # as done in original SDF code
            noise_level = sqrt_alphas_cumprod_prev[t+1] # t+1 as in ldm code 

            pred = model(
                samples, noise_level * th.ones(shape[0], device=device), c=cond
            )
            samples = noise_scheduler.step(pred, t, samples)
            if i == len(timesteps)-1:
                intermediates.append(samples)
                if plot_debug: 
                    plot_sdfs(samples.prev_sample, title=f"Denoising - timestep {t} / {timesteps[0]} (DDIM timesteps: {len(timesteps)} out of {noise_scheduler.config.num_train_timesteps})")
            elif i % log_every_t == 0:
                if return_intermediates:
                    intermediates.append(samples)
                if plot_debug: 
                    plot_sdfs(samples.prev_sample, title=f"Denoising - timestep {t} / {timesteps[0]} (DDIM timesteps: {len(timesteps)} out of {noise_scheduler.config.num_train_timesteps})")
            samples = samples.prev_sample
    if return_intermediates:
        return intermediates
    return intermediates[-1]

def diffusers_inverse_sample(
        inverse_noise_scheduler, 
        model, 
        x_t,
        from_t_idx=0,
        to_t_idx=None,
        cond=None, 
        return_intermediates=False, 
        log_every_t=5, 
        plot_debug=False,
        device="cuda", 
    ):
    """Invert a sample to noise with the diffusion model."""

    samples = x_t
    shape = samples.shape

    intermediates = [samples]
    # alphas_cumprod = inverse_noise_scheduler.alphas_cumprod.to(device)
    timesteps = inverse_noise_scheduler.timesteps.to(device)[from_t_idx:to_t_idx]
    # as done in original SDF code
    # TODO: TRY ONLY INDEXING sqrt_alphas_cumprod as next should be t
    #sqrt_alphas_cumprod_next = th.sqrt(th.cat((inverse_noise_scheduler.alphas_cumprod[1:], inverse_noise_scheduler.alphas_cumprod[[-1]]))).to(device)
    # sqrt_alphas_cumprod_next = th.sqrt(th.cat((alphas_cumprod[1:], alphas_cumprod[[-1]]))).to(device)
    # #print(len(sqrt_alphas_cumprod_next))
    # sample iteratively
    with th.no_grad():
        for i, t in enumerate(tqdm(timesteps)):
            # prev_timestep = t
            # timestep = min(t - inverse_noise_scheduler.config.num_train_timesteps // inverse_noise_scheduler.num_inference_steps, inverse_noise_scheduler.config.num_train_timesteps - 1)
            # alpha_prod_t = inverse_noise_scheduler.alphas_cumprod[timestep] if timestep >= 0 else inverse_noise_scheduler.initial_alpha_cumprod
            # alpha_prod_t_prev = inverse_noise_scheduler.alphas_cumprod[prev_timestep]
            # noise_level = th.sqrt(alpha_prod_t_prev) # TODO: NOT SURE WHAT OF THE TWO IS CORRECT (alpha_prod_t or alpha_prod_t_prev) -- it should be alpha_prod_t_prev tho 
            # #noise_level = th.sqrt(alpha_prod_t)
            #noise_level = sqrt_alphas_cumprod_next[t+1] # t+1 as in ldm code 
            noise_level = th.sqrt(inverse_noise_scheduler.alphas_cumprod[t+1]) # t+1 as in ldm code
            pred = model(
                samples, noise_level * th.ones(shape[0], device=device), c=cond
            )
            samples = inverse_noise_scheduler.step(pred, t, samples)
            if i == len(timesteps)-1:
                intermediates.append(samples)
                if plot_debug: 
                    plot_sdfs(samples.prev_sample, title=f"Noising - timestep {t} / {timesteps[-1]} (DDIM timesteps: {len(timesteps)} out of {inverse_noise_scheduler.config.num_train_timesteps})")
            elif i % log_every_t == 0:
                if return_intermediates:
                    intermediates.append(samples)
                if plot_debug: 
                    plot_sdfs(samples.prev_sample, title=f"Noising - timestep {t} / {timesteps[-1]} (DDIM timesteps: {len(timesteps)} out of {inverse_noise_scheduler.config.num_train_timesteps})")
            samples = samples.prev_sample
    if return_intermediates:
        return intermediates
    return intermediates[-1]

seed_everything(40)
out1_t = diffusers_sample(ddim_scheduler, model1, shape=(2, 1, 32, 32, 32), device=device, plot_debug=True, log_every_t=10).prev_sample
plot_sdfs(list(out1_t))
plot_sdfs(list(diffusers_sample(
    ddim_scheduler, 
    model1,
    x_t=diffusers_inverse_sample(
        ddim_inverse_scheduler, 
        model1, 
        x_t=out1_std, 
        return_intermediates=False, 
        plot_debug=True, 
        log_every_t=10, 
        device=device).prev_sample,
    plot_debug=True,
    log_every_t=10,
    device=device
).prev_sample))


In [None]:
del out1_inv

In [None]:
from src.models.diffusion import identity
from tqdm import tqdm
from src.utils.vis import plot_sdfs
    
def ddim_sample_noise_guidance(
    noise_scheduler,
    inverse_noise_scheduler,
    denoise_fn,
    x_0,
    from_t_optim_idx,
    obj_fn,
    obj_fn_args={},
    tgt_noise_level = "t_optim",
    clip_denoised=True,
    denoise_kwargs={},
    post_fn=identity,
    #return_intermediates=False,
    log_every_t=5,
    show_pbar=False,
    pbar_kwargs={},
    opt_kwargs={"lr":1e-2, "decay_fn": identity}, 
    grad_clip_value=None,
    plot_debug=False,
):        
    ddim_args = {
        "clip_denoised": clip_denoised,
        "denoise_kwargs": denoise_kwargs, 
        "post_fn": post_fn, 
        #"return_intermediates": return_intermediates, 
        "log_every_t": log_every_t, 
        "show_pbar": show_pbar, 
        "pbar_kwargs": pbar_kwargs
    }

    # Get the latent at the defined noise level
    x_t = .invert_ddim(denoise_fn, x_0, to_t_idx=from_t_optim_idx, requires_grad=False, **ddim_args)
    if plot_debug:  
        plot_sdfs(x_t, title=f"x_t inverted to timestep {self.ddim_timesteps[from_t_optim_idx]}")

    # Denoise it with guidance
    # if from_t_optim_idx < 0:
    #     idxs = range(from_t_optim_idx, 0)
    # else:
    #     idxs = range(from_t_optim_idx, len(self.ddim_timesteps))
    idxs = range(from_t_optim_idx, 0, -1)
    
    #x_t_optim = x_t.clone()
    #x_t_optim.requires_grad = True
    #optimizer = th.optim.SGD([x_t_optim], **opt_kwargs)
    with tqdm(idxs, desc="Latent optimization") as pbar_idxs:
        for i, t_idx in enumerate(pbar_idxs):
            x_t.requires_grad_(True)
            if tgt_noise_level == "t_optim":
                tgt_pred = x_t
            elif tgt_noise_level == "zero":
                tgt_pred = self.sample_ddim(denoise_fn, x_t=x_t, from_t_idx=-t_idx, requires_grad=True, **ddim_args) # Too slow and computationally and memory intensive 
            elif tgt_noise_level == "zero_pred": 
                tgt_pred = self.sample_ddim(denoise_fn, x_t=x_t, from_t_idx=-(t_idx), to_t_idx=-(t_idx-1), requires_grad=True, return_intermediates=True, **ddim_args)[-1]
                #tgt_pred = self.sample_ddim(denoise_fn, x_t=x_t, from_t_idx=-(t_optim_idx), to_t_idx=None, requires_grad=True, return_intermediates=True, **ddim_args)[1] # slow, unnecessary (tested to be the equal)
            else:
                raise ValueError("Invalid noise level: available levels are " + ["t_optim", "zero"] + ".")
            # optimizer.zero_grad(set_to_none=True)
            # loss_i = obj_fn(tgt_pred, **obj_fn_args)
            # loss_i.backward()
            # if grad_clip_value is not None:
            #     th.nn.utils.clip_grad_value_(x_t, grad_clip_value)
            # optimizer.step()

            #tgt_pred.requires_grad_(True)
            with th.enable_grad():
                loss_i = obj_fn(tgt_pred, **obj_fn_args) # tgt_pred is a function of x_t
            grad_t = th.autograd.grad(loss_i, x_t, retain_graph=False)[0] # grad of loss wrt to x_t

            x_t = self.sample_ddim(denoise_fn, x_t=x_t, from_t_idx=-(t_idx), to_t_idx=-(t_idx-1), requires_grad=False, **ddim_args) # x_{t-1}
            decay_fn = opt_kwargs.get("decay_fn", identity)
            x_t = x_t - decay_fn(i) * opt_kwargs["lr"] * grad_t
            x_t.grad = None
            loss_i.grad = None
            tgt_pred.grad = None

            pbar_idxs.set_postfix({"Loss (mean)": th.mean(loss_i).item()})

            if plot_debug:
                plot_sdfs([tgt_pred, x_t], title=f"Optimization step {i} at timestep {self.ddim_timesteps[t_idx]}", titles=[f"Target shape (target type: \"{tgt_noise_level}\")", "Optimized shape"])

    return x_t

In [None]:
# Optimization 
def volume_estimates(sdfs, dx=1., dy=1., dz=1.): 
    #inside_mask = (-sdfs) > 0
    #volume_estimates = th.sum(inside_mask.float(), dim=list(range(1, sdfs.ndim)))
    volume_estimates_activation = th.sum(StraightThroughEstimator()(-sdfs), dim=list(range(1, sdfs.ndim)))
    #assert th.allclose(volume_estimates, volume_estimates_bin_activation), f"using mask: {volume_estimates}, with activation function {volume_estimates_bin_activation}"
    return volume_estimates_activation * dx * dy * dz

class STEFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(th.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x
 

def volume_estimates_loss_fn(xs, target_volumes, max_volume=1., grad_var_reg_weight=0, tot_variation_reg_weight=0):
    input_volumes = volume_estimates(xs) / max_volume
    target_volumes = target_volumes / max_volume
    loss = th.nn.MSELoss()(input_volumes, target_volumes)
    if grad_var_reg_weight > 0:
        loss += grad_var_reg_weight * th.var(xs.grad, dim=list(range(1, xs.ndim)))
    if tot_variation_reg_weight > 0:
        loss += tot_variation_reg_weight * tot_variation(xs)
    return loss 

def tot_variation(sdfs, weight=1.):       
    tv_x = ((sdfs[:,:,1:,:,:] - sdfs[:,:,:-1,:,:]).pow(2)).sum()
    tv_y = ((sdfs[:,:,:,1:,:] - sdfs[:,:,:,:-1,:]).pow(2)).sum()    
    tv_z = ((sdfs[:,:,:,:,1:] - sdfs[:,:,:,:,:-1]).pow(2)).sum()
    return tv_x + tv_y + tv_z


In [None]:
shift = -0.062
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = 0.9
t_optim_idx = 15
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="t_optim",
    opt_kwargs={"lr":1e-2}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        #th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at the same timestep",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        #[f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 0.9
t_optim_idx = 8
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="t_optim",
    opt_kwargs={"lr":1e-2}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        #th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at the same timestep",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        #[f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = 1.27
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True,
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
xxxx = ddpm_sampler1.invert_ddim(model1, x_t=out1_std[0:1], to_t_idx=5, debug_plot=True, log_every_t=1, show_pbar=True)
ddpm_sampler1.sample_ddim(model1, x_t=xxxx, from_t_idx=-5, debug_plot=True, log_every_t=1)

In [None]:
xx = ddpm_sampler1.invert_ddim(model1, x_t=out1_std[0:1], to_t_idx=None, debug_plot=True, log_every_t=10, show_pbar=True)

ddpm_sampler1.sample_ddim(model1, x_t=xx, debug_plot=True, log_every_t=10)

In [None]:
xx = th.vstack((th.randn_like(out1[0:1]), th.randn_like(out1[0:1])))

ddpm_sampler1.sample_ddim(model1, x_t=xx, debug_plot=True, log_every_t=10)

In [None]:
from diffusers import DDIMScheduler, DDIMInverseScheduler

prediction_type_map = {
    "x_0": "sample", 
    "eps": "epsilon"
} 

ddim_scheduler = DDIMScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)

ddim_inverse_scheduler = DDIMInverseScheduler(
    num_train_timesteps=args1.ddpm.valid.params.schedule_kwargs.n_timestep,
    beta_start=args1.ddpm.valid.params.schedule_kwargs.linear_start,
    beta_end=args1.ddpm.valid.params.schedule_kwargs.linear_end,
    beta_schedule=args1.ddpm.valid.params.schedule_kwargs.schedule,
    prediction_type=prediction_type_map[args1.ddpm.valid.params.model_mean_type],
    #timestep_spacing="linspace",
    set_alpha_to_one=False
)
ddim_inverse_scheduler.set_timesteps(num_inference_steps=args1.ddpm.valid.params.schedule_kwargs.ddim_S, device=device)


def diffusers_sample(noise_scheduler, model, shape=None, x_t=None, device="cuda", cond_signal=None, cond=None, return_intermediates=False, log_every_t=5):
    """Sample from the diffusion model."""
    assert (shape is not None) != (x_t is not None), "Either shape or x_t must be provided, but not both." 

    # initialize noise
    samples = th.randn(shape, device=device) if shape is not None else x_t
    shape = samples.shape

    intermediates = [samples]
    # sample iteratively
    sqrt_alphas_cumprod_prev = th.sqrt(th.cat((noise_scheduler.alphas_cumprod[[0]], noise_scheduler.alphas_cumprod))).to(device)
    from tqdm import tqdm
    with th.no_grad():
        for i, t in enumerate(tqdm(noise_scheduler.timesteps)):
            pred = model(
                samples, sqrt_alphas_cumprod_prev[t] * th.ones(shape[0], device=device), cond_signal, cond
            )
            samples = noise_scheduler.step(pred, t, samples).prev_sample
            if return_intermediates and (i % log_every_t == 0 or i == len(noise_scheduler.timesteps)-1):
                intermediates.append(samples)
    if return_intermediates:
        return intermediates
    return samples

def diffusers_inverse_sample(noise_scheduler, model, samples, device, cond_signal=None, cond=None, return_intermediates=False, log_every_t=5, plot_debug=False):
    """Invert a sample to noise with the diffusion model."""

    intermediates = [samples]
    alphas_cumprod = noise_scheduler.alphas_cumprod.to(device)
    timesteps = noise_scheduler.timesteps.to(device)
    sqrt_alphas_cumprod_next = th.sqrt(th.cat((alphas_cumprod[1:], alphas_cumprod[[-1]]))).to(device)
    #print(len(sqrt_alphas_cumprod_next))
    # sample iteratively
    from tqdm import tqdm
    with th.no_grad():
        for i, t in enumerate(tqdm(timesteps)):
            pred = model(
                samples, sqrt_alphas_cumprod_next[t] * th.ones(samples.shape[0], device=device), cond_signal, cond
            )
            samples = noise_scheduler.step(pred, t, samples).prev_sample
            if (i % log_every_t == 0 or i == len(timesteps)-1):
                if return_intermediates:
                    intermediates.append(samples)
                if plot_debug: 
                    plot_sdfs(samples, title=f"Denoising - timestep {t} / {timesteps[-1]} (DDIM timesteps: {len(timesteps)} out of {noise_scheduler.config.num_train_timesteps})")
    if return_intermediates:
        return intermediates
    return samples


In [None]:
sample_interm = diffusers_sample(ddim_scheduler, model1, (1, 1, 32, 32, 32), device=device, return_intermediates=True, log_every_t=3)
plot_sdfs(sample_interm)

In [None]:
inverse_interm = diffusers_inverse_sample(ddim_inverse_scheduler, model1, sample_interm[-1].to(device), device=device, return_intermediates=True, plot_debug=True, log_every_t=3)
#plot_sdfs(inverse_interm)

In [None]:
sample_interm_from_inverse = diffusers_sample(ddim_scheduler, model1, x_t=inverse_interm[-1], device=device, return_intermediates=True)
plot_sdfs(sample_interm)

In [None]:
target_volume_increment = 1.27
t_optim_idx = 3
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True,
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
target_volume_increment = 0.71
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2],  
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = 0.71
t_optim_idx = 3
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2],  
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = 1.27
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1, "decay_fn": lambda x: th.exp(th.tensor(-x))}, 
    plot_debug=False
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
target_volume_increment = 0.71
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2],  
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1, "decay_fn": lambda x: th.exp(th.tensor(-x))},  
    plot_debug=False,
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = 1.27
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero_pred",
    opt_kwargs={"lr":1e-1}, 
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume of the PREDICTED $x_0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
target_volume_increment = 0.71
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2],  
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero_pred",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume of the PREDICTED $x_0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = 1.27
t_optim_idx = 8
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero_pred",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
target_volume_increment = 0.71
t_optim_idx = 8
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2],  
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero_pred",
    opt_kwargs={"lr":1e-1}, 
    plot_debug=True
)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

DONE UNTIL HERE

In [None]:
shift = -0.062
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = 1.27
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited[0:1]), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1[0:1].shape[0])], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited)[i].item():.2f}, V_{{target}}= {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1[0:1].shape[0])],
        ]
)

In [None]:
target_volume_increment = 0.9
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
shift = -0.059
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
shift = +0.055
plot_sdfs(
    [out1, out1 + shift], 
    titles=[
        [f"Original \n $V={volume_estimates(out_i).item()}$" for out_i in out1], 
        [f"Shifting the sdfs by ${shift}$ \n $V={volume_estimates(out_i_shift).item()}$ (incr.: ${((volume_estimates(out_i_shift).item() - volume_estimates(out_i).item())/volume_estimates(out_i).item()):.2f}$)" for out_i_shift, out_i in zip(out1+shift, out1)], 
])

In [None]:
target_volume_increment = -0.57
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[0:1]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[0:1], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.38
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
intermediates_debug = ddpm_sampler1.sample_ddim(model1, (3, 1, 32, 32, 32), show_pbar=True, return_intermediates=True)

In [None]:
log_every_t = 5
plot_sdfs(intermediates_debug, titles=[f"Predicted $x_0$ at DDIM timestep ${t*log_every_t}$" for t in range(len(intermediates_debug))])

In [None]:
target_volume_increment = -0.38
t_optim_idx = 10
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.38
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[1:2], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[1:2]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[1:2], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[1:2]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[1:2], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.57
t_optim_idx = 7
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[0:1]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[0:1], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.57
t_optim_idx = 8
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1e-1}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[0:1]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[0:1], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.57
t_optim_idx = 8
x_edited = ddpm_sampler1.ddim_sample_noise_guidance(
    model1, 
    x_0=out1_std[0:1], 
    obj_fn=volume_estimates_loss_fn, 
    obj_fn_args={"target_volumes": volume_estimates(out1_std[0:1]) * (1+target_volume_increment)},
    from_t_optim_idx=t_optim_idx, 
    tgt_noise_level="zero",
    opt_kwargs={"lr":1, "decay_fn": lambda x: th.exp(-x)}, 

)

#plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1[0:1], 
        preprocessor1.destandardize(x_edited), 
    ], 
    title = f"Optimization from the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_i_std).item():.2f}$)" for out1_i_std in out1_std[0:1]], 
        [f"Edited ($V_{{edit}} = {volume_estimates(x_edited_i).item():.2f}, V_{{target}}= {volume_estimates(out1_i_std).item() * (1+target_volume_increment):.2f})$" for (out1_i_std, x_edited_i) in zip(out1_std[0:1], x_edited)],
    ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 5
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment)},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    #opt_kwargs={"lr":1e-3, "weight_decay": 1e-2}, # AdamW default # Too slow with lr=1e-3
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2}, 
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 3
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-3},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[out1, preprocessor1.destandardize(x_edited), th.abs(out1- preprocessor1.destandardize(x_edited))], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
# for i, out in enumerate(preprocessor1.destandardize(x_edited)):
#     save_sdf_as_mesh(f"gen32_{i}_v40%incr_3steps.obj", out, safe=True)

# lr_cond_edit = F.interpolate(preprocessor1.destandardize(x_edited), (64, 64, 64), mode="nearest")
# lr_cond_edit = preprocessor2.standardize(lr_cond_edit, 0)
# x_edited_sr = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond_edit, x], 1), t), (out1.shape[0], 1, 64, 64, 64), show_pbar=True)

# x_edited_sr = preprocessor2.destandardize(x_edited_sr, 1)

# for i, out in enumerate(x_edited_sr):
#     save_sdf_as_mesh(f"sr64_{i}_v40%incr_3steps.obj", out, safe=True)

# plot_sdfs(
#     sdfs=[out2, x_edited_sr], 
#     title = f"Super-resolution results of optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
#     titles=[
#         [f"Original (SR) ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
#         [f"Edited (SR) ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
#         ]
# )

In [None]:
target_volume_increment = -0.61
t_optim_idx = 6
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 1e-2},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)

In [None]:
target_volume_increment = -0.61
t_optim_idx = 5
x_edited, x_t, x_t_optim = ddpm_sampler1.ddim_sample_latent_optimization(
    model1, 
    x_0=out1_std, 
    obj_fn=volume_estimates_loss_fn,
    obj_fn_args={"target_volumes": volume_estimates(out1_std) * (1+target_volume_increment), "max_volume":1},
    t_optim_idx=t_optim_idx,
    tgt_noise_level="zero",
#    max_opt_iters=50,
    loss_threshold=20,
    opt_kwargs={"lr":1e-2, "weight_decay": 0},
).values()

plot_sdfs(sdfs=[x_t, x_t_optim], title=f"Latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$", titles=["Latent to be optimized", "Optimized latent"]) 

plot_sdfs(
    sdfs=[
        out1, 
        preprocessor1.destandardize(x_edited), 
        # th.abs(out1- preprocessor1.destandardize(x_edited))
    ], 
    title = f"Optimization on the latent at $t={ddpm_sampler1.ddim_timesteps[t_optim_idx]}$ based on the volume at $t=0$",
    titles=[
        [f"Original ($V = {volume_estimates(out1_std)[i].item():.2f}$)" for i in range(out1.shape[0])], 
        [f"Edited ($V_E = {volume_estimates(x_edited)[i].item():.2f} - V_{{target}}: {volume_estimates(out1_std)[i].item() * (1+target_volume_increment):.2f})$" for i in range(out1.shape[0])],
        # [f"Absolute difference between original and edited" for i in range(out1.shape[0])]
        ]
)