In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# Check if a GPU is available
if torch.cuda.is_available():
    # Set the default device to GPU
    torch.cuda.set_device(0)  # You can specify the GPU index (0, 1, etc.) if you have multiple GPUs
else:
    print("No GPU available. Switching to CPU.")

In [None]:
import numpy as np
complex_data = np.load('../layer_array.npy')

In [None]:
complex_data = torch.Tensor(complex_data)
def transform(data):
    min_, max_ = torch.min(data, axis=1), torch.max(data, axis=1)
    data_transformed = 2 * (data.sub(min_.values[:, None])).div((max_.values - min_.values)[:, None]) - 1
    return data_transformed, min_, max_

data_transformed, min_, max_ = transform(complex_data[:2, :])
data_transformed = data_transformed

In [None]:
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

#Positional or Fourier for time embedding
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
    
class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""  
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights (frequencies) during initialization. 
        # These weights (frequencies) are fixed during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
    def forward(self, x):
        # Cosine(2 pi freq x), Sine(2 pi freq x)
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

In [None]:
#different scheduler

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

def sigmoid_beta_schedule1(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    better for images > 64x64, when used during training
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def logsnr_schedule_cosine(t, logsnr_min=-20., logsnr_max=20.):
    b = tf.math.atan(tf.exp(-0.5 * logsnr_max))
    a = tf.math.atan(tf.exp(-0.5 * logsnr_min)) - b
    return -2. * tf.math.log(tf.math.tan(a * tf.cast(t,tf.float32) + b))


def inv_logsnr_schedule_cosine(logsnr, logsnr_min=-20., logsnr_max=20.):
    b = tf.math.atan(tf.exp(-0.5 * logsnr_max))
    a = tf.math.atan(tf.exp(-0.5 * logsnr_min)) - b
    return tf.math.atan(tf.exp(-0.5 * tf.cast(logsnr,tf.float32)))/a -b/a

    

def get_logsnr_alpha_sigma(time):
    logsnr = self.logsnr_schedule_cosine(time)
    alpha = tf.sqrt(tf.math.sigmoid(logsnr))
    sigma = tf.sqrt(tf.math.sigmoid(-logsnr))
        
    return logsnr, alpha, sigma


# Simple scalar noise schedule, i.e. gamma(t) in the vdm paper:
# gamma(t) = abs(w) * t + b
class NoiseSchedule(nn.Module):

    def setup(self):
        init_bias = init_gamma_0
        init_scale = init_gamma_1 - init_gamma_0
        self.w = self.param('w', constant_init(init_scale), (1,))
        self.b = self.param('b', constant_init(init_bias), (1,))

    def __call__(self, t):
        return abs(self.w) * t + self.b

In [None]:
dataset1 = torch.Tensor(np.vstack((data_transformed[0,0:128*64],data_transformed[1,0:128*64]))).float()

In [None]:
timesteps = 200

#change the beta sc
betas = cosine_beta_schedule(timesteps)
alphas = 1 - betas
alphas_ = torch.cumprod(alphas, axis=0)
variance = 1 - alphas_
sd = torch.sqrt(variance)


def forward_process(x_start, timestep, noise=None):
    """ Diffuse the data (t == 0 means diffused for 1 step) """
    x_seq = [x_start]
    beta_start = 0.0001
    beta_end = 0.02
    noise_at_t = torch.normal(0, std=1, size=x_start.size())
    for n in range(timestep):
        x_seq.append(x_start.mul(torch.sqrt(alphas_[n])) + noise_at_t.mul(sd[n]))
    return x_seq

x_seq = forward_process(dataset1, 200)
print(len(x_seq))


In [None]:
fig, axs = plt.subplots(1, 6, figsize=(28, 3))  #cosine scheduler
for i in range(6): 
    axs[i].scatter(x_seq[int((i / 1) * 5)][0], x_seq[int((i / 1) * 5)][1], s=10);
    axs[i].set_axis_off(); axs[i].set_title('$q(\mathbf{x}_{'+str(int((i / 1) * 5))+'})$')

In [None]:
fig, axs = plt.subplots(1, 10, figsize=(28, 3))
for i in range(9):
    axs[i].scatter(x_seq[int((i / 10.0) * timesteps)][0], x_seq[int((i / 10.0) * timesteps)][1], s=10);
    axs[i].set_axis_off(); axs[i].set_title('$q(\mathbf{x}_{'+str(int((i / 10.0) * timesteps))+'})$')

In [None]:
#some reparameterization trick
import torch.nn.functional as F

alphas_prev_ = F.pad(alphas_[:-1], [1, 0], "constant", 1.0)
sigma_squared_q_t = (1 - alphas) * (1 - alphas_) / (1 - alphas_prev_)
log_sigma_squared_q_t = torch.log(1-alphas) + torch.log(1-alphas_) - torch.log(1-alphas_prev_)
sigma_squared_q_t_corrected = torch.exp(log_sigma_squared_q_t)

# how to add noise to the data
def get_noisy(batch, timestep):
    noise_at_t = torch.normal(0, std=1, size=batch.size())
    added_noise_at_t = batch.mul(torch.sqrt(alphas_[timestep])) + noise_at_t.mul(sd[timestep])
    return added_noise_at_t, noise_at_t

def recover_original(batch, timestep, noise):
    true_data = (batch.sub(noise.mul(sd[timestep]))).div(alphas_[timestep])
    return true_data
    
added_noise_at_t, noise = get_noisy(data_transformed[:2, :], 20)
plt.scatter(added_noise_at_t[0], added_noise_at_t[1])

posterior_variance = (betas) * (1 - alphas_prev_) / (1 - alphas_)

## https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L196
## https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L78
log_posterior_variance = torch.log(torch.hstack([posterior_variance[1], posterior_variance[1:]]))

# not sure why we are multiplying by 1/2 here
posterior_variance_corrected = torch.exp(log_posterior_variance)

In [None]:
plt.plot(alphas_)

In [None]:
sinusoidalPositionEmbeddings = SinusoidalPositionEmbeddings(8)
position_embeddings = sinusoidalPositionEmbeddings(torch.arange(0, timesteps))

## Denoising diffusion probabilistic models (DDPM)

In a very recent article, Ho et al. [ [ 1 ] ](#ref1) constructed over the diffusion models idea, by proposing several enhancements allowing to enhance the quality of the results. First, they proposed to rely on the following parameterization for the mean function
$$
\mathbf{\mu}_{\theta}(\mathbf{x}_{t}, t) = \frac{1}{\sqrt{\alpha_{t}}} \left( (\mathbf{x}_{t} - \frac{\beta_{t}}{\sqrt{1 - \bar{\alpha}}_{t}} \mathbf{\epsilon}_{\theta} (\mathbf{x}_{t}, t) \right) 
$$

Note that now, the model is trained at outputing directly a form of _noise_ function, which is used in the sampling process. Furthermore, the authors suggest to rather use a fixed variance function

$$
\mathbf{x}_{t-1} = \frac{1}{\sqrt{\alpha_{t}}} \left( \mathbf{x}_{t} - \frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha_{t}}}} \mathbf{\epsilon}_{\theta}(\mathbf{x}_{t}, t) \right) + \sigma_{t}\mathbf{z}
$$

