In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
from tqdm import tqdm
from keras.datasets.mnist import load_data
import matplotlib as plt
from unet import UNet
from diffusion_model import DiffusionModel


In [13]:
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.

def sample_batch(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).unsqueeze(1).to(device)
    return torch.nn.functional.interpolate(data, 32)

In [14]:
t = (torch.rand(10)*10).long()
img = torch.randn((10,1,32,32))
model = UNet()
img = model(img, t)
img.shape

torch.Size([10, 1, 32, 32])

In [None]:
device = 'mps'
batch_size = 64
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
diffusion_model = DiffusionModel(1000,model, device)



In [None]:
training_loss = []
for epoch in tqdm((range(4000))):
    loss = diffusion_model.training_step(batch_size, optimizer)
    training_loss.append(loss)

    if epoch % 100 ==0:
        plt.plot(training_loss)
        plt.savefig('training_loss.png')
        plt.close()

        plt.plot(training_loss[-1000:])
        plt.savefig("training_loss_cropped.png")
        plt.close()
    if epoch % 1000 == 0:
        nb_images = 81
        samples = diffusion_model.sampling(n_samples=nb_images,use_tqdm=True)
        plt.figure(figsize=(17,17))
        for i in range(nb_images):
            plt.subplot(9,9,1+i)
            plt.axis('off')
            plt.imshow(samples[i].squeeze(0).clip(0,1).data.cpu().numpy(), cmap='gray')

        plt.show()
        plt.savefig(f"sample_epoch_{epoch}.png")
        plt.close()
torch.save(model.cpu(), f'model_paper2_epoch_{epoch}')
model.to(device)


In [None]:
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.

def sample_batch(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).unsqueeze(1).to(device)
    return torch.nn.functional.interpolate(data, 32)
    #interpolate -> need least 2 batch size

In [None]:
device = 'cuda'
batch_size = 64
#model = UNet().to(device)
model = torch.load("model_paper2_epoch_3999").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
diffusion_model = DiffusionModel(1000,model, device)

In [None]:
def forward(diffusion_model, T, x0):
    #Add Noise
    x_forward = [x0]
    x = x0
    for t in range(T):
        std = torch.sqrt(diffusion_model.beta[t])
        x = x + torch.randn_like(x) * std
        x_forward.append(x)

    return x_forward
'''
def forward(x0, T, device):
  #Add Noise
  x0 = torch.tensor(x0)
  mu = torch.sqrt(diffusion_model.alpha_bar[T]) * x0
  std = torch.sqrt(1-diffusion_model.alpha_bar[T])
  epsilon = torch.randn_like(x0)
  xt = mu + std * epsilon

  return xt

'''


In [None]:
xT = forward(x0,500, device)
plt.imshow(xT[0].squeeze(0).clip(0,1).data.cpu().numpy(), cmap='gray')

In [None]:
@torch.no_grad()
def inpainting(diffusion_model, x0, T, device, mask):

  diffusion_model = diffusion_model.to(device)
  x0 = torch.tensor(x0, device=device)
  mu = torch.sqrt(diffusion_model.alpha_bar[T]) * x0
  std = torch.sqrt(1-diffusion_model.alpha_bar[T])
  epsilon = torch.randn_like(x0, device = device)
  xT = mu + std * epsilon
  x = xT

  mask = mask.bool()
  inpainting_sample = [x]

  for t in tqdm(range(T, 0, -1)):
      if(t == 0):
          z = torch.zeros_like(x, device = device)
      else:
          z = torch.randn_like(x, device = device)

      t_tensor = torch.ones(x0.shape[0]) * t
      t_tensor = t_tensor.to(device)

      alpha_t = diffusion_model.alpha[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
      alpha_bar_t = diffusion_model.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
      beta_t = diffusion_model.beta[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
      eps_theta = diffusion_model.function_approximator(x,t_tensor-1)

      mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
          1 - alpha_bar_t)) * eps_theta)
      sigma = torch.sqrt(beta_t)
      x = x0
      x[mask] =  (mean + sigma * z)[mask]
      inpainting_sample.append(x)
  return inpainting_sample

In [None]:
batch_size = 2
x0 = sample_batch(batch_size, device)
plt.imshow(x0[0].squeeze(0).clip(0,1).data.cpu().numpy(), cmap='gray')

In [None]:
mask = torch.zeros_like(x0)
mask[:,:,:,:16] = 1
x0_mask = x0 * (1 - mask)
plt.imshow(x0_mask[0].squeeze(0).clip(0,1).data.cpu().numpy(), cmap='gray')

In [None]:
x = inpainting(diffusion_model, x0=x0, T=500, device=device, mask=mask)
T = 499
xT = x[T]
plt.imshow(xT[0].squeeze(0).clip(0,1).data.cpu().numpy(), cmap='gray')