In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torch import optim, nn
from math import exp

import numpy as np

from tqdm import tqdm
from random import randint

from load_dataset import load_dataset, plot_image
from auto_encoder2 import PAutoE

device = "cuda"

pimages = load_dataset().to(device)

In [2]:
T = 1000
beta = torch.linspace(10e-4, 0.02, T, device=device)

alpha = 1 - beta

alpha_ = torch.zeros(T, device=device)

for i in range(T):
    alpha_[i] = torch.prod(alpha[0:i+1])

alpha_[-1]

tensor(2.5652e-05, device='cuda:0')

In [None]:
pimages[0].shape

In [None]:
a = torch.randperm(3)
a.tolist()

In [None]:
t = -1

a = torch.randperm(3)
x0 = pimages[0, 0, a]

z = torch.randn(x0.shape).to(device)
xt = x0 * torch.sqrt(alpha_[t]) + z * torch.sqrt(1-alpha_[t])

plot_image(x0)
plot_image(xt)

In [None]:
for t in range(597,T):
    
    model = PAutoE(3, 3).to(device)

    loss_func = nn.MSELoss()    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
        
    pbar = tqdm(range(15001), miniters=15)
    def closure():
        
        optimizer.zero_grad()

        source = randint(0, 2)
        batch = torch.randperm(385)[:8]
        color = torch.randperm(3)
        
        x0 = pimages[source, batch][color]
        e = torch.randn(x0.shape, device=device)
        
        x_nois = torch.sqrt(alpha_[t]) * x0 + torch.sqrt(1-alpha_[t]) * e
        e_pred = model(x_nois)

        # Calculate the loss between the result and the noise
        loss = loss_func(e, e_pred)
        loss.backward()
        
        pbar.set_description(f"t = {t}, %.8f" % loss, refresh=False)
        
        return loss
            
    model.train()
    for j in pbar:
        optimizer.step(closure)

    script = torch.jit.script(model)
    script.save(f"./M/model_{t}.pt")

t = 597, 0.00414827: 100%|████████████████| 15001/15001 [02:49<00:00, 88.45it/s]
t = 598, 0.00332553: 100%|████████████████| 15001/15001 [02:47<00:00, 89.69it/s]
t = 599, 0.00472171: 100%|████████████████| 15001/15001 [02:47<00:00, 89.42it/s]
t = 600, 0.00285755: 100%|████████████████| 15001/15001 [02:48<00:00, 89.21it/s]
t = 601, 0.00448557: 100%|████████████████| 15001/15001 [02:48<00:00, 88.95it/s]
t = 602, 0.00301361: 100%|████████████████| 15001/15001 [02:48<00:00, 89.00it/s]
t = 603, 0.00347396: 100%|████████████████| 15001/15001 [02:48<00:00, 88.90it/s]
t = 604, 0.00463837: 100%|████████████████| 15001/15001 [02:48<00:00, 89.05it/s]
t = 605, 0.00382890: 100%|████████████████| 15001/15001 [02:48<00:00, 88.96it/s]
t = 606, 0.00493825: 100%|████████████████| 15001/15001 [02:48<00:00, 88.99it/s]
t = 607, 0.00393389: 100%|████████████████| 15001/15001 [02:48<00:00, 89.09it/s]
t = 608, 0.00395970: 100%|████████████████| 15001/15001 [02:48<00:00, 89.25it/s]
t = 609, 0.00272561: 100%|██

In [None]:
from UNet import UNet

In [None]:
model = UNet(3,3,8).to(device)

In [None]:
def noise_images(x, t):
    "Add noise to images at instant t"
    a = torch.sqrt(alpha_[t])[:, None, None, None]
    b = torch.sqrt(1 - alpha_[t])[:, None, None, None]
    Ɛ = torch.randn_like(x)
    return a * x + b * Ɛ, Ɛ

batch_size = 8

t = torch.randint(low=1, high=13, size=(batch_size,), device=device)

source = randint(0, 2)
batch = torch.randperm(385)[:batch_size]
x0 = pimages[source][batch]

x_t, noise = noise_images(x0, t)

model(x_t, t)

In [None]:
torch.randint(low=1, high=13, size=(3,))

In [None]:
t

In [None]:
3 //2