In [1]:
%%writefile models/resnet.py
import torch
from torch import nn
from torch.nn import functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float=0.0):
        super().__init__()

        self.groupnorm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        self.groupnorm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if in_channels != out_channels:
            self.proj_input = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        else:
            self.proj_input = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_norm = self.groupnorm_1(x)
        out = F.silu(x_norm)
        out = self.conv_1(out)

        out = self.groupnorm_2(out)
        out = F.silu(out)
        out = self.dropout(out)
        out = self.conv_2(out)
        
        return out + self.proj_input(x)
        

Overwriting models/resnet.py


In [2]:
%%writefile models/vae.py
import torch
from torch import nn
from torch.nn import functional as F
from .resnet import ResidualBlock
from .attention import MultiheadSelfAttention
from typing import List

class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.pad = (0, 1, 0, 1)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.pad(x, self.pad)
        x = self.conv(x)
        return x

class UpSample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = self.conv(x)
        return x
        
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.attn = MultiheadSelfAttention(num_heads=1, embedding_dim=in_channels)

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (n, c, h, w)
        batch_size, channels, h, w = x.shape
        x_norm = self.groupnorm(x)
        # (n, c, h, w) -> (n, c, h * w) -> (n, h * w, c)
        x_norm = x_norm.view((batch_size, channels, -1)).transpose(1, 2)

        # (n, h * w, c)
        out = self.attn(x=x_norm)

        # (n, h * w, c) -> (n, c, h * w) -> (n, c, h, w)
        out = out.transpose(1, 2).reshape(x.shape)
        
        return out + x
        
class VAE_Encoder(nn.Module):
    def __init__(self, in_channels: int, ch_mult: List[int]=[1, 2, 4, 8], dropout: float=0.0, z_channels: int=8):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, 128, kernel_size=3, stride=1, padding=1)

        # start
        self.down = nn.ModuleList()
        in_ch_mult = [1] + ch_mult
        ch = 128
        for i in range(len(ch_mult)):
            block_in = ch * in_ch_mult[i]
            block_out = ch * ch_mult[i]
            block = nn.Sequential(
                ResidualBlock(block_in, block_out, dropout),
                ResidualBlock(block_out, block_out, dropout),
            )
            
            down = nn.Module()
            down.block = block
            if i != len(ch_mult) - 1:
                down.downsample = Downsample(block_out)
            else:
                down.downsample = nn.Identity()
                
            self.down.append(down)
            curr_channels = block_out

        # middle
        self.mid = nn.Module()
        self.mid.res_block_1 = ResidualBlock(curr_channels, curr_channels)
        self.mid.attn_block_1 = AttentionBlock(in_channels=curr_channels)
        self.mid.res_block_2 = ResidualBlock(curr_channels, curr_channels)
        
        # end
        self.out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=curr_channels),
            nn.SiLU(),
            nn.Conv2d(curr_channels, 2*z_channels, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(2*z_channels, 2*z_channels, kernel_size=1, stride=1, padding=0))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_in(x)
        
        for down in self.down:
            x = down.block(x)
            x = down.downsample(x)

        x = self.mid.res_block_1(x)
        x = self.mid.attn_block_1(x)
        x = self.mid.res_block_2(x)

        x = self.out(x)
        return x


