<a href="https://colab.research.google.com/github/ichiyan/BrainDecoding/blob/master/ddpm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import os
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import logging

logging.basicConfig(filename='', format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt = "%I:%M:%S")

Basic Diffusion (still in pixel space)

In [None]:
# TODO
#  - use image latent embeddings instead of image
#  - improve based on fast AI and LDM/SD
#  - add text cells for formulas used (see notes)

class Diffusion:
  def __int(self, schedule="linear", noise_steps=1000, beta_start=1e-4, beta_end=0.02, cosine_s=8e-3, img_size=64, device="cuda"):
    self.schedule = schedule
    self.noise_steps = noise_steps
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.cosine_s = cosine_s
    self.img_size = img_size
    self.device = device

    self.beta = self.prepare_noise_schedule().to(device)
    self.alpha = 1. - self.beta
    self.alpha_bar = torch.cumprod(self.alpha, dim=0)

  # TODO: add dtype=torch.float64 ?
  def prepare_noise_schedule(self):
    if self.schedule == "linear":
      betas =  torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
    
    elif self.schedule == "cosine":
      timesteps = (
          torch.arrange(self.noise_steps + 1) / self.noise_steps + self.cosine_s
      )

      alphas = timesteps / (1 + self.cosine_s) * np.pi / 2
      alphas = torch.cos(alphas).pow(2)
      alphas = alphas / alphas[0]
      betas = 1 - alphas[1:] / alphas[:-1]
      betas = np.clip(betas, a_min=0, a_max=0.999)
    
    return betas

  # adding noise in one single step instead of adding noise iteratively
  def noise_images(self, x, t):
    sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
    sqrt_one_minus_alpha_bar = torch.sqrt(1. - self.alpha_bar[t])[:, None, None, None]
    epsilon = torch.randn_like(x)
    return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * epsilon, epsilon

  def sample_timesteps(self, n):
    return torch.randint(low=1, high=self.noise_steps, size=(n,))

  def sample(self, model, n):
    logging.info(f"Sampling {n} new images...")
    model.eval()
    with torch.no_grad():
      # create initial images by sampling from the normal distribution 
      x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
      # denoising loop 
      for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
        # creating timestep (tensor of length n with the current timestep)
        t = (torch.ones(n) * i).long().to(self.device)
        predicted_noise = model(x, t)
        alpha = self.alpha[t][:, None, None, None]
        alpha_bar = self.alpha_bar[t][:, None, None, None]
        beta = self.beta[t][:, None, None, None]

        if i > 1:
          noise = torch.randn_like(x)
        else:
          noise = torch.zeros_like(x)

        # remove portion of noise in image
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_bar))) * predicted_noise) + torch.sqrt(beta) * noise
      
      model.train()
      # clip values and then bring values back to 0-1
      x = (x.clamp(-1, 1) + 1) / 2
      # convert to valid pixel range and change data type for saving
      x = (x * 255).type(torch.uint8)

      return x





  


UNet Model 

In [None]:
class UNet(nn.Module):
  # sep ntbk
  pass
