# DDPM forward and reverse process implementation
Paper: https://arxiv.org/abs/2006.11239

## Forward Process
Fowrad process adss noise to the data


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import time
from data.dataset import BEVFeaturesDataset, PaddDataset
from torch.utils.data import DataLoader, Dataset
import wandb
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [4]:
!CUDA_LAUNCH_BLOCKING=1

In [5]:
import math
import torch
from torch import nn
from inspect import isfunction

class UNet(nn.Module):
    def __init__(
        self,
        in_channel=6,
        out_channel=3,
        inner_channel=32,
        norm_groups=32,
        channel_mults=(1, 2, 4, 8, 8),
        attn_res=(8),
        res_blocks=3,
        dropout=0,
        with_noise_level_emb=True,
        image_size=128,
        eps=1e-5
    ):
        super().__init__()

        if with_noise_level_emb:
            noise_level_channel = inner_channel
            self.noise_level_mlp = nn.Sequential(
                PositionalEncoding(inner_channel),
                nn.Linear(inner_channel, inner_channel * 4),
                Swish(),
                nn.Linear(inner_channel * 4, inner_channel)
            )
        else:
            noise_level_channel = None
            self.noise_level_mlp = None
        
        self.image_size = image_size
        self.in_channel = in_channel
        self.out_channel = out_channel

        num_mults = len(channel_mults)
        pre_channel = inner_channel
        feat_channels = [pre_channel]
        now_res = image_size
        downs = [nn.Conv2d(in_channel, inner_channel,
                           kernel_size=3, padding=1)]
        for ind in range(num_mults):
            is_last = (ind == num_mults - 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks):
                downs.append(ResnetBlocWithAttn(
                    pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, eps=eps))
                feat_channels.append(channel_mult)
                pre_channel = channel_mult
            if not is_last:
                downs.append(Downsample(pre_channel))
                feat_channels.append(pre_channel)
                now_res = now_res//2
        self.downs = nn.ModuleList(downs)

        self.mid = nn.ModuleList([
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=True, eps=eps),
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=False, eps=eps)
        ])

        ups = []
        for ind in reversed(range(num_mults)):
            is_last = (ind < 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks+1):
                ups.append(ResnetBlocWithAttn(
                    pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                        dropout=dropout, with_attn=use_attn, eps=eps))
                pre_channel = channel_mult
            if not is_last:
                ups.append(Upsample(pre_channel))
                now_res = now_res*2

        self.ups = nn.ModuleList(ups)

        self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups, eps=eps)

    def forward(self, x, time):
        t = self.noise_level_mlp(time) if exists(
            self.noise_level_mlp) else None

        feats = []
        for layer in self.downs:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)
            feats.append(x)

        for layer in self.mid:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)

        for layer in self.ups:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(torch.cat((x, feats.pop()), dim=1), t)
            else:
                x = layer(x)

        return self.final_conv(x)


# PositionalEncoding Source： https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
class PositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, noise_level):
        count = self.dim // 2
        step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
        encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
        encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
        return encoding


class FeatureWiseAffine(nn.Module):
    def __init__(self, in_channels, out_channels, use_affine_level=False):
        super(FeatureWiseAffine, self).__init__()
        self.use_affine_level = use_affine_level
        self.noise_func = nn.Sequential(
            nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
        )

    def forward(self, x, noise_embed):
        batch = x.shape[0]
        if self.use_affine_level:
            gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
            x = (1 + gamma) * x + beta
        else:
            x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
        return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Upsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        return self.conv(self.up(x))


class Downsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


# building block modules


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=32, dropout=0, eps=1e-5):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(groups, dim, eps=eps),
            Swish(),
            nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
            nn.Conv2d(dim, dim_out, 3, padding=1)
        )

    def forward(self, x):
        return self.block(x)


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32, eps=1e-5):
        super().__init__()
        self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)

        self.block1 = Block(dim, dim_out, groups=norm_groups, eps=eps)
        self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout, eps=eps)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb):
        b, c, h, w = x.shape
        h = self.block1(x)
        h = self.noise_func(h, time_emb)
        h = self.block2(h)
        return h + self.res_conv(x)


class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1, norm_groups=32, eps=1e-5):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(norm_groups, in_channel, eps=eps)
        self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
        self.out = nn.Conv2d(in_channel, in_channel, 1)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input


class ResnetBlocWithAttn(nn.Module):
    def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, eps=1e-5):
        super().__init__()
        self.with_attn = with_attn
        self.res_block = ResnetBlock(
            dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout, eps=eps)
        if with_attn:
            self.attn = SelfAttention(dim_out, norm_groups=norm_groups, eps=eps)

    def forward(self, x, time_emb):
        x = self.res_block(x, time_emb)
        if(self.with_attn):
            x = self.attn(x)
        return x


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

In [6]:
from numpy import mean, var
from functools import partial


def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
    betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    warmup_time = int(n_timestep * warmup_frac)
    betas[:warmup_time] = np.linspace(
        linear_start, linear_end, warmup_time, dtype=np.float64)
    return betas

