In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DDPM(nn.Module):
    def __init__(self, num_timesteps):
        super(DDPM, self).__init__()
        self.num_timesteps = num_timesteps
        self.model = self.build_model()

    def build_model(self):
        model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
        return model

    def forward(self, x, t):
        return self.model(x)

    def compute_loss(self, x, t):
        noise = torch.randn_like(x)
        t = t.view(-1, 1, 1, 1)  # t의 크기를 (batch_size, 1, 1, 1)로 맞춤
        noisy_x = x + noise * torch.sqrt(t / self.num_timesteps)
        predicted_noise = self.forward(noisy_x, t)
        return F.mse_loss(predicted_noise, noise)

# 사용 예시
model = DDPM(num_timesteps=1000)
x = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지
t = torch.randint(0, 1000, (8,))  # 각 이미지에 대한 랜덤 타임스텝
loss = model.compute_loss(x, t)
print(loss)

tensor(1.0140, grad_fn=<MseLossBackward0>)


In [2]:
%%timeit -n 10 -r 10
batch_size = 8
x = torch.randn(batch_size, 3, 32, 32)  # 8개의 32x32 RGB 이미지
t = torch.randint(0, 1000, (batch_size,))  # 각 이미지에 대한 랜덤 타임스텝
loss = model.compute_loss(x, t)

2.89 ms ± 314 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DDIM(nn.Module):
    def __init__(self, num_timesteps):
        super(DDIM, self).__init__()
        self.num_timesteps = num_timesteps
        self.model = self.build_model()

    def build_model(self):
        model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
        return model

    def forward(self, x, t):
        return self.model(x)

    def compute_loss(self, x, t):
        noise = torch.randn_like(x)
        t = t.view(-1, 1, 1, 1)  # t의 크기를 (batch_size, 1, 1, 1)로 맞춤
        noisy_x = x + noise * torch.sqrt(t / self.num_timesteps)
        predicted_noise = self.forward(noisy_x, t)
        return F.mse_loss(predicted_noise, noise)

# 사용 예시
model = DDIM(num_timesteps=50)
x = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지
t = torch.randint(0, 50, (8,))  # 각 이미지에 대한 랜덤 타임스텝
loss = model.compute_loss(x, t)
print(loss)

tensor(1.0251, grad_fn=<MseLossBackward0>)


In [4]:
%%timeit -n 10 -r 10
batch_size = 8
x = torch.randn(batch_size, 3, 32, 32)  # 8개의 32x32 RGB 이미지
t = torch.randint(0, 50, (batch_size,))  # 각 이미지에 대한 랜덤 타임스텝
loss = model.compute_loss(x, t)

2.57 ms ± 228 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class DDPM(nn.Module):
    def __init__(self, num_timesteps):
        super(DDPM, self).__init__()
        self.num_timesteps = num_timesteps
        self.model = self.build_model()

    def build_model(self):
        model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
        return model

    def forward(self, x, t):
        return self.model(x)

    def compute_loss(self, x, t):
        noise = torch.randn_like(x)
        t = t.view(-1, 1, 1, 1).float()
        noisy_x = x + noise * torch.sqrt(t / self.num_timesteps)
        predicted_noise = self.forward(noisy_x, t)
        return F.mse_loss(predicted_noise, noise)

    def ddpm_sampling(self, x_T):
        x = x_T
        for t in reversed(range(1, self.num_timesteps)):
            t_tensor = torch.tensor([t], device=x.device).view(-1, 1, 1, 1).float()
            epsilon = self.forward(x, t_tensor)
            alpha_t = 1 - t / self.num_timesteps
            alpha_t = torch.tensor(alpha_t, device=x.device).view(-1, 1, 1, 1)
            x = (x - epsilon * torch.sqrt(1 - alpha_t) / torch.sqrt(alpha_t)) + torch.randn_like(x) * torch.sqrt(1 - alpha_t)
        return x