class VAE_Decoder(nn.Module):
    def __init__(self, ch_mult: List[int]=[1, 2, 4, 8], dropout: float=0.0, z_channels: int=8):
        super().__init__()

        ch = 128
        block_in = ch*ch_mult[-1]
        self.conv_in = nn.Sequential(nn.Conv2d(z_channels, z_channels, kernel_size=1, padding=0), 
                                     nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1))
        
    
        # mid
        self.mid = nn.Module()
        self.mid.res_block_1 = ResidualBlock(block_in, block_in)
        self.mid.attn_block_1 = AttentionBlock(in_channels=block_in)
        self.mid.res_block_2 = ResidualBlock(block_in, block_in)

        # upsampling
        self.up = nn.ModuleList()
        for i in reversed(range(len(ch_mult))):
            block_out = ch * ch_mult[i]
            block = nn.Sequential(
                ResidualBlock(block_in, block_out),
                ResidualBlock(block_out, block_out),
                ResidualBlock(block_out, block_out)
            )
            up = nn.Module()
            up.block = block
            if i != 0:
                up.upsample = UpSample(in_channels=block_out)
            else:
                up.upsample = nn.Identity()
            self.up.append(up)
            block_in = block_out

        self.out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=ch), 
            nn.SiLU(),
            nn.Conv2d(ch, 3, kernel_size=3, stride=1, padding=1))

        
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, channels, h, w)
        x /= 0.18215
        x = self.conv_in(x)

        for up in self.up:
            x = up.block(x)
            x = up.upsample(x)

        out = self.out(x)
        return out
        
        
class VAE(nn.Module):
    def __init__(self, in_channels: int=3, z_channels: int=8):
        super().__init__()
        self.encoder = VAE_Encoder(in_channels=in_channels)
        self.decoder = VAE_Decoder(z_channels=z_channels)

    def encode(self, x: torch.Tensor, noise=None) -> torch.Tensor:
        # z: (n, c, h, w)
        z = self.encoder(x)
        mean, log_variance = z.chunk(2, dim=1)
        log_variance = torch.clamp(log_variance, -20, 30)
        variance = log_variance.exp()
        stdev = torch.sqrt(variance)
        if noise:
            return mean + stdev * noise
        else:
            return mean + stdev * torch.randn_like(stdev)
        

    def decode(self, z: torch.Tensor):
        return self.decoder(z)   

Overwriting models/vae.py


In [3]:
%%writefile models/cond_encoder.py
import torch
from torch import nn
from .attention import MultiheadSelfAttention
from .activation_fn import QuickGELU

class TextEncoder(nn.Module):
    def __init__(self, n_vocab: int=49408, embed_dim: int=768, max_len: int=77):
        super().__init__()
        self.text_embedding = TextEmbedding(n_vocab=n_vocab, embed_dim=embed_dim, max_len=max_len)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoder(num_heads=12, embed_dim=embed_dim, ffn_dim=embed_dim*8) for _ in range(12)
        ])
        self.layernorm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## TODO: Padding text before putting into embedding
        
        x = self.text_embedding(x)

        for layer in self.encoder_layers:
            x = layer(x)

        x = self.layernorm(x)
        return x
        
class TextEmbedding(nn.Module):
    def __init__(self, n_vocab: int, embed_dim: int, max_len: int):
        super().__init__()
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(max_len, embed_dim), requires_grad=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        x += self.positional_encoding
        return x
        
class TransformerEncoder(nn.Module):
    def __init__(self, num_heads: int, embed_dim: int, ffn_dim: int, dropout: float=0.0):
        super().__init__()
        self.attn_1 = MultiheadSelfAttention(num_heads=num_heads, embedding_dim=embed_dim)
        self.dropout_1 = nn.Dropout(dropout)
        self.layernorm_1 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            QuickGELU(),
            nn.Linear(ffn_dim, embed_dim)
        )
        self.dropout_2 = nn.LayerNorm(embed_dim)
        self.layernorm_2 = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
        x = x.type(torch.long)
        skip_connection = x
        x = self.attn_1(x=x.type(torch.float), lookahead_mask=True)
        x = self.dropout_1(x)
        x = self.layernorm_1(x + skip_connection)

        skip_connection = x
        x = self.ffn(x)
        x = self.dropout_2(x)
        output = self.layernorm_2(x + skip_connection)
        return output

Overwriting models/cond_encoder.py


In [4]:
%%writefile models/attention.py
import torch
from torch import nn
import math
import logging
from typing import Optional