This leads to a new sampling procedure for the reverse process as follows (we also quickly redefine the model to output the correct dimensionality).

In [None]:
class simplemodel(nn.Module):
    
    def __init__(self, hidden_units=128):
        super(simplemodel, self).__init__()
        # hidden_units = 128
        self.position_mlp = nn.Sequential(
          # SinusoidalPositionEmbeddings(8),
          nn.SiLU(), 
          nn.Linear(8, 2)
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(2, int(hidden_units), bias=True),
            nn.SiLU(),
            nn.Linear(int(hidden_units), int(hidden_units/4), bias=True),
            nn.SiLU(),
            nn.Linear(int(hidden_units/4), int(hidden_units/8), bias=True),
            nn.SiLU(),
            nn.Linear(int(hidden_units/8), int(hidden_units/4), bias=True),
            nn.SiLU(),
            nn.Linear(int(hidden_units/4), int(hidden_units), bias=True),
            nn.SiLU(),
            nn.Linear(int(hidden_units), 2, bias=True)
        )
        
        
    def forward(self, x, timestep=None):
        if timestep is not None:
            # using the generated embeddings for each position, instead of generating it each time
            timestep_embeddings = position_embeddings[timestep.long()]
            time_embeddings = self.position_mlp(timestep_embeddings)
            # concatenating time embeddings to input
            x = x.add(time_embeddings)
        x = self.mlp(x)
        return x