def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3):
    """
    Create a beta schedule that is a function of the number of diffusion steps.
    Return:
        betas: a numpy array of shape (n_timestep,) that defines the beta schedule
    """
    if schedule == 'quad':
        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
                            n_timestep) ** 2
    elif schedule == 'linear':
        betas = np.linspace(linear_start, linear_end,
                            n_timestep)
    elif schedule == 'warmup10':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.1)
    elif schedule == 'warmup50':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.5)
    elif schedule == 'const':
        betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    elif schedule == 'jsd':  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1. / np.linspace(n_timestep,
                                 1, n_timestep, dtype=np.float64)
    elif schedule == "cosine":
        timesteps = (
            torch.arange(n_timestep + 1, dtype=torch.float64) /
            n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * math.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = betas.clamp(max=0.999)
    else:
        raise NotImplementedError(schedule)
    return torch.from_numpy(betas) if betas.type == np.ndarray else betas


class DenoiseDiffusion(nn.Module):
    def __init__(self, eps_model, beta_schedule, loss_fn=nn.L1Loss()):
        super().__init__()
        # Parameters for training
        self.loss_fn = loss_fn
        self.eps_model = eps_model
        self.beta_schedule = beta_schedule


        # Parameters for diffusion process         
    def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'):
        self.n_steps = self.beta_schedule[phase]['n_timestep']
        to_torch = partial(torch.as_tensor, dtype=torch.float32, device=device)

        betas = make_beta_schedule(**self.beta_schedule[phase])
        # self.betas = beta.type(dtype=torch.float32).to(self.eps_model.device)
        alphas = 1. - betas
        gammas = torch.cumprod(alphas, dim=0)
        sigmas = torch.sqrt(1.0 - torch.pow(alphas, 2))
        lambdas = torch.log(alphas / sigmas)

        self.register_buffer("betas", to_torch(betas))
        self.register_buffer("alphas", to_torch(alphas))
        self.register_buffer("gammas", to_torch(gammas))
        self.register_buffer("sigmas", to_torch(sigmas))
        self.register_buffer("lambdas", to_torch(lambdas))


    # def to(self, device):
        # self.eps_model = self.eps_model.to(device)

    def gather(self, tensor, t):
        """
        Gather the values of x at the time steps t.
        Makes it compatible with the shape of x0, which is (B, C, H, W).
        Args:
            tensor: a tensor of shape (n_steps,)
            t: a tensor of shape (B,)
        Return:
            a tensor of shape (B, 1, 1, 1) that contains the values of x at the time steps t
        """
        t = tensor.gather(-1, t)
        return t.reshape(-1, 1, 1, 1)
    
        # We need a function that samples the batch 
    def q_sample(self, y0, sample_gammas, noise=None):
        """
        Sample from q(yt|y0), reading same as sample xt at step t given x0.
        Other implementations also use function q_xt_x0 first but we can directly implement it here.
        Args:
            y0: the original data, shape (B, C, H, W)
            sample_gammas: the gamma values for sampling, shape (B,)
            noise: the noise, shape (B, C, H, W)
        Return:
            yt: the noisy data at time step t, shape (B, C, H, W)
        """
        eps = torch.randn_like(y0, device=y0.device) if noise is None else noise
        
        return (
            torch.sqrt(sample_gammas) * y0 + torch.sqrt(1 - sample_gammas) * eps
        )
    

    def forward(self, y0, y_cond=None):
        """
        Algorithm 1 in Denoising Diffusion Probalisitic Models

        Args:
            y0: the original data, shape (B, C, H, W)
        """
        b, *_ = y0.shape

        t = torch.randint(1, self.n_steps, (b,), device=y0.device, dtype=torch.long)
        # Select a random gamma for each sample in the batch, which is between gamma_t and gamma_t-1 of generated timesteps t. This is to make the training more stable and avoid overfitting to specific timesteps.
        gamma_t1 = self.gather(self.gammas, t - 1)
        gamma_t2 = self.gather(self.gammas, t)
        sample_gammas = (gamma_t2 - gamma_t1) * torch.rand((b, 1, 1, 1), device=y0.device) + gamma_t1
        sample_gammas = sample_gammas

        # Create the noise to compare it to the predicted noise, which is used for training the model. This is the noise added to the original data to get the noisy data at time step t.
        noise = torch.randn_like(y0, device=y0.device)
        y_noisy = self.q_sample(y0, sample_gammas, noise=noise)

        noise_hat = self.eps_model(torch.cat([y_noisy, y_cond], dim=1) if y_cond is not None else y_noisy, sample_gammas)

        loss = self.loss_fn(noise_hat, noise)
        return loss
    

    # Samplers

    @torch.no_grad()
    def ddpm_sampler(self, n_samples, y_cond=None, sample_inter=10, clip_denoised=True):
        """
        https://arxiv.org/pdf/2006.11239
        Implementation of algorithm 2. However, to keep sampling stable we calculate the start from noise, clamp it and use the posterior of equation 7 to calculate y_t-1.
        We use equation 15 to calculate y_0 (start from noise), then clamp it. Then we use equation 7 to calculate y_t-1 = mean + sigma * z
        """
        y = torch.randn(n_samples, self.eps_model.out_channel, self.eps_model.image_size, self.eps_model.image_size, device=y_cond.device)
        ret_arr = y.clone()
        for i in tqdm(reversed(range(self.n_steps)), desc='DDPM sampler', total=self.n_steps):
            z = torch.randn_like(y) if i > 1 else torch.zeros_like(y)   
            t_tensor = torch.full((n_samples,), i, device=y_cond.device, dtype=torch.long)

            gamma_t = self.gather(self.gammas, t_tensor)
            gamma_t_prev = self.gather(self.gammas, t_tensor - 1) if i > 0 else torch.ones_like(gamma_t)
            beta_t = self.gather(self.betas, t_tensor)
            alpha_t = self.gather(self.alphas, t_tensor)

            y_0_tilde = (y - torch.sqrt(1-gamma_t)*self.eps_model(torch.cat([y, y_cond], dim=1) if y_cond is not None else y, gamma_t)) / torch.sqrt(gamma_t) # predict start

            if clip_denoised:
                y_0_tilde = torch.clamp(y_0_tilde, -1., 1.)

            # eta = (y - torch.sqrt(alpha_t)*y_0_tilde) / torch.sqrt(1-alpha_t)

            # Calculate posterior mean and variance
            mean = (torch.sqrt(gamma_t_prev) * beta_t * y_0_tilde) / (1 - gamma_t) + (torch.sqrt(alpha_t) * (1 - gamma_t_prev) * y) / (1 - gamma_t)

            sigma = (1 - gamma_t_prev) * beta_t / (1 - gamma_t)

            y = mean + sigma * z

            if i & sample_inter ==0:
                ret_arr = torch.cat((ret_arr, y), dim=0)

        return y, ret_arr
    

    @torch.no_grad()
    def ddim_sampler(self, y_cond=None, noise=None, sample_inter=10, steps=50, clip_denoised=True, eta=0.0):
        """
        DDIM sampler from https://arxiv.org/abs/2010.02502
        With eta=0, it becomes a deterministic sampler, which is the one we will use in this implementation. With eta>0, it becomes a stochastic sampler, which is similar to the DDPM sampler but with different noise scale. 
        """
        y = torch.randn_like(y_cond, device=y_cond.device) if noise is None else noise
        ret_arr = y.clone()
        step_size = self.n_steps // steps

        for i in tqdm(range(steps), desc='DDIM sampling loop timestep', total=steps):
            t = self.n_steps - i * step_size
            t_tensor = torch.full((y_cond.shape[0],), t, dtype=torch.long, device=y_cond.device)

            gamma = self.gather(self.gammas, t_tensor - 1)
            
            # Make sure that when t_tensor - step_size - 1 is negative, we use gamma_prev = 1, which means that we are at the final step and we should not add any noise.
            gamma_prev = self.gather(self.gammas, torch.clamp(t_tensor - step_size - 1, min=0)) if (t_tensor - step_size - 1 >= 0).any() else torch.ones_like(gamma)
            noise_pred = self.eps_model(torch.cat([y, y_cond], dim=1) if y_cond is not None else y, gamma)

            y0_pred = (y - torch.sqrt(1 - gamma) * noise_pred) / torch.sqrt(gamma)
            
            # Clamp prediction to stablize sampling
            if clip_denoised:
                y0_pred = torch.clamp(y0_pred, -1., 1.)

            sigma_t = eta * torch.sqrt((1 - gamma_prev) / (1 - gamma)) * torch.sqrt(1-gamma / gamma_prev)

            dir_yt = torch.sqrt(1 - gamma_prev - torch.pow(sigma_t, 2)) * noise_pred
            
            
            y = torch.sqrt(gamma_prev) * y0_pred + dir_yt + sigma_t * torch.randn_like(y)

        return y, ret_arr
    

    @torch.no_grad()
    def dpm_solver_multi_step_sampler(self, n_samples, y_cond=None, sample_inter=10, steps=10, clip_denoised=True):
        """
        Implement multistep from https://arxiv.org/pdf/2211.01095
        """
        step_size = self.n_steps // steps


        yT = torch.randn(n_samples, self.eps_model.out_channel, self.eps_model.image_size, self.eps_model.image_size, device=y_cond.device)
        ret_arr = yT.clone()

        ytilde = yT


        t_0 = torch.full((n_samples,), self.n_steps - 1, device=y_cond.device, dtype=torch.long)
        t_1 = torch.full((n_samples,), self.n_steps - step_size - 1, device=y_cond.device, dtype=torch.long)
        t_2 = torch.full((n_samples,), self.n_steps - 2 * step_size - 1, device=y_cond.device, dtype=torch.long)
        # print(self.lambdas.shape)
        # print(t_1)
        # print(t_0)
        h_i_prev = self.gather(self.lambdas, t_1) - self.gather(self.lambdas, t_0)

        # Buffer P and Q for multi_step sampling. P = -2 and Q = -1 in timestepe space, which means that they are the data prediction at t_i-2 and t_i-1 respectively. We will update them in each step and use them to calculate the data prediction at t_i.
        P = self.data_prediction(ytilde, y_cond=y_cond, t=t_0) # y_theta_0
        if clip_denoised:
            P = torch.clamp(P, -1., 1.)
        y_tilde = (self.gather(self.sigmas, t_1)/self.gather(self.sigmas, t_0)) * ytilde - self.gather(self.alphas, t_1) * (torch.exp(-h_i_prev) - 1) * P
        Q = self.data_prediction(y_tilde, y_cond=y_cond, t=t_1) # y_theta_2
        if clip_denoised:
            Q = torch.clamp(Q, -1, 1.)
        

        for i in tqdm((range(2, steps)), desc='DPM-Solver++(2M) sampler', initial=2, total=steps):
            t_cur = self.n_steps - i * step_size - 1
            t_prev = self.n_steps - (i - 1) * step_size - 1

            t_prev_tensor = torch.full((n_samples,), t_prev, device=y_cond.device, dtype=torch.long)
            t_cur_tensor = torch.full((n_samples,), t_cur, device=y_cond.device, dtype=torch.long)

            h_i_cur = self.gather(self.lambdas, t_cur_tensor) - self.gather(self.lambdas, t_prev_tensor)
            r_i = h_i_prev / h_i_cur

            D_i = (1 + 1 / (2 * r_i)) * Q - 1 / (2 * r_i) * P

            y_tilde = (self.gather(self.sigmas, t_cur_tensor) / self.gather(self.sigmas, t_prev_tensor)) * y_tilde - self.gather(self.alphas, t_cur_tensor) * (torch.exp(-h_i_cur) - 1) * D_i
            
            h_i_prev = h_i_cur
            P = Q.clone()
            if i < self.n_steps:
                Q = self.data_prediction(y_tilde, y_cond=y_cond, t=t_cur_tensor)
                if clip_denoised:
                    Q = torch.clamp(Q, -1., 1.)

            if i & sample_inter == 0:
                ret_arr = torch.cat((ret_arr, y_tilde), dim=0)

        return y_tilde, ret_arr
    
    def data_prediction(self, yt, y_cond=None, t=None):
        gamma = self.gather(self.gammas, t).to(yt.device)
        noise_pred = self.eps_model(torch.cat([yt, y_cond], dim=1) if y_cond is not None else yt, gamma)
        y0_hat = (yt - torch.sqrt(1 - gamma) * noise_pred) / torch.sqrt(gamma)

        return y0_hat

In [7]:
beta_schedule = dict(
    train=dict(
        schedule='cosine',
        n_timestep=2000,
        cosine_s=8e-3,
    ),
    test=dict(
        schedule='linear',
        n_timestep=1000,
        linear_start=1e-5,
        linear_end=1e-1,
    )
)

model_config = dict(
    in_channel=512,
    out_channel=256,
    inner_channel=128,
    norm_groups=32,
    channel_mults=(1, 2, 4, 8),
    attn_res=(25,),
    res_blocks=2,
    dropout=0,
    with_noise_level_emb=True,
    image_size=200,
    eps=1e-5
)

hyperparameters = dict(
    model_config=model_config,
    beta_schedule=beta_schedule,
    batch_size=6,
)

Unet = UNet(**model_config)
diffusion = DenoiseDiffusion(Unet, beta_schedule)
diffusion.set_new_noise_schedule(phase='train')

In [10]:
diffusion.to('cuda')
diffusion.ddim_sampler(y_cond=torch.randn(6, 256, 200, 200, device='cuda'), steps=50)

DDIM sampling loop timestep:   0%|          | 0/50 [00:00<?, ?it/s]

DDIM sampling loop timestep: 100%|██████████| 50/50 [00:21<00:00,  2.30it/s]


(tensor([[[[ 0.9052,  0.6056,  0.7539,  ...,  0.7089,  0.9348, -0.7608],
           [-0.4291,  0.1744,  0.9985,  ..., -0.5786,  0.6089,  0.0615],
           [-0.9999,  0.4726, -0.9195,  ..., -0.4503, -0.9907,  0.6652],
           ...,
           [ 0.2648, -0.5884, -0.0404,  ...,  0.4837, -0.5868,  0.0095],
           [ 1.0000,  0.6826, -0.7567,  ..., -0.8508, -0.3490,  0.8485],
           [-0.7334, -0.8856,  1.0000,  ..., -0.6242,  0.2219,  1.0000]],
 
          [[ 0.9595, -0.9917, -0.8975,  ...,  0.4510, -0.6107, -0.7388],
           [-0.5226,  0.1377,  0.1038,  ...,  0.6675, -0.5343, -0.9889],
           [-0.5597,  1.0000, -0.1788,  ..., -0.2632,  0.4166, -1.0000],
           ...,
           [ 0.9991,  0.0395, -0.5515,  ...,  0.8632,  0.7626,  0.5206],
           [ 0.8184, -0.4611, -0.9966,  ...,  0.9069,  0.4638,  0.6909],
           [-0.8151,  0.6977,  0.9526,  ...,  0.5655, -0.9994,  0.9947]],
 
          [[ 0.9260, -0.3743, -0.3316,  ..., -0.9889,  0.9998,  0.6551],
           [ 

In [6]:
from torchinfo import summary

summary(Unet, input_size=[(hyperparameters['batch_size'], model_config['in_channel'], model_config['image_size'], model_config['image_size']), (hyperparameters['batch_size'], 1)])

Layer (type:depth-idx)                             Output Shape              Param #
UNet                                               [6, 256, 200, 200]        --
├─Sequential: 1-1                                  [6, 1, 128]               --
│    └─PositionalEncoding: 2-1                     [6, 1, 128]               --
│    └─Linear: 2-2                                 [6, 1, 512]               66,048
│    └─Swish: 2-3                                  [6, 1, 512]               --
│    └─Linear: 2-4                                 [6, 1, 128]               65,664
├─ModuleList: 1-2                                  --                        --
│    └─Conv2d: 2-5                                 [6, 128, 200, 200]        589,952
│    └─ResnetBlocWithAttn: 2-6                     [6, 128, 200, 200]        --
│    │    └─ResnetBlock: 3-1                       [6, 128, 200, 200]        312,192
│    └─ResnetBlocWithAttn: 2-7                     [6, 128, 200, 200]        --
│    │    └─Resne

In [10]:
size = 64
square_size = 16
tensor = torch.zeros(1, 1, size, size)
start = (size - square_size) // 2
end = start + square_size
tensor[:, :, start:end, start:end] = 1.0

samples = tensor.repeat(100, 1, 1, 1)

In [None]:
diffusion.set_new_noise_schedule(phase='test')
sampled = diffusion.dpm_solver_multi_step_sampler(1, steps=50, clip_denoised=True)
y, ret_arr = sampled
print(y)
print(f"Max: {torch.amax(y)}, Min: {torch.amin(y)}")
plt.imshow(y.cpu().numpy()[0, 0], cmap='viridis')
plt.colorbar()

In [7]:
from torch.utils.data import DataLoader

def load_data():
    # Load the saved data
    dataset = BEVFeaturesDataset(root_dir='/home/mingdayang/FeatureBridgeMapping/data/bev_features', transform=None)

    return dataset

def create_splits(dataset, train_split=0.8):
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(train_split * len(dataset)), len(dataset) - int(train_split * len(dataset))])

    return train_dataset, test_dataset

def make_loader(batch_size, dataset):

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        # Let's check out what we've created

    return dataloader

dataset = load_data()
train_dataset, test_dataset = create_splits(dataset)
train_loader = make_loader(hyperparameters['batch_size'], train_dataset)
test_loader = make_loader(hyperparameters['batch_size'], test_dataset)

In [8]:
# simple training loop using the existing `tensor` as toy data
device = 'cuda'
diffusion.to(device)
diffusion.eps_model.train()

optimizer = torch.optim.Adam(diffusion.eps_model.parameters(), lr=1e-4)

epochs = 300
val_interval = 10

with wandb.init(project="diffusion_test", config=hyperparameters):
    wandb.watch(diffusion.eps_model, log="all", log_freq=10)
    for epoch in range(epochs):
        diffusion.set_new_noise_schedule(phase='train')
        train_loss = 0
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
            X, y = batch['img_bev_embed'], batch['pts_bev_embed']
            X = X.to(device)
            y = y.to(device)
            loss = diffusion.forward(y0=y, y_cond=X)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print(f"Batch Loss: {loss.item()}")
        
        avg_train_loss = train_loss / len(train_loader)
        wandb.log({'train_loss': avg_train_loss}, step=epoch)

    wandb.finish()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmln9d4[0m ([33mmln9d4-tu-delft[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/300:   9%|▉         | 1/11 [00:02<00:21,  2.17s/it]

Batch Loss: 0.8441227078437805


Epoch 1/300:  18%|█▊        | 2/11 [00:04<00:17,  1.97s/it]

Batch Loss: 0.8360666632652283


Epoch 1/300:  27%|██▋       | 3/11 [00:05<00:15,  1.91s/it]

Batch Loss: 0.8235496282577515


Epoch 1/300:  36%|███▋      | 4/11 [00:07<00:13,  1.87s/it]

Batch Loss: 0.8175937533378601


Epoch 1/300:  45%|████▌     | 5/11 [00:09<00:11,  1.86s/it]

Batch Loss: 0.8118245601654053


Epoch 1/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.90s/it]

Batch Loss: 0.8065184950828552


Epoch 1/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.88s/it]

Batch Loss: 0.8035463690757751


Epoch 1/300:  73%|███████▎  | 8/11 [00:15<00:05,  1.89s/it]

Batch Loss: 0.8023022413253784


Epoch 1/300:  82%|████████▏ | 9/11 [00:17<00:03,  1.88s/it]

Batch Loss: 0.8011170625686646


Epoch 1/300:  91%|█████████ | 10/11 [00:20<00:02,  2.48s/it]

Batch Loss: 0.8000731468200684


Epoch 1/300: 100%|██████████| 11/11 [00:22<00:00,  2.03s/it]


Batch Loss: 0.7993554472923279


Epoch 2/300:   9%|▉         | 1/11 [00:01<00:18,  1.87s/it]

Batch Loss: 0.7987856268882751


Epoch 2/300:  18%|█▊        | 2/11 [00:03<00:16,  1.87s/it]

Batch Loss: 0.7984997630119324


Epoch 2/300:  27%|██▋       | 3/11 [00:05<00:15,  1.89s/it]

Batch Loss: 0.7984712719917297


Epoch 2/300:  36%|███▋      | 4/11 [00:07<00:13,  1.88s/it]

Batch Loss: 0.7982845306396484


Epoch 2/300:  45%|████▌     | 5/11 [00:09<00:11,  1.87s/it]

Batch Loss: 0.7976917028427124


Epoch 2/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.88s/it]

Batch Loss: 0.797479510307312


Epoch 2/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.86s/it]

Batch Loss: 0.7973124384880066


Epoch 2/300:  73%|███████▎  | 8/11 [00:14<00:05,  1.86s/it]

Batch Loss: 0.7970888614654541


Epoch 2/300:  82%|████████▏ | 9/11 [00:18<00:04,  2.48s/it]

Batch Loss: 0.7966421842575073


Epoch 2/300:  91%|█████████ | 10/11 [00:20<00:02,  2.30s/it]

Batch Loss: 0.7970740795135498


Epoch 2/300: 100%|██████████| 11/11 [00:22<00:00,  2.00s/it]


Batch Loss: 0.7967941761016846


Epoch 3/300:   9%|▉         | 1/11 [00:01<00:18,  1.86s/it]

Batch Loss: 0.796898365020752


Epoch 3/300:  18%|█▊        | 2/11 [00:03<00:17,  1.90s/it]

Batch Loss: 0.7962692975997925


Epoch 3/300:  27%|██▋       | 3/11 [00:05<00:15,  1.88s/it]

Batch Loss: 0.7959099411964417


Epoch 3/300:  36%|███▋      | 4/11 [00:07<00:13,  1.89s/it]

Batch Loss: 0.7960195541381836


Epoch 3/300:  45%|████▌     | 5/11 [00:09<00:11,  1.90s/it]

Batch Loss: 0.7965874075889587


Epoch 3/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.91s/it]

Batch Loss: 0.796532928943634


Epoch 3/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.87s/it]

Batch Loss: 0.7959973812103271


Epoch 3/300:  73%|███████▎  | 8/11 [00:17<00:07,  2.53s/it]

Batch Loss: 0.7970651388168335


Epoch 3/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.33s/it]

Batch Loss: 0.7953234910964966


Epoch 3/300:  91%|█████████ | 10/11 [00:20<00:02,  2.20s/it]

Batch Loss: 0.7957515716552734


Epoch 3/300: 100%|██████████| 11/11 [00:22<00:00,  2.03s/it]


Batch Loss: 0.7950561046600342


Epoch 4/300:   9%|▉         | 1/11 [00:01<00:19,  1.93s/it]

Batch Loss: 0.7952281832695007


Epoch 4/300:  18%|█▊        | 2/11 [00:03<00:17,  1.92s/it]

Batch Loss: 0.7944297790527344


Epoch 4/300:  27%|██▋       | 3/11 [00:05<00:15,  1.89s/it]

Batch Loss: 0.794302225112915


Epoch 4/300:  36%|███▋      | 4/11 [00:07<00:13,  1.92s/it]

Batch Loss: 0.7951439023017883


Epoch 4/300:  45%|████▌     | 5/11 [00:09<00:11,  1.93s/it]

Batch Loss: 0.7937213778495789


Epoch 4/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.90s/it]

Batch Loss: 0.794418215751648


Epoch 4/300:  64%|██████▎   | 7/11 [00:15<00:10,  2.62s/it]

Batch Loss: 0.7933404445648193


Epoch 4/300:  73%|███████▎  | 8/11 [00:17<00:07,  2.38s/it]

Batch Loss: 0.7926390767097473


Epoch 4/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.22s/it]