class MultiheadSelfAttention(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, cond_dim: int=768, use_bias=True):
        super().__init__()
        
        if not cond_dim:
            cond_dim = embedding_dim
            
        self.proj_q = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)
        self.proj_k = nn.Linear(cond_dim, embedding_dim, bias=use_bias)
        self.proj_v = nn.Linear(cond_dim, embedding_dim, bias=use_bias)
        self.num_heads = num_heads
        self.head_dim = embedding_dim // self.num_heads
        self.proj_out = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)

    def forward(self, x: torch.Tensor, cond: torch.Tensor=None, lookahead_mask: bool=True) -> torch.Tensor:
        # x: (n, seq_len, embedding_dim)
        # cond: (n, seq_len, dim)
        
        batch_size, seq_len, embedding_dim = x.shape
       
        if cond is None:
            cond = x
            
        q = self.proj_q(x)
        k = self.proj_k(cond)
        v = self.proj_v(cond)
            
        # (batch_size, seq_len, embedding_dim) -> (n, seq_len, num_heads, head_dim) -> (n, num_heads, seq_len, head_dim)
        q = q.view(*q.shape[:2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)

       

        # (n, num_heads, seq_len, head_dim) @ (n, num_heads, head_dim, seq_len) -> (n, seq_len, seq_len, seq_len)
        attn_weights = q @ k.transpose(-1, -2)
        if lookahead_mask:
            mask = torch.ones_like(attn_weights, dtype=torch.bool).triu(1)
            attn_weights.masked_fill_(mask, -torch.inf)
            
        attn_weights /= math.sqrt(self.head_dim)
        attn_weights = torch.softmax(attn_weights, dim=-1)

        # (n, num_heads, seq_len, seq_len) @ (n, num_heads, seq_len, head_dim) -> (n, num_heads, seq_len, head_dim)
        attn_weights = attn_weights @ v

        # (n, num_heads, seq_len, head_dim) -> (n, seq_len, num_heads, head_dim) -> (n, seq_len, embedding_dim)
        attn_weights = attn_weights.transpose(1, 2).reshape((batch_size, seq_len, embedding_dim))

        out = self.proj_out(attn_weights)

        return out

Overwriting models/attention.py


In [5]:
%%writefile models/unet.py
import torch
from torch import nn
from torch.nn import functional as F
from .attention import MultiheadSelfAttention
from .activation_fn import GeGELU
from typing import Optional, List

class UNet_TransformerEncoder(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, cond_dim: int=768):
        super().__init__()
        channels = embedding_dim * num_heads
        self.groupnorm = nn.GroupNorm(32, channels)
        self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

        self.transformer_block = UNet_AttentionBlock(num_heads=num_heads, embedding_dim=channels, cond_dim=cond_dim)

        self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

    def forward(self, x: torch.Tensor, cond: torch.Tensor=None) -> torch.Tensor:
        # x: (b, c, h, w)
        b, c, h, w = x.shape

        x_in = x

        x = self.groupnorm(x)
        x = self.conv_input(x)

        # (b, c, h, w) -> (b, c, h * w) -> (b, h * w, c)
        x = x.view(b, c, -1).transpose(-1, -2)

        x = self.transformer_block(x=x, cond=cond)

        x = x.transpose(-1, -2).view(b, c, h, w)

        x = self.conv_output(x)

        return x + x_in
        
class UNet_AttentionBlock(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, cond_dim: int=768):
        super().__init__()
        
        if embedding_dim % num_heads:
            raise ValueError('Number of heads must be divisible by Embedding Dimension')
            
        self.head_dim = embedding_dim // num_heads

        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.attn1 = MultiheadSelfAttention(num_heads=num_heads, embedding_dim=embedding_dim, cond_dim=cond_dim, use_bias=False)
        
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        self.attn2 = MultiheadSelfAttention(num_heads=num_heads, embedding_dim=embedding_dim, cond_dim=cond_dim, use_bias=False)

        self.layer_norm3 = nn.LayerNorm(embedding_dim)
        self.ff = nn.Sequential(
            GeGELU(embedding_dim, embedding_dim * 4),
            nn.Linear(embedding_dim * 4, embedding_dim))
        

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        x = self.attn1(self.layer_norm1(x), cond=cond) + x

        x = self.attn2(self.layer_norm2(x), cond=cond) + x

        x = self.ff(self.layer_norm3(x)) + x
        
        return x
        

class UNet_ResBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, t_embed_dim: int):
            super().__init__()
            
            self.groupnorm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
            self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

            self.groupnorm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
            self.conv_2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
            
            self.t_embed = nn.Linear(t_embed_dim, out_channels)
            
            if in_channels == out_channels:
                self.proj_input = nn.Identity()
            else:
                self.proj_input = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0)

    def forward(self, x: torch.Tensor, t_embed: torch.Tensor) -> torch.Tensor:
        # x: (n, c, h, w)
        h = self.groupnorm_1(x)
        h = F.silu(h)
        h = self.conv_1(h)

        # time: (1, t_embed_dim) -> (1, out_channels)
        time = F.silu(t_embed)
        time = self.t_embed(t_embed)

        # (n, out_channels, h, w) + (1, out_channels, 1, 1) -> (n, out_channels, h, w)
        h = h + time[:, :, None, None]

        h = self.groupnorm_2(h)
        h = F.silu(h)
        h = self.conv_2(h)
        return h + self.proj_input(x)

