<a href="https://colab.research.google.com/github/durml91/MMath-Project/blob/duo-branch/Image_Diffusion_(working)/DiT%20Diffusion/DIT_V5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install diffrax
!pip install equinox
!pip install einops
!pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting diffrax
  Downloading diffrax-0.3.1-py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.4/140.4 KB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Collecting equinox>=0.10.0
  Downloading equinox-0.10.1-py3-none-any.whl (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.7/108.7 KB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxtyping>=0.2.12
  Downloading jaxtyping-0.2.14-py3-none-any.whl (20 kB)
Collecting typeguard>=2.13.3
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox, diffrax
Successfully installed diffrax-0.3.1 equinox-0.10.1 jaxtyping-0.2.14 typeguard-2.13.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/co

In [2]:
import array
import functools as ft
import gzip
import os
import struct
import urllib.request

import diffrax as dfx  # https://github.com/patrick-kidger/diffrax
import einops  # https://github.com/arogozhnikov/einops
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

import equinox as eqx

In [3]:
key = jr.PRNGKey(2023)



In [162]:
t1 = 10.0

In [201]:
def cifar():
  from tensorflow.keras.datasets import cifar10
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  
  set1 = jnp.array(x_train)
  set2 = jnp.array(x_test)

  lab1 = jnp.array(y_train)
  lab2 = jnp.array(y_test)

  im = jnp.concatenate((set1, set2))
  lab = jnp.concatenate((lab1, lab2))
  im = einops.rearrange(im, "n h w c -> n c h w")
 

  
  return im, lab

In [177]:
@eqx.filter_jit
def modulate(x, shift, scale):
    scale = jnp.expand_dims(scale, axis=1)       #scale is the value you mutliply your array by
    shift = jnp.expand_dims(shift, axis=1)       #shift is the value you move your array b                     
    return x * (1 + scale) + shift #1+scale to allow for zero scaling

    #e.g. by inputting (x, 0 ,0) you end up with x returned
    #this is just a very intuitive function!

In [176]:
@eqx.filter_jit
def get_2d_sincos_pos_embed(n_embd, grid_size):

    """grid_size: int value of grid height and width - we denote by (H, W)
       return pos_embed of dim (grid_size*grid_size, n_embd)"""
    

    grid_h = jnp.arange(grid_size, dtype=float)
    grid_w = jnp.arange(grid_size, dtype=float)
    grid = jnp.meshgrid(grid_w, grid_h)
    grid = jnp.stack(grid, axis=0)

    grid = jnp.reshape(grid, (2, 1, grid_size, grid_size))
    pos_embed = get_2d_sincos_pos_embed_from_grid(n_embd, grid)
    return pos_embed

#we basically end up with each patch embedding having a fixed n_embd dimensional sin/cos embedding vector - this is fixed! doesn't change based on sample but allows the NN to understand the "spatial" representation of the patches
##################################################################################

@eqx.filter_jit
def get_2d_sincos_pos_embed_from_grid(n_embd, grid):
    assert n_embd % 2 == 0

    emb_h = get_1d_sincos_pos_embed_from_grid(n_embd // 2, grid[0]) # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(n_embd // 2, grid[1]) # (H*W, D/2)

    emb = jnp.concatenate([emb_h, emb_w], axis=1)  #(H*W, D)
    return emb

#if we have a grid - think two rows of however many columns - we take each row and encode and get an positional embedding for each element in the row

##################################################################################

@eqx.filter_jit
def get_1d_sincos_pos_embed_from_grid(n_embd, pos):
    """n_embd is the output dimension for each position (D,)
       pos is an array of positions to be encoded of size (M,)
       size is of pos (M,D)"""
    
    assert n_embd % 2 == 0                
    omega = jnp.arange(n_embd // 2, dtype=float)   #generates an array from 0 to (n_embd // 2) - 1 of integer values    
    omega /= n_embd / 2         #divide by n_embd / 2 - can be thought of as normalising the values from to 0 to 1 i.e. uniform values
    omega = 1. / 10000**omega #(D/2) - so n_embd is D  
   
    pos = jnp.array(pos)
    out = jnp.outer(pos, omega)   #so we have pos^T * omega to make a matrix of dim (M, D/2)
    emb_sin = jnp.sin(out)   #dim (M, D/2)
    emb_cos = jnp.cos(out)   #dim (M, D/2)

    emb = jnp.concatenate([emb_sin, emb_cos], axis=1)  #dim (M, D) - join sin and cos column wise -> <-
    return emb


    #if we think of the input as a sequence of words and the respective n_embedding, then we are simply adding a positional embedding across each word in the sequence 

In [256]:
#############################################
##############                ###############
##############    DiT model   ###############
##############                ###############
#############################################

"""Diffusion models meet Transformers!"""


###########   Time embedding    #############



###### Define silu activation ######

from typing import Callable

class Lambda1(eqx.Module):
    fn: Callable
    
    def __call__(self, x, *, key=None):
        return self.fn(x)

###### Time embedding ######

class TimeStepEmbedder(eqx.Module):
    mlp: eqx.nn.Sequential
    frequency_embedding_size: int
    
    def __init__(
        self,
        hidden_size,
        frequency_embedding_size,   #set as 256
        key
    ):
        l1key, l2key = jr.split(key, 2)
        self.mlp = eqx.nn.Sequential([
            eqx.nn.Linear(frequency_embedding_size, hidden_size, key=l1key),
            Lambda1(jax.nn.silu),
            eqx.nn.Linear(hidden_size, hidden_size, key=l2key)
        ])
        self.frequency_embedding_size = frequency_embedding_size

    def __call__(self, t, max_period=10000):
        dim = self.frequency_embedding_size
        half = dim // 2
        freqs = jnp.exp(
            -jnp.log(max_period) * jnp.arange(0, half, dtype=float) / half
        )
        args = t[:, None].astype(float) * freqs[None]
        embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
        if dim % 2:
            embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
        t_freq = embedding
        t_emb = jax.vmap(self.mlp)(t_freq)
        return t_emb

###########  Label Embedding   #########

class LabelEmbedder(eqx.Module):
  embedding_table: eqx.nn.Embedding
  num_classes: int
  dropout_prob: float

  def __init__(self, num_classes, hidden_size, dropout_prob, key):
      key1 = jr.split(key, 1)[0]
      self.embedding_table = eqx.nn.Embedding(num_classes, hidden_size, key=key1)
      self.num_classes = num_classes
      self.dropout_prob = dropout_prob

  def __call__(self, labels):
      use_dropout = self.dropout_prob
      embeddings = self.embedding_table(labels)
      return embeddings


##########   Multi-Head Attention   #########


class MHA(eqx.Module):
    num_heads: int
    head_dim: int
    scale: int

    qkv: eqx.nn.Linear
    o_proj: eqx.nn.Linear



    def __init__(self, dim, key, num_heads=8, ):
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        key1, key2 = jr.split(key, 2)
        self.qkv = eqx.nn.Linear(dim, dim *3, key=key1)
        self.o_proj = eqx.nn.Linear(dim, dim, key=key2)

    def __call__(self, x, mask=None):
        B, N, C = x.shape
        qkv = jax.vmap(jax.vmap(self.qkv))(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.transpose(2, 0, 3, 1, 4)

        q, k, v =jnp.moveaxis(qkv, 0, 0)

        q = q * self.scale
        attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
        if mask is not None:
            attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
        attention = jax.nn.softmax(attn_logits, axis=-1)
        values = jnp.matmul(attention, v)
        values = einops.rearrange(values, "B C N H -> B N (C H)")

        values = jax.vmap(jax.vmap(self.o_proj))(values)
        return values


##############    DiT Block     ##############


class DitBlock(eqx.Module):
    norm1: eqx.nn.LayerNorm
    norm2: eqx.nn.LayerNorm
    attn: eqx.Module
    Mlp: eqx.nn.Sequential
    adaLN_modulation: eqx.nn.Sequential

    def __init__(
        self,
        hidden_size,
        n_head,
        mlp_ratio,   # = 4.0
        key,
    ):
        m1key, m2key, adakey, attkey = jr.split(key, 4)
        self.norm1 = eqx.nn.LayerNorm(hidden_size, eps = 1e-06, elementwise_affine=False)
        self.attn = MHA(hidden_size, key=attkey, num_heads=n_head,)
        self.norm2 = eqx.nn.LayerNorm(hidden_size, eps = 1e-06, elementwise_affine=False)
        mlp_hidden_size = int(hidden_size * mlp_ratio)
        self.Mlp = eqx.nn.Sequential([
            eqx.nn.Linear(hidden_size, mlp_hidden_size, key=m1key),
            Lambda1(jax.nn.gelu),
            eqx.nn.Linear(mlp_hidden_size, hidden_size, key=m2key) ])
        self.adaLN_modulation = eqx.nn.Sequential([
            Lambda1(jax.nn.silu),
            eqx.nn.Linear(hidden_size, 6 * hidden_size, key=adakey)
        ])

    def __call__(self, x, t):

        temp = jax.vmap(self.adaLN_modulation)(t)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.array_split(temp, 6, axis=1)
        
        
        gate_msa = jnp.expand_dims(gate_msa, axis=1)
        a = jax.vmap(self.norm1)(x)
        tem = modulate(a, shift_msa, scale_msa)
        x = x + gate_msa * self.attn(tem)
        
        
        gate_mlp = jnp.expand_dims(gate_mlp, axis=1)
        b = jax.vmap(self.norm2)(x)
        tems = modulate(b, shift_mlp, scale_mlp)     
        x = x + gate_mlp * jax.vmap(jax.vmap(self.Mlp))(tems)
        
        
        return x



#################   Final Layer   ################



class FinalLayer(eqx.Module):
    norm_final: eqx.nn.LayerNorm
    linear: eqx.nn.Linear
    adaLN_modulation: eqx.nn.Sequential
    
    def __init__(
        self,
        hidden_size,
        patch_size,
        out_channels,
        key
    ):
        lkey, adakey = jr.split(key, 2)
        self.norm_final = eqx.nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
        self.linear = eqx.nn.Linear(hidden_size, patch_size * patch_size * out_channels, key=lkey)
        self.adaLN_modulation = eqx.nn.Sequential([
            Lambda1(jax.nn.silu),
            eqx.nn.Linear(hidden_size, 2 * hidden_size, key=adakey)
        ])

    def __call__(self, x, t):
        temp = jax.vmap(self.adaLN_modulation)(t)
        shift, scale = jnp.array_split(temp, 2, axis=1)
        x = modulate(jax.vmap(self.norm_final)(x), shift, scale)
        x = jax.vmap(jax.vmap(self.linear))(x)
        return x


###########   Patch embedding   ##############

class PatchEmbed(eqx.Module):
    num_patches:int
    proj: eqx.nn.Conv2d
    patch_size: int

    def __init__(
        self,
        img_size,
        patch_size,
        in_chans,
        n_embd,
        key
    ):
        patkey, _ = jr.split(key,2)
        self.patch_size = patch_size
        dg = img_size // self.patch_size
        self.num_patches = dg ** 2
        self.proj = eqx.nn.Conv2d(in_chans, n_embd, self.patch_size, self.patch_size, key=patkey)

    def __call__(self, x):
        B, C, H, W = x.shape
        x = jnp.array(x, dtype=float)
        x = jax.vmap(self.proj)(x)
        x = einops.rearrange(x, "B C H W -> B (H W) C")
        return x



###########   Parameter module    ##########

class Params(eqx.Module):
    param: jnp.ndarray

    def __init__(self, num_patches, hidden_size):
        self.param = jnp.zeros((1, num_patches, hidden_size), dtype = float)

    def __call__(self):
        return self.param


##########    DiT   ##########


class DiT(eqx.Module):
    in_channels: int
    out_channels: int
    patch_size: int
    n_head: int
    t1: float

    x_embedder: eqx.Module
    t_embedder: eqx.Module
    y_embedder: eqx.Module
    pos_embed: eqx.Module
    blocks: list
    final_layer: eqx.Module

    def __init__(
        self,
        input_size=32,
        patch_size=4,
        in_channels=3,
        hidden_size=128,
        depth=2,  
        n_head=4,  
        mlp_ratio=4.0,  #fixed
        frequency_embedding_size=256,   #fixed
        class_dropout_prob=0.1,
        num_classes=10,
        *,
        key=key,
        
    ):
        xkey, tkey, flkey, embkey, *dbkeys = jr.split(key, 4 + depth)
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.patch_size = patch_size
        self.n_head = n_head
        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, key=xkey)
        self.t_embedder = TimeStepEmbedder(hidden_size, frequency_embedding_size, key=tkey)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, key=embkey)
        num_patches = self.x_embedder.num_patches
        self.pos_embed = Params(num_patches, hidden_size)
        self.blocks = [
            DitBlock(
                hidden_size, n_head, mlp_ratio, key = key
            )
            for dbkey in dbkeys                                   #_ in range(depth)           #*bkeys = jr.split(key, num_blocks)
        ]
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, key=flkey)

        pos_embed = get_2d_sincos_pos_embed(self.pos_embed().shape[-1], int(self.x_embedder.num_patches ** 0.5))
        pos_embed = jnp.array(pos_embed, dtype=float)
        pos_embed = jnp.expand_dims(pos_embed, axis=0)
        self.pos_embed = pos_embed.copy()

        self.t1 = t1
    def unpatchify(self, x):
        """
        x: (N, T, patch_size ** 2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels      
        p = self.x_embedder.patch_size 
        h = w = int(x.shape[1] ** 0.5)    
        x = jnp.reshape(x, (x.shape[0], h, w, p, p, c))
        x = einops.rearrange(x, "n h w p q c->n c h p w q")
        imgs = jnp.reshape(x, (x.shape[0], c, h * p, h * p))
        return imgs
    
    def __call__(self, x, t, y):
        #pos_embed = get_2d_sincos_pos_embed(self.pos_embed().shape[-1], int(self.x_embedder.num_patches ** 0.5))
        
        
        """
        x: (N, C, H, W)
        t: (N, )
        """
        #t = jnp.array([t], dtype=int)
        print(self.pos_embed)
        t = t/self.t1
        x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)   # (N, D)

        y = self.y_embedder(y)

        t = t+y
  
        for block in self.blocks:
            x = block(x,t)    # (N, T, D)
        x = self.final_layer(x, t)     # (N, T, patch_size ** 2 * out_channels) - N is the batch_size, T is the number of patches, 
        x = self.unpatchify(x)        # (N, out_channels, H, W)
        #x = jnp.squeeze(x, axis=0)
        return x

In [258]:
model = DiT()

In [202]:
images, labels = cifar()

In [203]:
eg = images[0:1]

In [204]:
eg.shape

(1, 3, 32, 32)

In [205]:
egl = labels[0:1]

In [None]:
egl=jnp.squeeze(egl, axis=1)

In [230]:
t= jnp.array([5], dtype=float)

In [264]:
model(eg, t, egl)

[[[ 0.          0.          0.         ...  1.          1.
    1.        ]
  [ 0.84147096  0.68156135  0.53316844 ...  1.          1.
    1.        ]
  [ 0.9092974   0.99748     0.9021307  ...  1.          1.
    1.        ]
  ...
  [-0.9589243  -0.5711271   0.32393527 ...  0.9999986   0.9999992
    0.9999996 ]
  [-0.2794155  -0.97739613 -0.23036751 ...  0.9999986   0.9999992
    0.9999996 ]
  [ 0.6569866  -0.8593135  -0.7137213  ...  0.9999986   0.9999992
    0.9999996 ]]]


Array([[[[-0.34200123,  0.2966537 , -0.12871987, ...,  0.2676334 ,
           0.40355986, -0.01633361],
         [-0.08845013,  0.36638242, -0.11904924, ...,  0.20065154,
          -0.37602353, -0.4027549 ],
         [ 1.1493251 ,  0.04737014,  0.44141626, ...,  0.45272228,
           0.21076703, -0.29202688],
         ...,
         [-0.5765127 ,  0.158071  , -0.5851695 , ...,  0.2035378 ,
          -0.31120777, -0.07466888],
         [ 1.9556712 ,  0.7470492 ,  0.2461924 , ...,  0.33781418,
           0.76516306,  0.39535636],
         [ 0.61308324,  0.82563716,  0.00530211, ...,  0.63807684,
          -0.26639882,  0.6060345 ]],

        [[ 0.30808845, -0.19852664,  0.31034914, ..., -0.42339617,
           0.3716297 ,  0.05619001],
         [ 0.14791155, -0.45064726, -0.4066413 , ..., -0.8313336 ,
          -0.6323137 , -0.05502474],
         [-0.39554676,  0.08149413,  0.140227  , ...,  0.33878013,
           0.4687193 , -0.5533159 ],
         ...,
         [ 0.49349555, -1.372494  