Batch Loss: 0.7943829298019409


Epoch 4/300:  91%|█████████ | 10/11 [00:21<00:02,  2.16s/it]

Batch Loss: 0.7937290668487549


Epoch 4/300: 100%|██████████| 11/11 [00:22<00:00,  2.06s/it]


Batch Loss: 0.7923153042793274


Epoch 5/300:   9%|▉         | 1/11 [00:01<00:18,  1.83s/it]

Batch Loss: 0.7924906611442566


Epoch 5/300:  18%|█▊        | 2/11 [00:03<00:16,  1.89s/it]

Batch Loss: 0.7912731766700745


Epoch 5/300:  27%|██▋       | 3/11 [00:05<00:15,  1.93s/it]

Batch Loss: 0.7920214533805847


Epoch 5/300:  36%|███▋      | 4/11 [00:07<00:13,  1.92s/it]

Batch Loss: 0.793274462223053


Epoch 5/300:  45%|████▌     | 5/11 [00:09<00:11,  1.92s/it]

Batch Loss: 0.7902165651321411


Epoch 5/300:  55%|█████▍    | 6/11 [00:13<00:13,  2.64s/it]

Batch Loss: 0.7904682159423828


Epoch 5/300:  64%|██████▎   | 7/11 [00:15<00:09,  2.41s/it]