class TimeEmbedding(nn.Module):
    def __init__(self, t_embed_dim: int=320):
        super().__init__()
        self.t_embed_dim = t_embed_dim
        self.ffn = nn.Sequential(
            # (1, 320) -> (1, 1280)
            nn.Linear(t_embed_dim, t_embed_dim * 4),
            nn.SiLU(),
            # (1, 1280) -> (1, 1280)
            nn.Linear(t_embed_dim * 4,  t_embed_dim * 4))

    def _get_time_embedding(self, timestep):
        half = self.t_embed_dim // 2
        freqs = torch.pow(1000, -torch.arange(0, half, dtype=torch.float32)/half)
        x = torch.tensor([timestep], dtype=torch.float32, device=timestep.device)[None, :] * freqs[None, :].to(timestep.device)
        return torch.cat([torch.cos(x), torch.sin(x)], dim=1)
            
    def forward(self, timestep: int) -> torch.Tensor:
        t_embed = self._get_time_embedding(timestep)
        print(t_embed.device)
        return self.ffn(t_embed)

class TimeStepSequential(nn.Sequential):
    def forward(self, x: torch.Tensor, t_embed: torch.Tensor, cond=None) -> torch.Tensor:
        for layer in self:
            if isinstance(layer, UNet_ResBlock):
                x = layer(x, t_embed)
            elif isinstance(layer, UNet_TransformerEncoder):
                x = layer(x, cond)
            else:
                x = layer(x)
        return x
        
class UNet_Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)

class UNet_Upsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(self.upsample(x))
    