In [None]:
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod) #sd
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

In [None]:
def _batch_get_variance(self, t, prev_t):
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[ torch.clip(prev_t, min=0) ]
        alpha_prod_t_prev[ prev_t < 0 ] = torch.tensor(1.0)
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

In [None]:
n_steps=200
def noise_estimation_loss(model, x_0):
    batch_size = x_0.shape[0]
    device = x_0.device  # Get the device of x_0

    t = torch.randint(0, n_steps, size=(batch_size // 2 + 1,), device=device)
    t = torch.cat([t, n_steps - t - 1], dim=0)[:batch_size].long()

    a = extract(alphas_bar_sqrt, t, x_0)
    am1 = extract(one_minus_alphas_bar_sqrt, t, x_0)
    e = torch.randn_like(x_0, device=device)

    x = x_0 * a + e * am1
    output = model(x, t)
    return (e - output).square().mean()

In [None]:
#from ddim
class EMA(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        module_copy = type(module)(module.config).to(module.config.device)
        module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

        
        
#alternative implementation for hyperparameter
#Exponential Moving Average it's a technique used to make results better and more stable training. 
#It works by keeping a copy of the model weights of the previous iteration and updating the current iteration weights by a factor of (1-beta). 
class EMA1:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0


    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)


    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new


    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1


    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())



In [None]:
def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)

def p_sample(model, x, t):
    t = torch.tensor([t])
    # Factor to the model output
    eps_factor = ((1 - extract(alphas, t, x)) / extract(one_minus_alphas_bar_sqrt, t, x))
    # Model output
    eps_theta = model(x, t)
    # Final values
    mean = (1 / extract(alphas, t, x).sqrt()) * (x - (eps_factor * eps_theta))
    # Generate z
    z = torch.randn_like(x)
    # Fixed sigma
    sigma_t = extract(betas, t, x).sqrt()
    sample = mean + sigma_t * z
    return (sample)

def p_sample_loop(model, shape):
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i)
        x_seq.append(cur_x)
    return x_seq

In [None]:
model = simplemodel()
#model = UNetWithTimestep(input_channels, output_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dataset = dataset1.T

# Create EMA model
#ema = EMA(0.9)
#ema.register(model)

#loss_fn = nn.MSELoss()

# Batch size
batch_size = 64
losses = []
for t in range(200):
    # X is a torch Variable
    permutation = torch.randperm(dataset.size()[0])
    for i in range(0, dataset.size()[0], batch_size):
        # Retrieve current batch
        indices = permutation[i:i+batch_size]
        
        batch_x = dataset[indices]

        # Compute the loss.
        
        loss = noise_estimation_loss(model, batch_x)
        # Before the backward pass, zero all of the network gradients
        optimizer.zero_grad()
        # Backward pass: compute gradient of the loss with respect to parameters
        loss.backward()
        # Perform gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        # Calling the step function to update the parameters
        optimizer.step()
        # Update the exponential moving average
        #ema.update(model)
        losses.append(loss)
    model_output = []
    # Print loss
    if (t % 20 == 0):
        print(loss)
        model_output.append(model(torch.randn(dataset.shape), torch.tensor(t)))
        x_seq = p_sample_loop(model, dataset.shape)
        print('{}_th iter pass'.format(t))
        fig, axs = plt.subplots(1, 10, figsize=(28, 3))
        for i in range(1, 11):
            cur_x = x_seq[i * 20].detach()
            axs[i-1].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
            axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*20)+'})$')