Batch Loss: 0.7896623611450195


Epoch 5/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.26s/it]

Batch Loss: 0.7902094125747681


Epoch 5/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.17s/it]

Batch Loss: 0.7881224155426025


Epoch 5/300:  91%|█████████ | 10/11 [00:21<00:02,  2.08s/it]

Batch Loss: 0.7874919176101685


Epoch 5/300: 100%|██████████| 11/11 [00:22<00:00,  2.06s/it]


Batch Loss: 0.7898820638656616


Epoch 6/300:   9%|▉         | 1/11 [00:01<00:19,  1.96s/it]

Batch Loss: 0.7864083647727966


Epoch 6/300:  18%|█▊        | 2/11 [00:03<00:17,  1.97s/it]

Batch Loss: 0.7863025069236755


Epoch 6/300:  27%|██▋       | 3/11 [00:05<00:15,  1.93s/it]

Batch Loss: 0.7857045531272888


Epoch 6/300:  36%|███▋      | 4/11 [00:07<00:13,  1.90s/it]

Batch Loss: 0.786055326461792


Epoch 6/300:  45%|████▌     | 5/11 [00:11<00:16,  2.70s/it]

Batch Loss: 0.7828835844993591


Epoch 6/300:  55%|█████▍    | 6/11 [00:13<00:12,  2.42s/it]

