In [None]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

loss_history = []
iterations = 1000

for it in range(iterations):
    optimizer.zero_grad()
    noise = torch.randn_like(x0, device=device)
    t = torch.randint(0, noise_scheduler.steps, (1,1), device=device)
    x_t = noise_scheduler.add_noise(x0, t, noise)
    pred_noise = model(x_t, t)
    loss = criterion(pred_noise, noise)
    loss_history.append(loss.item())
    if it % 100 == 0:
        print(f"Iteration {it}, Loss {loss.item()}")
    loss.backward()
    optimizer.step()

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 3, 3, padding=1)
        self.linear = torch.nn.Linear(1, 16)

    def forward(self, x, t):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        t = t.to(torch.float32)
        t = self.linear(t)
        x = x + t.view(-1, 16, 1, 1)
        x = self.conv2(x)
        return x

In [None]:
class NoiseScheduler(torch.nn.Module):
    def __init__(self, steps=24, beta_start=1e-4, beta_end=0.6):
        super(NoiseScheduler, self).__init__()
        self.steps = steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        beta = torch.linspace(beta_start, beta_end, steps)
        alpha = 1. - beta
        alpha_bar = torch.cumprod(alpha, 0)

        self.register_buffer('alpha_bar', alpha_bar)

    def add_noise(self, x0, t, noise):
        """
        Adds arbitrary noise to an image
        :param x0: initial image
        :param t: step number, 0 indexed (0 <= t < steps)
        :param noise: noise to add
        :return: image with noise at step t
        """
        alpha_bar = self.alpha_bar[t]
        return torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise