In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
from tqdm import tqdm
from keras.datasets.mnist import load_data
import matplotlib as plt
from unet import UNet
from diffusion_model import DiffusionModel

import imageio


In [None]:
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.

def sample_batch(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).unsqueeze(1).to(device)
    return torch.nn.functional.interpolate(data, 32)

In [None]:
device = 'cuda'
batch_size = 64
#model = UNet().to(device)
model = torch.load("model_paper2_epoch_3999").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
diffusion_model = DiffusionModel(1000,model, device)

In [None]:
@torch.no_grad()
def sampling(self, n_samples=1, image_channel=1, image_size=(32, 32), use_tqdm=True):
        
    xT = torch.randn((n_samples, image_channel, image_size[0], image_size[1]), device=self.device)
    x = xT

    all_x = [x]
    progress_bar = tqdm if use_tqdm else lambda x : x
    for t in progress_bar(range(self.T, 0, -1)):
        if(t == 0):
            z = torch.zeros_like(x, device = self.device)
        else:
            z = torch.randn_like(x, device = self.device)

        t = torch.ones(n_samples, dtype=torch.long, device=self.device) * t
                
        alpha_t = self.alpha[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        alpha_bar_t = self.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        beta_t = self.beta[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        eps_theta = self.function_approximator(x,t-1)
            
        mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
          1 - alpha_bar_t)) * eps_theta)
        sigma = torch.sqrt(beta_t)
        
        x =  mean + sigma * z
        all_x.append(x)

    return all_x

In [None]:
imgs = sampling(diffusion_model, n_samples=10)

In [None]:
idx = 0
T = -1
plt.imshow(imgs[T][idx].clip(0,1).cpu().numpy().squeeze(0), cmpa='gray')

In [None]:
imgs_np = [e[idx].squeeze(0).cpu().numpy().clip(0,1) for e in imgs]
imageio.mimsave("movie.mp4", imgs_np)