Batch Loss: 0.783279299736023


Epoch 6/300:  64%|██████▎   | 7/11 [00:15<00:09,  2.26s/it]

Batch Loss: 0.7823695540428162


Epoch 6/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.14s/it]

Batch Loss: 0.7796109318733215


Epoch 6/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.05s/it]

Batch Loss: 0.7772848606109619


Epoch 6/300:  91%|█████████ | 10/11 [00:21<00:02,  2.00s/it]

Batch Loss: 0.7803148031234741


Epoch 6/300: 100%|██████████| 11/11 [00:22<00:00,  2.06s/it]


Batch Loss: 0.7775654196739197


Epoch 7/300:   9%|▉         | 1/11 [00:01<00:19,  1.91s/it]

Batch Loss: 0.7787734270095825


Epoch 7/300:  18%|█▊        | 2/11 [00:03<00:17,  1.91s/it]

Batch Loss: 0.7769889235496521


Epoch 7/300:  27%|██▋       | 3/11 [00:05<00:15,  1.98s/it]

Batch Loss: 0.7741247415542603


Epoch 7/300:  36%|███▋      | 4/11 [00:09<00:18,  2.65s/it]

Batch Loss: 0.7732566595077515


Epoch 7/300:  45%|████▌     | 5/11 [00:11<00:14,  2.42s/it]

Batch Loss: 0.778141975402832


Epoch 7/300:  55%|█████▍    | 6/11 [00:13<00:11,  2.27s/it]

Batch Loss: 0.7739723324775696


Epoch 7/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.16s/it]

Batch Loss: 0.770585834980011


Epoch 7/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.09s/it]

Batch Loss: 0.7681195139884949


Epoch 7/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.05s/it]

Batch Loss: 0.7769440412521362


Epoch 7/300:  91%|█████████ | 10/11 [00:21<00:01,  1.99s/it]

Batch Loss: 0.765667736530304


Epoch 7/300: 100%|██████████| 11/11 [00:22<00:00,  2.05s/it]


Batch Loss: 0.7651867866516113


Epoch 8/300:   9%|▉         | 1/11 [00:01<00:18,  1.89s/it]

Batch Loss: 0.7700921893119812


Epoch 8/300:  18%|█▊        | 2/11 [00:03<00:17,  1.96s/it]

Batch Loss: 0.7636364698410034


Epoch 8/300:  27%|██▋       | 3/11 [00:08<00:23,  2.95s/it]

Batch Loss: 0.7721006870269775