Some other model

In [None]:
def DDIM(model_output,timesteps,num_trainstep,num_inferencestep,prediction_type = "epsilon", eta = 0.0):
    t = timesteps
    prev_t = t - self.config.num_train_timesteps // self.num_inference_steps
    alpha_prod_t = self.alphas_cumprod[t]
    alpha_prod_t_prev = self.alphas_cumprod[ torch.clip(prev_t, min=0) ]
    alpha_prod_t_prev[ prev_t < 0 ] = torch.tensor(1.0)

    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    
    if prediction_type == "epsilon":
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
        pred_epsilon = model_output
    elif prediction_type == "sample":
        pred_original_sample = model_output
        pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
    elif prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
        pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
    else:
        raise ValueError(
            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
            " `v_prediction`"
            )
    
    
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    std_dev_t = eta * variance ** (0.5)
    #compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
    #compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
    
    
    return prev_sample, pred_sample_direction, std_dev_t

EDM_sampler

In [None]:
def edm_sampler( self, x, E, sample_algo = 'euler', randn_like=torch.randn_like, num_steps=400, sigma_min=0.002, sigma_max=10, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,):
    # Adjust noise levels based on what's supported by the network.

    #sigma_min = max(sigma_min, net.sigma_min)
    #sigma_max = min(sigma_max, net.sigma_max)

    gen_size = x.shape[0]

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float32, device=x.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = x.to(torch.float32) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = torch.as_tensor(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step.

        t_hat_full = torch.full((gen_size,), t_hat, device=x.device)
        denoised = self.denoise(x_hat, E, t_hat_full).to(torch.float32)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur



        # Apply 2nd order correction.
        if (sample_algo == 'edm') and (i < num_steps - 1):
            t_next_full = torch.full((gen_size,), t_next, device=x.device)
            denoised = self.denoise(x_next, E, t_next_full).to(torch.float32)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next

dpm-solver, need some special data form? Or is it properly used to density estimation task?

In [None]:
#dpm-solver
class NoiseScheduleVP:
    def __init__(
        self,
        schedule="discrete",
        betas=None,
        alphas_cumprod=None,
        continuous_beta_0=0.1,
        continuous_beta_1=20.0,
        dtype=torch.float32,
    ):

        if schedule not in ["discrete", "linear"]:
            raise ValueError(
                "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)
            )

        self.schedule = schedule
        if schedule == "discrete":
            if betas is not None:
                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
            else:
                assert alphas_cumprod is not None
                log_alphas = 0.5 * torch.log(alphas_cumprod)
            self.T = 1.0
            self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1)).to(dtype=dtype)
            self.total_N = self.log_alpha_array.shape[1]
            self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
        else:
            self.T = 1.0
            self.total_N = 1000
            self.beta_0 = continuous_beta_0
            self.beta_1 = continuous_beta_1
    
    def numerical_clip_alpha(log_alphas, clipped_lambda=-5.1):
        """
        For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. 
        We clip the log-SNR near t=T within -5.1 to ensure the stability.
        Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
        """
        log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
        lambs = log_alphas - log_sigmas  
        idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
        if idx > 0:
            log_alphas = log_alphas[:-idx]
        return log_alphas

    def marginal_log_mean_coeff(self, t):
        """
        Compute log(alpha_t) of a given continuous-time label t in [0, T].
        """
        if self.schedule == "discrete":
            return interpolate_fn(
                t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
            ).reshape((-1))
        elif self.schedule == "linear":
            return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0

    def marginal_alpha(self, t):
        """
        Compute alpha_t of a given continuous-time label t in [0, T].
        """
        return torch.exp(self.marginal_log_mean_coeff(t))

    def marginal_std(self, t):
        """
        Compute sigma_t of a given continuous-time label t in [0, T].
        """
        return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))

    def marginal_lambda(self, t):
        """
        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
        """
        log_mean_coeff = self.marginal_log_mean_coeff(t)
        log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
        return log_mean_coeff - log_std

    def inverse_lambda(self, lamb):
        """
        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
        """
        if self.schedule == "linear":
            tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
            Delta = self.beta_0**2 + tmp
            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
        elif self.schedule == "discrete":
            log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
            t = interpolate_fn(
                log_alpha.reshape((-1, 1)),
                torch.flip(self.log_alpha_array.to(lamb.device), [1]),
                torch.flip(self.t_array.to(lamb.device), [1]),
            )
            return t.reshape((-1,))

