In [26]:
import torch
import numpy as np
from functools import partial
import sys
from einops import rearrange

sys.path.append("../../../..") #four levels up from script's current location (Instead of modifying sys.path dynamically, consider using absolute paths or configuring the PYTHONPATH environment variable.)

from ConditionalDiffusionGeneration.src.guided_diffusion.unet import create_model
from ConditionalDiffusionGeneration.src.guided_diffusion.condition_methods import get_conditioning_method
from ConditionalDiffusionGeneration.src.guided_diffusion.measurements import get_noise, get_operator
from ConditionalDiffusionGeneration.src.guided_diffusion.gaussian_diffusion import create_sampler
from ConditionalNeuralField.cnf.inference_function import ReconstructFrame, decoder

if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"

device = torch.device(dev)
print(f"Running on device: {device}")

torch.manual_seed(42)
np.random.seed(42)

# Conditioning Data Loader
#no_of_sensors = 10  # 1, 10, 100, 1000
#true_measurement = torch.from_numpy(np.load(f'input_case3/random_sensor/{no_of_sensors}/measures.npy')).to(device)
#true_measurement = true_measurement[:, :, :2]  # Keep only u, v
#print(f"Updated true_measurement.shape = {true_measurement.shape}")  # Should be [384, 10, 2]
#print(f"DEBUG: Original true_measurement.shape = {true_measurement.shape}")

#print(true_measurement[0])  # Print first time step to inspect values

# Load trained unconditional model
u_net_model = create_model(
    image_size=256,
    num_channels=128,
    num_res_blocks=2,
    channel_mult="",
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="32,16,8",
    model_path='./input_case3/diff_model/ema_0.9999_340000.pt'
)

u_net_model.to(device)
u_net_model.eval()

noiser = get_noise(sigma=0.0, name='gaussian')

mask = torch.ones_like(true_measurement, device=device)

# Sampler
sampler = create_sampler(
    sampler='ddpm',
    steps=1000,
    noise_schedule="cosine",
    model_mean_type="epsilon",
    model_var_type="fixed_large",
    dynamic_threshold=False,
    clip_denoised=True,
    rescale_timesteps=False,
    timestep_respacing=""
)

def measurement_cond_fn(x_t, t, measurement, noisy_measurement=None, x_prev=None, x_0_hat=None):
    """Dummy conditioning function for unconditional sampling."""
    print(f"DEBUG: measurement_cond_fn called with t={t}")  # Debugging
    return torch.zeros_like(x_t, device=device), None  # Return zero tensor and None for distance

print(f"DEBUG: Checking function signature for measurement_cond_fn: {measurement_cond_fn.__code__.co_varnames}")

def sample_fn_unconditional(x_start):
    return sampler.p_sample_loop(
        model=u_net_model,
        x_start=x_start,
        measurement=torch.zeros_like(x_start, device=device),  # Dummy measurement tensor
        measurement_cond_fn=measurement_cond_fn,  # Updated function with all expected arguments
        record=False,
        save_root=None
    )

# Generate Unconditional Samples
no_of_samples = 10  # Number of realizations
time_length = 256   # Adjust for 2D periodic hill case
latent_size = 256   # Ensure compatibility with 2D setup

x_start = torch.randn(no_of_samples, 1, time_length, latent_size, device=device)  # Random latent noise

# Use fixed unconditional sample function
samples = [sample_fn_unconditional(x_start[i:i+1]) for i in range(x_start.shape[0])]

gen_latents = torch.cat(samples)
print(f"Generated latents shape: {gen_latents.shape}")

Running on device: cpu
DEBUG: Checking function signature for measurement_cond_fn: ('x_t', 't', 'measurement', 'noisy_measurement', 'x_prev', 'x_0_hat')


  0%|          | 0/1000 [00:00<?, ?it/s]

TypeError: measurement_cond_fn() missing 1 required positional argument: 't'