Epoch 8/300:  36%|███▋      | 4/11 [00:10<00:18,  2.57s/it]

Batch Loss: 0.7659868597984314


Epoch 8/300:  45%|████▌     | 5/11 [00:11<00:14,  2.34s/it]

Batch Loss: 0.759876012802124


Epoch 8/300:  55%|█████▍    | 6/11 [00:13<00:10,  2.19s/it]

Batch Loss: 0.7772618532180786


Epoch 8/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.12s/it]

Batch Loss: 0.773729681968689


Epoch 8/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.08s/it]

Batch Loss: 0.7673633098602295


Epoch 8/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.04s/it]

Batch Loss: 0.7566608190536499


Epoch 8/300:  91%|█████████ | 10/11 [00:21<00:02,  2.01s/it]

Batch Loss: 0.7590855360031128


Epoch 8/300: 100%|██████████| 11/11 [00:23<00:00,  2.10s/it]


Batch Loss: 0.7674962878227234


Epoch 9/300:   9%|▉         | 1/11 [00:01<00:18,  1.87s/it]

Batch Loss: 0.7566249370574951


Epoch 9/300:  18%|█▊        | 2/11 [00:06<00:29,  3.31s/it]

Batch Loss: 0.7626208662986755


Epoch 9/300:  27%|██▋       | 3/11 [00:08<00:21,  2.71s/it]

Batch Loss: 0.7545785307884216


Epoch 9/300:  36%|███▋      | 4/11 [00:10<00:16,  2.36s/it]

Batch Loss: 0.7577865719795227


Epoch 9/300:  45%|████▌     | 5/11 [00:11<00:13,  2.19s/it]

Batch Loss: 0.750921368598938


Epoch 9/300:  55%|█████▍    | 6/11 [00:13<00:10,  2.14s/it]

Batch Loss: 0.7588316202163696


Epoch 9/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.04s/it]

Batch Loss: 0.7528528571128845


Epoch 9/300:  73%|███████▎  | 8/11 [00:17<00:05,  1.98s/it]

Batch Loss: 0.7548947930335999


Epoch 9/300:  82%|████████▏ | 9/11 [00:19<00:03,  1.99s/it]

Batch Loss: 0.7560904026031494


Epoch 9/300:  91%|█████████ | 10/11 [00:21<00:01,  1.96s/it]

Batch Loss: 0.753282368183136


Epoch 9/300: 100%|██████████| 11/11 [00:22<00:00,  2.08s/it]


Batch Loss: 0.7482677698135376


Epoch 10/300:   9%|▉         | 1/11 [00:03<00:37,  3.78s/it]

Batch Loss: 0.7634565234184265


Epoch 10/300:  18%|█▊        | 2/11 [00:05<00:24,  2.67s/it]

Batch Loss: 0.7477062344551086


Epoch 10/300:  27%|██▋       | 3/11 [00:07<00:18,  2.32s/it]

Batch Loss: 0.7437124848365784


Epoch 10/300:  36%|███▋      | 4/11 [00:09<00:15,  2.16s/it]

Batch Loss: 0.7516869902610779


Epoch 10/300:  45%|████▌     | 5/11 [00:11<00:12,  2.05s/it]

Batch Loss: 0.7434467673301697


Epoch 10/300:  55%|█████▍    | 6/11 [00:13<00:09,  1.99s/it]

Batch Loss: 0.7491068840026855


Epoch 10/300:  64%|██████▎   | 7/11 [00:15<00:07,  1.96s/it]

Batch Loss: 0.747697114944458


Epoch 10/300:  73%|███████▎  | 8/11 [00:17<00:05,  1.96s/it]

Batch Loss: 0.7410244345664978


Epoch 10/300:  82%|████████▏ | 9/11 [00:18<00:03,  1.93s/it]

Batch Loss: 0.7403913140296936


Epoch 10/300:  91%|█████████ | 10/11 [00:20<00:01,  1.91s/it]

Batch Loss: 0.748776912689209


Epoch 10/300: 100%|██████████| 11/11 [00:24<00:00,  2.20s/it]


Batch Loss: 0.740047037601471


Epoch 11/300:   9%|▉         | 1/11 [00:01<00:18,  1.83s/it]

Batch Loss: 0.7390502691268921


Epoch 11/300:  18%|█▊        | 2/11 [00:03<00:17,  1.90s/it]

Batch Loss: 0.7471584677696228


Epoch 11/300:  27%|██▋       | 3/11 [00:05<00:15,  1.95s/it]

Batch Loss: 0.7475237250328064


Epoch 11/300:  36%|███▋      | 4/11 [00:07<00:13,  1.92s/it]

Batch Loss: 0.7471215128898621


Epoch 11/300:  45%|████▌     | 5/11 [00:09<00:11,  1.94s/it]

Batch Loss: 0.7548812627792358


Epoch 11/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.96s/it]

Batch Loss: 0.7377633452415466


Epoch 11/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.95s/it]

Batch Loss: 0.7387232780456543


Epoch 11/300:  73%|███████▎  | 8/11 [00:15<00:05,  1.95s/it]

Batch Loss: 0.7351105809211731


Epoch 11/300:  82%|████████▏ | 9/11 [00:17<00:03,  1.95s/it]

Batch Loss: 0.737082839012146


Epoch 11/300:  91%|█████████ | 10/11 [00:21<00:02,  2.57s/it]

Batch Loss: 0.7460026144981384


Epoch 11/300: 100%|██████████| 11/11 [00:22<00:00,  2.07s/it]


Batch Loss: 0.7400280833244324


Epoch 12/300:   9%|▉         | 1/11 [00:01<00:19,  1.96s/it]

Batch Loss: 0.7423542141914368


Epoch 12/300:  18%|█▊        | 2/11 [00:03<00:17,  1.95s/it]

Batch Loss: 0.7399508953094482


Epoch 12/300:  27%|██▋       | 3/11 [00:05<00:15,  1.93s/it]

Batch Loss: 0.7402819991111755


Epoch 12/300:  36%|███▋      | 4/11 [00:07<00:13,  1.93s/it]

Batch Loss: 0.7449744939804077


Epoch 12/300:  45%|████▌     | 5/11 [00:09<00:11,  1.90s/it]

Batch Loss: 0.7430785298347473


Epoch 12/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.87s/it]

Batch Loss: 0.7369444370269775


Epoch 12/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.87s/it]

Batch Loss: 0.7312026023864746


Epoch 12/300:  73%|███████▎  | 8/11 [00:15<00:05,  1.94s/it]

Batch Loss: 0.7300217747688293


Epoch 12/300:  82%|████████▏ | 9/11 [00:19<00:05,  2.55s/it]

Batch Loss: 0.7430335879325867


Epoch 12/300:  91%|█████████ | 10/11 [00:21<00:02,  2.38s/it]

Batch Loss: 0.7342144250869751


