In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torch import optim, nn
from math import exp

import numpy as np

from tqdm import tqdm
from random import randint

from load_dataset import load_dataset, plot_image
from auto_encoder2 import PAutoE

device = "cuda"

pimages = load_dataset().to(device)

In [2]:
T = 150
beta = torch.linspace(1e-4, 0.1, T, device=device)

alpha = 1 - beta

alpha_ = torch.zeros(T, device=device)

for i in range(T):
    alpha_[i] = torch.prod(alpha[0:i+1])

alpha_

tensor([9.9990e-01, 9.9913e-01, 9.9769e-01, 9.9558e-01, 9.9281e-01, 9.8939e-01,
        9.8531e-01, 9.8058e-01, 9.7523e-01, 9.6924e-01, 9.6265e-01, 9.5545e-01,
        9.4767e-01, 9.3932e-01, 9.3040e-01, 9.2095e-01, 9.1098e-01, 9.0051e-01,
        8.8955e-01, 8.7813e-01, 8.6627e-01, 8.5398e-01, 8.4130e-01, 8.2824e-01,
        8.1483e-01, 8.0109e-01, 7.8705e-01, 7.7272e-01, 7.5814e-01, 7.4332e-01,
        7.2830e-01, 7.1309e-01, 6.9772e-01, 6.8221e-01, 6.6659e-01, 6.5088e-01,
        6.3510e-01, 6.1929e-01, 6.0345e-01, 5.8761e-01, 5.7179e-01, 5.5601e-01,
        5.4030e-01, 5.2467e-01, 5.0914e-01, 4.9373e-01, 4.7845e-01, 4.6333e-01,
        4.4837e-01, 4.3359e-01, 4.1901e-01, 4.0464e-01, 3.9050e-01, 3.7658e-01,
        3.6291e-01, 3.4949e-01, 3.3633e-01, 3.2345e-01, 3.1084e-01, 2.9851e-01,
        2.8647e-01, 2.7473e-01, 2.6328e-01, 2.5213e-01, 2.4129e-01, 2.3075e-01,
        2.2051e-01, 2.1059e-01, 2.0096e-01, 1.9165e-01, 1.8263e-01, 1.7392e-01,
        1.6551e-01, 1.5739e-01, 1.4957e-

In [None]:
pimages[0].shape

In [None]:
a = torch.randperm(3)
a.tolist()

In [None]:
t = -1

a = torch.randperm(3)
x0 = pimages[0, 0, a]

z = torch.randn(x0.shape).to(device)
xt = x0 * torch.sqrt(alpha_[t]) + z * torch.sqrt(1-alpha_[t])

plot_image(x0)
plot_image(xt)

In [3]:
for t in range(T):
    
    model = PAutoE(3, 3).to(device)

    loss_func = nn.MSELoss()    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
        
    pbar = tqdm(range(15001), miniters=15)
    def closure():
        
        optimizer.zero_grad()

        source = randint(0, 2)
        batch = torch.randperm(385)[:8]
        color = torch.randperm(3)
        
        x0 = pimages[source, batch][color]
        e = torch.randn(x0.shape, device=device)
        
        x_nois = torch.sqrt(alpha_[t]) * x0 + torch.sqrt(1-alpha_[t]) * e
        e_pred = model(x_nois)

        # Calculate the loss between the result and the noise
        loss = loss_func(e, e_pred)
        loss.backward()
        
        pbar.set_description(f"t = {t}, %.8f" % loss, refresh=False)
        
        return loss
            
    model.train()
    for j in pbar:
        optimizer.step(closure)

    script = torch.jit.script(model)
    script.save(f"./model_{t}.pt")

t = 0, 0.36052319: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.92it/s]
t = 1, 0.25219414: 100%|██████████████████| 15001/15001 [02:47<00:00, 89.44it/s]
t = 2, 0.14929168: 100%|██████████████████| 15001/15001 [02:48<00:00, 89.14it/s]
t = 3, 0.13708077: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.94it/s]
t = 4, 0.09507463: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.82it/s]
t = 5, 0.14011958: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.87it/s]
t = 6, 0.06966732: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.84it/s]
t = 7, 0.15669838: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.96it/s]
t = 8, 0.08744872: 100%|██████████████████| 15001/15001 [02:49<00:00, 88.74it/s]
t = 9, 0.15322866: 100%|██████████████████| 15001/15001 [02:48<00:00, 88.92it/s]
t = 10, 0.12385450: 100%|█████████████████| 15001/15001 [02:48<00:00, 88.89it/s]
t = 11, 0.08335673: 100%|█████████████████| 15001/15001 [02:48<00:00, 89.01it/s]
t = 12, 0.07491414: 100%|███

In [None]:
from UNet import UNet

In [None]:
model = UNet(3,3,8).to(device)

In [None]:
def noise_images(x, t):
    "Add noise to images at instant t"
    a = torch.sqrt(alpha_[t])[:, None, None, None]
    b = torch.sqrt(1 - alpha_[t])[:, None, None, None]
    Ɛ = torch.randn_like(x)
    return a * x + b * Ɛ, Ɛ

batch_size = 8

t = torch.randint(low=1, high=13, size=(batch_size,), device=device)

source = randint(0, 2)
batch = torch.randperm(385)[:batch_size]
x0 = pimages[source][batch]

x_t, noise = noise_images(x0, t)

model(x_t, t)

In [None]:
torch.randint(low=1, high=13, size=(3,))

In [None]:
t

In [None]:
3 //2