In [1]:
from diffusion import forward_diffusion_sample
import torch
from diffusion import get_time_embedding
from diffusion import model_architecture
from diffusion import reverse_diffusion_sample

# Forward diffusion test

In [2]:
num_diffusion_steps = 1000
betas = torch.linspace(0.0001, 0.02, num_diffusion_steps)

batch_size =  4
dim = 10

x_0 = torch.randn(batch_size, dim)

In [3]:
betas.shape

torch.Size([1000])

In [4]:
x_0.shape

torch.Size([4, 10])

In [5]:
timestep = 50
timestep_tensor = torch.full((batch_size,), timestep, dtype=torch.long) # Tensor shaped (batch_size,) with timestep repeated batch_size times

In [6]:
timestep_tensor.shape

torch.Size([4])

In [7]:
x_t,noise = forward_diffusion_sample(x_0, timestep_tensor,betas)

In [8]:

print("Clean sample (x_0):")
print(x_0)
print("\nNoisy sample (x_t):")
print(x_t)
print("\nNoise added:")
print(noise)

Clean sample (x_0):
tensor([[ 1.0312,  0.8105,  1.6585,  0.5953,  0.3595, -0.7537,  1.5352,  1.9974,
          2.3449, -0.1453],
        [ 0.6224,  0.1751, -0.1036, -0.4103, -0.0133,  0.4811,  1.3800,  0.7278,
          0.0162,  0.1797],
        [-0.9411, -0.1383, -1.4601,  0.8752, -0.9297,  1.0874,  0.6565,  0.3901,
         -0.8836,  0.0482],
        [ 1.7072,  1.1462, -0.7023, -0.3768, -0.8702, -0.3389,  1.5028, -0.3549,
         -0.3975, -0.0086]])

Noisy sample (x_t):
tensor([[ 0.9129,  0.9456,  1.9236,  0.6603,  0.0957, -0.8940,  1.2359,  1.8455,
          2.4707, -0.1314],
        [ 0.4904,  0.2617, -0.1398, -0.3541,  0.0190,  0.6403,  1.2285,  0.5831,
          0.2228,  0.3880],
        [-0.7446, -0.3469, -1.2333,  0.9391, -1.0112,  1.4118,  0.7305,  0.3503,
         -0.8563,  0.0370],
        [ 1.6883,  1.1925, -0.6067, -0.5078, -0.9327, -0.2396,  1.2374, -0.2426,
         -0.5686, -0.3173]])

Noise added:
tensor([[-0.5921,  0.8497,  1.6738,  0.4271, -1.4904, -0.8752, -1.5927,

# Time embedding test

In [9]:
test = get_time_embedding(timestep_tensor, 32)

In [10]:
print(test)
print(test.shape)

tensor([[-0.2624,  0.1566, -0.1032,  0.5084, -0.9589,  0.3239,  0.9999,  0.7765,
          0.4794,  0.2775,  0.1575,  0.0888,  0.0500,  0.0281,  0.0158,  0.0089,
          0.9650, -0.9877, -0.9947, -0.8611,  0.2837, -0.9461, -0.0103,  0.6301,
          0.8776,  0.9607,  0.9875,  0.9960,  0.9988,  0.9996,  0.9999,  1.0000],
        [-0.2624,  0.1566, -0.1032,  0.5084, -0.9589,  0.3239,  0.9999,  0.7765,
          0.4794,  0.2775,  0.1575,  0.0888,  0.0500,  0.0281,  0.0158,  0.0089,
          0.9650, -0.9877, -0.9947, -0.8611,  0.2837, -0.9461, -0.0103,  0.6301,
          0.8776,  0.9607,  0.9875,  0.9960,  0.9988,  0.9996,  0.9999,  1.0000],
        [-0.2624,  0.1566, -0.1032,  0.5084, -0.9589,  0.3239,  0.9999,  0.7765,
          0.4794,  0.2775,  0.1575,  0.0888,  0.0500,  0.0281,  0.0158,  0.0089,
          0.9650, -0.9877, -0.9947, -0.8611,  0.2837, -0.9461, -0.0103,  0.6301,
          0.8776,  0.9607,  0.9875,  0.9960,  0.9988,  0.9996,  0.9999,  1.0000],
        [-0.2624,  0.1566

# Reverse diffusion process

In [12]:
x_t_minus_1 = reverse_diffusion_sample(x_t, betas, timestep_tensor, model_architecture)

In [15]:
print(f" Original sample (x_t): {x_t}")
print(f" Noise added: {noise}")
print(f" Reversed sample (x_t-1): {x_t_minus_1}")

 Original sample (x_t): tensor([[ 0.9129,  0.9456,  1.9236,  0.6603,  0.0957, -0.8940,  1.2359,  1.8455,
          2.4707, -0.1314],
        [ 0.4904,  0.2617, -0.1398, -0.3541,  0.0190,  0.6403,  1.2285,  0.5831,
          0.2228,  0.3880],
        [-0.7446, -0.3469, -1.2333,  0.9391, -1.0112,  1.4118,  0.7305,  0.3503,
         -0.8563,  0.0370],
        [ 1.6883,  1.1925, -0.6067, -0.5078, -0.9327, -0.2396,  1.2374, -0.2426,
         -0.5686, -0.3173]])
 Noise added: tensor([[-0.5921,  0.8497,  1.6738,  0.4271, -1.4904, -0.8752, -1.5927, -0.7021,
          0.9303,  0.0672],
        [-0.7069,  0.5148, -0.2175,  0.2886,  0.1850,  0.9603, -0.7536, -0.7710,
          1.1935,  1.2176],
        [ 1.0512, -1.2152,  1.1807,  0.4453, -0.5517,  1.9664,  0.4839, -0.1953,
          0.0798, -0.0605],
        [ 0.0398,  0.3673,  0.4900, -0.7888, -0.4366,  0.5434, -1.3997,  0.6173,
         -1.0215, -1.7815]])
 Reversed sample (x_t-1): tensor([[ 0.9119,  0.9473,  1.9252,  0.6607,  0.0961, -0.8930,