Epoch 12/300: 100%|██████████| 11/11 [00:22<00:00,  2.06s/it]


Batch Loss: 0.7294101119041443


Epoch 13/300:   9%|▉         | 1/11 [00:01<00:19,  1.94s/it]

Batch Loss: 0.739787757396698


Epoch 13/300:  18%|█▊        | 2/11 [00:03<00:17,  1.93s/it]

Batch Loss: 0.7302761077880859


Epoch 13/300:  27%|██▋       | 3/11 [00:05<00:15,  1.99s/it]

Batch Loss: 0.7332197427749634


Epoch 13/300:  36%|███▋      | 4/11 [00:07<00:13,  1.95s/it]

Batch Loss: 0.7367726564407349


Epoch 13/300:  45%|████▌     | 5/11 [00:09<00:11,  1.93s/it]

Batch Loss: 0.7410221695899963


Epoch 13/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.97s/it]

Batch Loss: 0.7465898394584656


Epoch 13/300:  64%|██████▎   | 7/11 [00:13<00:07,  1.94s/it]

Batch Loss: 0.7376255393028259


Epoch 13/300:  73%|███████▎  | 8/11 [00:17<00:07,  2.66s/it]

Batch Loss: 0.7358484268188477


Epoch 13/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.42s/it]

Batch Loss: 0.747573971748352


Epoch 13/300:  91%|█████████ | 10/11 [00:21<00:02,  2.25s/it]

Batch Loss: 0.7358379364013672


Epoch 13/300: 100%|██████████| 11/11 [00:22<00:00,  2.09s/it]


Batch Loss: 0.7365199327468872


Epoch 14/300:   9%|▉         | 1/11 [00:02<00:20,  2.03s/it]

Batch Loss: 0.743800938129425


Epoch 14/300:  18%|█▊        | 2/11 [00:03<00:17,  1.92s/it]

Batch Loss: 0.7272002100944519


Epoch 14/300:  27%|██▋       | 3/11 [00:05<00:15,  1.89s/it]

Batch Loss: 0.740547776222229


Epoch 14/300:  36%|███▋      | 4/11 [00:07<00:13,  1.92s/it]

Batch Loss: 0.7315204739570618


Epoch 14/300:  45%|████▌     | 5/11 [00:09<00:11,  1.91s/it]

Batch Loss: 0.7431174516677856


Epoch 14/300:  55%|█████▍    | 6/11 [00:11<00:09,  1.91s/it]

Batch Loss: 0.7307971715927124


Epoch 14/300:  64%|██████▎   | 7/11 [00:15<00:10,  2.62s/it]

Batch Loss: 0.7271736860275269


Epoch 14/300:  73%|███████▎  | 8/11 [00:17<00:07,  2.39s/it]

Batch Loss: 0.7256747484207153


Epoch 14/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.24s/it]

Batch Loss: 0.7245981693267822


Epoch 14/300:  91%|█████████ | 10/11 [00:21<00:02,  2.16s/it]

Batch Loss: 0.7358826994895935


Epoch 14/300: 100%|██████████| 11/11 [00:22<00:00,  2.07s/it]


Batch Loss: 0.7303423285484314


Epoch 15/300:   9%|▉         | 1/11 [00:01<00:18,  1.89s/it]

Batch Loss: 0.7311898469924927


Epoch 15/300:  18%|█▊        | 2/11 [00:03<00:17,  1.89s/it]

Batch Loss: 0.7385170459747314


Epoch 15/300:  27%|██▋       | 3/11 [00:05<00:15,  1.93s/it]

Batch Loss: 0.7260569334030151


Epoch 15/300:  36%|███▋      | 4/11 [00:07<00:13,  1.94s/it]

Batch Loss: 0.7328760027885437


Epoch 15/300:  45%|████▌     | 5/11 [00:09<00:11,  1.91s/it]

Batch Loss: 0.7343663573265076


Epoch 15/300:  55%|█████▍    | 6/11 [00:13<00:12,  2.58s/it]

Batch Loss: 0.7358404994010925


Epoch 15/300:  64%|██████▎   | 7/11 [00:15<00:09,  2.37s/it]

Batch Loss: 0.7243000268936157


Epoch 15/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.22s/it]

Batch Loss: 0.7285226583480835


Epoch 15/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.13s/it]

Batch Loss: 0.7240411639213562


Epoch 15/300:  91%|█████████ | 10/11 [00:21<00:02,  2.07s/it]

Batch Loss: 0.7221397161483765


Epoch 15/300: 100%|██████████| 11/11 [00:22<00:00,  2.05s/it]


Batch Loss: 0.7195029854774475


Epoch 16/300:   9%|▉         | 1/11 [00:01<00:19,  1.92s/it]

Batch Loss: 0.7199113965034485


Epoch 16/300:  18%|█▊        | 2/11 [00:03<00:16,  1.87s/it]

Batch Loss: 0.7204291820526123


Epoch 16/300:  27%|██▋       | 3/11 [00:05<00:14,  1.87s/it]

Batch Loss: 0.7197888493537903


Epoch 16/300:  36%|███▋      | 4/11 [00:07<00:13,  1.91s/it]

Batch Loss: 0.7212901711463928


Epoch 16/300:  45%|████▌     | 5/11 [00:11<00:15,  2.62s/it]

Batch Loss: 0.7377653121948242


Epoch 16/300:  55%|█████▍    | 6/11 [00:13<00:11,  2.38s/it]

Batch Loss: 0.7194247245788574


Epoch 16/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.21s/it]

Batch Loss: 0.7220149040222168


Epoch 16/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.10s/it]

Batch Loss: 0.7208381295204163


Epoch 16/300:  82%|████████▏ | 9/11 [00:19<00:04,  2.04s/it]

Batch Loss: 0.72283935546875


Epoch 16/300:  91%|█████████ | 10/11 [00:20<00:01,  1.99s/it]

Batch Loss: 0.7423010468482971


Epoch 16/300: 100%|██████████| 11/11 [00:22<00:00,  2.02s/it]


Batch Loss: 0.7186816930770874


Epoch 17/300:   9%|▉         | 1/11 [00:01<00:19,  1.96s/it]

Batch Loss: 0.7267043590545654


Epoch 17/300:  18%|█▊        | 2/11 [00:03<00:17,  1.96s/it]

Batch Loss: 0.7212738394737244


Epoch 17/300:  27%|██▋       | 3/11 [00:05<00:15,  1.96s/it]

Batch Loss: 0.7296324372291565


Epoch 17/300:  36%|███▋      | 4/11 [00:09<00:18,  2.61s/it]

Batch Loss: 0.7213701605796814


Epoch 17/300:  45%|████▌     | 5/11 [00:11<00:14,  2.36s/it]

Batch Loss: 0.7157129049301147


Epoch 17/300:  55%|█████▍    | 6/11 [00:13<00:10,  2.20s/it]