class UNet_Encoder(nn.Module):
    def __init__(self, in_channels: int=8, num_heads: int=8, t_embed_dim: int=1280, cond_dim: int=768, ch_multiplier=[1, 2, 4, 4]):
        super().__init__()
        ch = 320
        
        self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
        
        self.down = nn.ModuleList()
        in_ch_multiplier = [1] + ch_multiplier
        
        for i in range(len(ch_multiplier)):
            down = nn.Module()
            in_channels = ch * in_ch_multiplier[i]
            out_channels = ch * ch_multiplier[i]
            block = TimeStepSequential(
                UNet_ResBlock(in_channels, out_channels, t_embed_dim), 
                UNet_TransformerEncoder(num_heads=num_heads, embedding_dim=out_channels // num_heads, cond_dim=cond_dim),
                UNet_ResBlock(out_channels, out_channels, t_embed_dim), 
                UNet_TransformerEncoder(num_heads=num_heads, embedding_dim=out_channels // num_heads, cond_dim=cond_dim)
            )
            if i != len(ch_multiplier) - 1:
                downsample = UNet_Downsample(out_channels)
            else:
                downsample = nn.Identity()
            
            down.block = block
            down.downsample = downsample
            
            self.down.append(down)
            
    def forward(self, x: torch.Tensor, t_embed: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:
        
        x = self.conv_in(x)
        skip_connections = [x]
        for down in self.down:
            x = down.block(x, t_embed, cond)
            skip_connections.append(x)
            x = down.downsample(x)
            
        return x, skip_connections

class UNet_Decoder(nn.Module):
    def __init__(self, num_heads: int=8, t_embed_dim: int=1280, cond_dim: int=768, ch_multiplier=[1, 2, 4, 4]):
        super().__init__()
        ch = 320
        in_ch_multiplier = [1] + ch_multiplier
        
        self.up = nn.ModuleList()
        for i in reversed(range(4)):
            up = nn.Module()
            in_ch = in_ch_multiplier[i+1] * ch
            out_ch = in_ch_multiplier[i] * ch
            block = TimeStepSequential(
                UNet_ResBlock(in_ch * 2, out_ch, t_embed_dim), 
                UNet_TransformerEncoder(num_heads=num_heads, embedding_dim=out_ch // num_heads, cond_dim=cond_dim),
                UNet_ResBlock(out_ch, out_ch, t_embed_dim), 
                UNet_TransformerEncoder(num_heads=num_heads, embedding_dim=out_ch // num_heads, cond_dim=cond_dim),
                UNet_ResBlock(out_ch, out_ch, t_embed_dim), 
                UNet_TransformerEncoder(num_heads=num_heads, embedding_dim=out_ch // num_heads, cond_dim=cond_dim)
            )
            
            if i != 0:
                upsample = UNet_Upsample(out_ch)
            else:
                upsample = nn.Identity()

            up.block = block
            up.upsample = upsample

            self.up.append(up)

    def forward(self, x: torch.Tensor, skip_connections: List[torch.Tensor], t_embed: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:
        # x: (b, c, h, w)
        for up in self.up:
            x = torch.cat([x, skip_connections.pop()], dim=1)
            x = up.block(x, t_embed, cond)
            x = up.upsample(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_channels: int=8, out_channels: int=8, num_heads: int=8, t_embed_dim: int=320, cond_dim: int=768):
        super().__init__()
        self.time_embedding = TimeEmbedding(t_embed_dim)
        self.encoder = UNet_Encoder(in_channels=in_channels, num_heads=num_heads, t_embed_dim=t_embed_dim * 4, cond_dim=cond_dim)
        self.bottle_neck = TimeStepSequential(
            UNet_ResBlock(1280, 1280, t_embed_dim * 4),
            UNet_TransformerEncoder(num_heads=8, embedding_dim=160, cond_dim=cond_dim),
            UNet_ResBlock(1280, 1280, t_embed_dim * 4)
        )
        self.decoder = UNet_Decoder(num_heads=num_heads, t_embed_dim=t_embed_dim * 4, cond_dim=cond_dim)
        self.output = nn.Sequential(
            nn.GroupNorm(32, 320),
            nn.SiLU(),
            nn.Conv2d(320, out_channels, kernel_size=3, stride=1, padding=1))

    def forward(self, x: torch.Tensor, timestep: int, cond: torch.Tensor) -> torch.Tensor:

        # t: int -> (1, 1280)
        t_embed = self.time_embedding(timestep)
        
        x, skip_connections = self.encoder(x, t_embed, cond)
        x = self.bottle_neck(x, t_embed, cond)
        x = self.decoder(x, skip_connections, t_embed, cond)
        
        output = self.output(x)
        return output
        

Overwriting models/unet.py


In [6]:
%%writefile models/diffusion.py
import torch
from torch import nn
import torchvision
from torchvision import transforms
from .vae import VAE
from .unet import UNet
from .cond_encoder import TextEncoder
import numpy as np
from tqdm.auto import tqdm
from .utils import denormalize_img
from PIL import Image


IMG_HEIGHT = 64
IMG_WIDTH = 64
Z_HEIGHT = 64 // 8
Z_WIDTH = 64 // 8

class StableDiffusion:
    def __init__(self, noise_step: int=1000, beta_start: float=1e-4, beta_end: float=0.02):
        self.betas = torch.linspace(beta_start, beta_end, noise_step, dtype=torch.float)
        self.alphas = 1 - self.betas
        self.alphas_hat = torch.cumprod(self.alphas, dim=0)
        self.noise_step = noise_step

        self.timesteps = torch.from_numpy(np.arange(0, noise_step)[::-1].copy())
        self.vae = VAE()
        self.unet = UNet()
        self.cond_encoder = TextEncoder()

    def _set_inference_step(self, inference_steps=50):
        self.inference_steps = inference_steps
        ratio = self.noise_step // self.inference_steps
        self.timesteps = torch.from_numpy((np.arange(0, self.inference_steps) * ratio).round()[::-1].copy().astype(np.int64))
        

    def _get_prev_timestep(self, timestep: int):
        prev_t = timestep - self.noise_step // self.inference_steps
        return prev_t

    def set_strength(self, strength: float=0.8):
        start_t = self.inference_steps - (self.inference_steps * strength)
        self.timesteps = self.timesteps[start_t:]

    # x_t ~ q(x_t | x_0) = N(x_t, sqrt(a_hat_t) * x_0, sqrt(1 - a_hat_t) * I)
    def forward_process(self, x_0: torch.Tensor, timestep: int):
        # x_0: (b, c, h, w)
        t = timestep
        # (1,) -> (1, 1, 1, 1)
        alpha_hat_t = self.alphas_hat[t, None, None, None].to(x_0.device)
              
        noise = torch.randn_like(x_0, dtype=torch.float32, device=x_0.device)
        latent = torch.sqrt(alpha_hat_t) * x_0 + torch.sqrt(1 - alpha_hat_t) * noise
        
        return latent, noise

    # x_(t-1) ~ p(x_(t-1) | x_t) = N(x_(t - 1), mu_theta(x_t, x_0), beta_tilda_t * I)
    # mu_theta(x_t, x-0)1/sqrt(alpha_hat_t) * (x_t - (1 - alpha_t)/sqrt(1 - alpha_hat_t) * epsilon_t)
    # beta_tilda_t = (1 - alpha_hat_(t-1)) / (1 - alpha_t) * beta_t
    def reverse_process(self, x_t: torch.Tensor, timestep: int, model_output=torch.Tensor) -> torch.Tensor:
        t = timestep
        prev_t = self._get_prev_timestep(t)

        
        alpha_t = self.alphas[t, None, None, None].to(x_t.device)
        alpha_hat_t = self.alphas_hat[t, None, None, None].to(x_t.device)
        prev_alpha_hat_t = self.alphas_hat[prev_t] if prev_t >= 0 else torch.tensor(1.0)
        prev_alpha_hat_t = prev_alpha_hat_t.to(x_t.device)
        
        mu = 1/torch.sqrt(alpha_t) * (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_hat_t) * model_output)

        stdev = 0
        if t > 0:
            variance = (1 - prev_alpha_hat_t) / (1 - alpha_hat_t) * self.betas[t]
            variance = torch.clamp(variance, min=1e-20)
            stdev = torch.sqrt(variance)
            
        noise = torch.randn_like(x_t, dtype=torch.float32, device=x_t.device)
        less_noise_sample = mu + stdev * noise
        return less_noise_sample
    
    def generate(self, input_image: Image, 
                 transforms: torchvision.transforms,
                 prompt: str,
                 uncond_promt: str,
                 do_cfg: bool,
                 cfg_scale: int,
                 device: torch.device,
                 strength:float,
                 inference_steps: int,
                 tokenizer=None) -> torch.Tensor:
        
        z_shape = (1, 8, Z_HEIGHT, Z_WIDTH)
        with torch.inference_mode():
            # Encoding Condition
            self.cond_encoder.to(device)
            if do_cfg:
                cond_tokens = torch.tensor(tokenizer.batch_encode_plus([prompt], padding='max_length', max_length=77).input_ids, dtype=torch.long, device=device)
                uncond_tokens = torch.tensor(tokenizer.batch_encode_plus([uncond_promt], padding='max_length', max_length=77).input_ids, dtype=torch.long, device=device)
    
                context = torch.cat([cond_tokens, uncond_tokens], dim=0)
                context_embedding = self.cond_encoder(context)
    
            else:
                cond_tokens = torch.tensor(tokenizer.batch_encode_plus([prompt], padding='max_length', max_length=77).input_ids, dtype=torch.long, device=device)
                context_embedding = self.cond_encoder(cond_tokens)
                
            self.cond_encoder.to('cpu')
    
            self._set_inference_step(inference_steps)
    
            # Encoding Image
            self.vae.to(device)
            if input_image:
                input_image = input_image.resize(IMG_HEIGHT, IMG_WIDTH)
                input_image = np.array(input_image)
                input_image = torch.from_array(input_image, dtype=torch.float32, device=device)
                input_image = input_image.unsqueeze(0)
                input_image = input_image.permute(0, 3, 1, 2)
                
                transformed_img = transforms(input_image)
                latent_features =  self.vae.encode(transformed_img)
    
                self.set_strength(strength=strength)
                latent_features = self.forward_process(latent_features, self.timesteps[0])
                
            else:
                latent_features = torch.randn(z_shape, dtype=torch.float32, device=device)
            self.vae.to('cpu')

            # Denoising
            timesteps = tqdm(self.timesteps.to(device))
            # x_t: torch.Tensor, 
            # timestep: int, 
            # model_output=torch.Tensor
    
            self.unet.to(device)
            for i, timestep in enumerate(timesteps):
                # (b, 8, latent_height, latent_width)
                model_input = latent_features
                if do_cfg:
                    model_input = model_input.repeat(2, 1, 1, 1)

                pred_noise = self.unet(model_input, timestep, context_embedding)
                
                if do_cfg:
                    cond_output, uncond_output = pred_noise.chunk(2)
                    pred_noise = cfg_scale * (cond_output - uncond_output) + uncond_output
    
                latent_features = self.reverse_process(latent_features, timestep, pred_noise)
            self.unet.to('cpu')
    
            self.vae.to(device)
            generated_imgs = self.vae.decode(latent_features)
            self.vae.to('cpu')
    
            generated_imgs = denormalize_img(generated_imgs, (-1, 1), (0, 255), clamp=True)
            generated_imgs = generated_imgs.permute(0, 2, 3, 1)
            generated_imgs = generated_imgs.to('cpu', torch.uint8).numpy()
            return generated_imgs[0]
    
        

Overwriting models/diffusion.py


In [7]:
from models.diffusion import StableDiffusion
import torch
from torchvision import transforms
from transformers import CLIPTokenizer
from PIL import Image


prompt = "A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
uncond_prompt = ""  # Also known as negative prompt
do_cfg = True
cfg_scale = 8  # min: 1, max: 14

## IMAGE TO IMAGE

input_image = None
# Comment to disable image to image
# image_path = "../images/dog.jpg"
# input_image = Image.open(image_path)
# Higher values means more noise will be added to the input image, so the result will further from the input image.
# Lower values means less noise is added to the input image, so output will be closer to the input image.
strength = 0.9

## SAMPLER
num_inference_steps = 50
seed = 4

model = StableDiffusion()

output = model.generate(input_image=None,
                        prompt=prompt,
                        uncond_promt=uncond_prompt,
                        do_cfg=do_cfg,
                        cfg_scale=cfg_scale,
                        device='mps',
                        strength=strength,
                        inference_steps=num_inference_steps,
                        transforms=transforms.ToTensor(),
                        tokenizer=CLIPTokenizer('./weights/clip/tokenizer_vocab.json', merges_file='./weights/clip/tokenizer_merges.txt')
                       )


  from .autonotebook import tqdm as notebook_tqdm
  0%|                                                                                                   | 0/50 [00:00<?, ?it/s]

mps:0


  0%|                                                                                                   | 0/50 [00:06<?, ?it/s]


UnboundLocalError: local variable 'pred_alpha_hat_t' referenced before assignment

In [None]:
from matplotlib import pyplot as plt
plt.imshow(output)