In [None]:
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_utils import SchedulerOutput

class ParaDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
    self.lambda_t = self.lambda_t.to(model_output.device)
    self.alpha_t = self.alpha_t.to(model_output.device)
    self.sigma_t = self.sigma_t.to(model_output.device)

    t = timesteps
    matches = (self.timesteps[None, :] == t[:, None])
    edgecases = ~matches.any(dim=1)
    step_index = torch.argmax(matches.int(), dim=1)
    step_index[edgecases] = len(self.timesteps) - 1 # if no match, then set to len(self.timesteps) - 1
    edgecases = (step_index == len(self.timesteps) - 1)

    prev_t = self.timesteps[ torch.clip(step_index+1, max=len(self.timesteps) - 1) ]
    prev_t[edgecases] = 0

    t = t.view(-1, *([1]*(model_output.ndim - 1)))
    prev_t = prev_t.view(-1, *([1]*(model_output.ndim - 1)))

    model_output = self.convert_model_output(model_output, t, sample)
    model_output = model_output.clamp(-1, 1) # important


    if self.config.solver_order == 1 or len(t) == 1:
        prev_sample = self.dpm_solver_first_order_update(model_output, t, prev_t, sample)
    elif self.config.solver_order == 2 or len(t) == 2:
        # first element in batch must do first_order_update
        prev_sample1 = self.dpm_solver_first_order_update(model_output[:1], t[:1], prev_t[:1], sample[:1])

        model_outputs_list = [model_output[:-1], model_output[1:]]
        timestep_list = [t[:-1], t[1:]]
        prev_sample2 = self.multistep_dpm_solver_second_order_update(
            model_outputs_list, timestep_list, prev_t[1:], sample[1:]
        )

        prev_sample = torch.cat([prev_sample1, prev_sample2], dim=0)
    else:
        # first element in batch must do first_order_update
        prev_sample1 = self.dpm_solver_first_order_update(model_output[:1], t[:1], prev_t[:1], sample[:1])

        # second element in batch must do second_order update
        model_outputs_list = [model_output[:1], model_output[1:2]]
        timestep_list = [t[:1], t[1:2]]
        prev_sample2 = self.multistep_dpm_solver_second_order_update(
            model_outputs_list, timestep_list, prev_t[1:2], sample[1:2]
        )

        model_outputs_list = [model_output[:-2], model_output[1:-1], model_output[2:]]
        timestep_list = [t[:-2], t[1:-1], t[2:]]
        prev_sample3 = self.multistep_dpm_solver_third_order_update(
            model_outputs_list, timestep_list, prev_t[2:], sample[2:]
        )

        prev_sample = torch.cat([prev_sample1, prev_sample2, prev_sample3], dim=0)

    # doing this otherwise set_timesteps throws an error
    # if worried about efficiency, can override the set_timesteps function
    self.lambda_t = self.lambda_t.to('cpu')

    return prev_sample