Batch Loss: 0.7158142924308777


Epoch 17/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.09s/it]

Batch Loss: 0.7244541049003601


Epoch 17/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.02s/it]

Batch Loss: 0.7200593948364258


Epoch 17/300:  82%|████████▏ | 9/11 [00:18<00:03,  2.00s/it]

Batch Loss: 0.7150096893310547


Epoch 17/300:  91%|█████████ | 10/11 [00:20<00:01,  1.97s/it]

Batch Loss: 0.7274829745292664


Epoch 17/300: 100%|██████████| 11/11 [00:22<00:00,  2.02s/it]


Batch Loss: 0.7328810095787048


Epoch 18/300:   9%|▉         | 1/11 [00:01<00:19,  1.94s/it]

Batch Loss: 0.7266672253608704


Epoch 18/300:  18%|█▊        | 2/11 [00:03<00:17,  1.95s/it]

Batch Loss: 0.7179184556007385


Epoch 18/300:  27%|██▋       | 3/11 [00:08<00:23,  2.94s/it]

Batch Loss: 0.7231310606002808


Epoch 18/300:  36%|███▋      | 4/11 [00:09<00:17,  2.51s/it]

Batch Loss: 0.7144427299499512


Epoch 18/300:  45%|████▌     | 5/11 [00:11<00:13,  2.31s/it]

Batch Loss: 0.7198529839515686


Epoch 18/300:  55%|█████▍    | 6/11 [00:13<00:10,  2.17s/it]

Batch Loss: 0.7142902612686157


Epoch 18/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.08s/it]

Batch Loss: 0.7144299745559692


Epoch 18/300:  73%|███████▎  | 8/11 [00:17<00:06,  2.02s/it]

Batch Loss: 0.7257034182548523


Epoch 18/300:  82%|████████▏ | 9/11 [00:19<00:03,  1.99s/it]

Batch Loss: 0.7109796404838562


Epoch 18/300:  91%|█████████ | 10/11 [00:21<00:01,  1.96s/it]

Batch Loss: 0.7126175761222839


Epoch 18/300: 100%|██████████| 11/11 [00:22<00:00,  2.07s/it]


Batch Loss: 0.712571382522583


Epoch 19/300:   9%|▉         | 1/11 [00:01<00:19,  1.97s/it]

Batch Loss: 0.713473379611969


Epoch 19/300:  18%|█▊        | 2/11 [00:05<00:27,  3.06s/it]

Batch Loss: 0.7211145758628845


Epoch 19/300:  27%|██▋       | 3/11 [00:07<00:20,  2.54s/it]

Batch Loss: 0.7124940156936646


Epoch 19/300:  36%|███▋      | 4/11 [00:09<00:15,  2.28s/it]

Batch Loss: 0.7212578058242798


Epoch 19/300:  45%|████▌     | 5/11 [00:11<00:12,  2.15s/it]

Batch Loss: 0.7241895794868469


Epoch 19/300:  55%|█████▍    | 6/11 [00:13<00:10,  2.07s/it]

Batch Loss: 0.7288814783096313


Epoch 19/300:  64%|██████▎   | 7/11 [00:15<00:08,  2.00s/it]

Batch Loss: 0.7109537720680237


Epoch 19/300:  73%|███████▎  | 8/11 [00:17<00:05,  1.96s/it]

Batch Loss: 0.7107380032539368


Epoch 19/300:  82%|████████▏ | 9/11 [00:19<00:03,  1.94s/it]

Batch Loss: 0.7103511095046997


Epoch 19/300:  91%|█████████ | 10/11 [00:20<00:01,  1.92s/it]

Batch Loss: 0.7137489914894104


Epoch 19/300: 100%|██████████| 11/11 [00:22<00:00,  2.02s/it]


Batch Loss: 0.7150452733039856


Epoch 20/300:   0%|          | 0/11 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/tmp/ipykernel_2964214/603521227.py", line 20, in <module>
    loss = diffusion.forward(y0=y, y_cond=X)
  File "/tmp/ipykernel_2964214/913358632.py", line 133, in forward
    noise_hat = self.eps_model(torch.cat([y_noisy, y_cond], dim=1) if y_cond is not None else y_noisy, sample_gammas)
  File "/home/mingdayang/miniconda3/envs/unibev/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1123, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/mingdayang/miniconda3/envs/unibev/lib/python3.7/site-packages/wandb/integration/torch/wandb_torch.py", line 114, in <lambda>
    mod, inp, outp, log_track_params
  File "/home/mingdayang/miniconda3/envs/unibev/lib/python3.7/site-packages/wandb/integration/torch/wandb_torch.py", line 108, in parameter_log_hook
    self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name)
KeyboardInterrupt


BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x72415c0d95d0>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
wandb.finish()

In [None]:
checkpoint = torch.load('checkpoints/diffusion_model_square.pth')

Unet=UNet(**checkpoint['model_config'])
Unet.load_state_dict(checkpoint['model_state_dict'])
diffusion = DenoiseDiffusion(Unet, checkpoint['beta_schedule'], 'cuda:0')

diffusion.set_new_noise_schedule(phase='test')
print(diffusion.eps_model.in_channel)

In [None]:

diffusion.set_new_noise_schedule(phase='test')
y_ddpm, _ = diffusion.ddpm_sampler(1)
y_ddim, _ = diffusion.ddim_sampler(1, steps=10, eta=0.0)



In [None]:
start = time.time()
y, _ = diffusion.ddim_sampler(1, steps=30, clip_denoised=True)
end = time.time()
print(y)
print(f"Max: {torch.amax(y)}, Min: {torch.amin(y)}")
print(f"Execution time: {end - start}")
plt.imshow(y.cpu().numpy()[0, 0], cmap='viridis')
plt.colorbar()
plt.show()

In [None]:
start = time.time()
y, _ = diffusion.ddim_sampler(1, steps=50, clip_denoised=True)
end = time.time()
print(y)
print(f"Max: {torch.amax(y)}, Min: {torch.amin(y)}")
print(f"Execution time: {end - start}")
plt.imshow(y.cpu().numpy()[0, 0], cmap='viridis')
plt.colorbar()
plt.show()

In [None]:
y_ddpm, _ = diffusion.ddpm_sampler(1)
y_ddim, _ = diffusion.ddim_sampler(1, steps=10)
fig, ax = plt.subplots(1, 2)

im0 = ax[0].imshow(y_ddpm.cpu().numpy()[0, 0], cmap='viridis')
im1 = ax[1].imshow(y_ddim.cpu().numpy()[0, 0], cmap='viridis')

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])

plt.show()

In [None]:
def save_model(model, model_config, beta_schedule):
    save_temp = {}
    save_temp['model_state_dict'] = model.eps_model.state_dict()
    save_temp['model_config'] = model_config
    save_temp['beta_schedule'] = beta_schedule
    torch.save(save_temp, 'checkpoints/temp.pth')