# 사용 예시
model_ddpm = DDPM(num_timesteps=1000)
x_T_ddpm = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddpm = model_ddpm.ddpm_sampling(x_T_ddpm)
ddpm_time = time.time() - start_time
print(f"DDPM Sampling Time: {ddpm_time:.4f} seconds")
print(sampled_images_ddpm.shape)


DDPM Sampling Time: 2.9006 seconds
torch.Size([8, 3, 32, 32])


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class DDIM(nn.Module):
    def __init__(self, num_timesteps, eta=0.):
        super(DDIM, self).__init__()
        self.num_timesteps = num_timesteps
        self.eta = eta
        self.model = self.build_model()

    def build_model(self):
        model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
        return model

    def forward(self, x, t):
        return self.model(x)

    def compute_loss(self, x, t):
        noise = torch.randn_like(x)
        t = t.view(-1, 1, 1, 1).float()
        noisy_x = x + noise * torch.sqrt(t / self.num_timesteps)
        predicted_noise = self.forward(noisy_x, t)
        return F.mse_loss(predicted_noise, noise)

    def ddim_sampling(self, x_T):
        x = x_T
        for t in reversed(range(1, self.num_timesteps)):
            t_tensor = torch.tensor([t], device=x.device).view(-1, 1, 1, 1).float()
            epsilon = self.forward(x, t_tensor)
            alpha_t = 1 - t / self.num_timesteps
            alpha_t_1 = 1 - (t - 1) / self.num_timesteps
            alpha_t = torch.tensor(alpha_t, device=x.device)
            alpha_t_1 = torch.tensor(alpha_t_1, device=x.device)
            x = (x - torch.sqrt(1 - alpha_t) * epsilon) / torch.sqrt(alpha_t) * torch.sqrt(alpha_t_1) + torch.sqrt(1 - alpha_t_1 - self.eta) * epsilon
        return x

# 사용 예시
model_ddim = DDIM(num_timesteps=50)
x_T_ddim = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddim = model_ddim.ddim_sampling(x_T_ddim)
ddim_time = time.time() - start_time
print(f"DDIM Sampling Time: {ddim_time:.4f} seconds")
print(sampled_images_ddim.shape)


DDIM Sampling Time: 0.1289 seconds
torch.Size([8, 3, 32, 32])


In [10]:
# DDPM 샘플링
x_T_ddpm = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddpm = model_ddpm.ddpm_sampling(x_T_ddpm)
ddpm_time = time.time() - start_time
print(f"DDPM Sampling Time: {ddpm_time:.4f} seconds")

# DDIM 샘플링
x_T_ddim = torch.randn(8, 3, 32, 32)  # 8개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddim = model_ddim.ddim_sampling(x_T_ddim)
ddim_time = time.time() - start_time
print(f"DDIM Sampling Time: {ddim_time:.4f} seconds")

DDPM Sampling Time: 3.4525 seconds
torch.Size([8, 3, 32, 32])
DDIM Sampling Time: 0.0916 seconds
torch.Size([8, 3, 32, 32])


In [14]:
# DDPM 샘플링
x_T_ddpm = torch.randn(100, 3, 32, 32)  # 100개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddpm = model_ddpm.ddpm_sampling(x_T_ddpm)
ddpm_time = time.time() - start_time
print(f"DDPM Sampling Time: {ddpm_time:.4f} seconds")

# DDIM 샘플링
x_T_ddim = torch.randn(100, 3, 32, 32)  # 100개의 32x32 RGB 이미지로 시작
start_time = time.time()
sampled_images_ddim = model_ddim.ddim_sampling(x_T_ddim)
ddim_time = time.time() - start_time
print(f"DDIM Sampling Time: {ddim_time:.4f} seconds")

print('='*80)
print(f'DDIM이 DDPM보다 {ddpm_time / ddim_time:.2f}배 빠릅니다')
print('='*80)

DDPM Sampling Time: 15.5073 seconds
DDIM Sampling Time: 0.4360 seconds
DDIM이 DDPM보다 35.57배 빠릅니다
