In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional

## Coding the VAE

In [None]:
class SelfAttention(nn.Module):
    """
    Self Attention mechanism for sequence data.
    
    Attributes:
        scale (float): Scaling factor for the attention scores.
        n_heads (int): Number of attention heads.
        d_head (int): Dimension of each attention head.
        QKV (nn.Linear): Linear layer for Query, Key, Value.
        O (nn.Linear): Linear output layer.
    """
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        """
        Initializes the SelfAttention class.
        
        Args:
            d_embed (int): Dimension of the embedding.
            n_heads (int): Number of attention heads. Defaults to 4.
            qkv_bias (bool): If True, adds bias to QKV linear layer. Defaults to True.
            out_bias (bool): If True, adds bias to O linear layer. Defaults to True.
        """
        
        super().__init__()
        
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.scale = self.d_head ** -0.5
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        """
        Forward pass for the SelfAttention mechanism.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_embed).
            mask (bool): If True, applies the attention mask. Defaults to False.

        Returns:
            torch.Tensor: Processed tensor.
        """
        
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        if not mask:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)


In [None]:
class VaeAttentionBlock(nn.Module):
    
    def __init__(self, n_channels: int) -> None:
        super().__init__()
        self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=n_channels)
        self.attention = SelfAttention(n_channels)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residue = x 
        n,c,h,w = x.shape # BxCxHxW
        x = x.view(n,c,h*w) # BxCxH*W
        x = x.transpose(-1, -2) # BxH*WxC
        attn = self.attention(x) # BxH*WxC
        x = x.transpose(-1, -2).view((n, c, h, w)) # BxCxHxW
        x = x + residue
        return x

In [None]:
class VaeResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.groupnorm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.conv_1(F.silu(self.groupnorm_1(x)))
        res = self.conv_2(F.silu(self.groupnorm_2(res)))
        return res + self.skip(x)

In [None]:
class VaeEncoder(nn.Sequential):
    def __init__(self) -> None:
        layers = [
            # Initial layers
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, padding=1),

            # Residual blocks 512x512
            VaeResidualBlock(in_channels=128, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),

            # Downsample to 256x256
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=128, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),

            # Downsample to 128x128
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=256, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Downsample to 64x64
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Attention block 64x64
            VaeAttentionBlock(n_channels=512),
            
            # Additional residual block 64x64
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Final layers
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.SiLU(),
            nn.Conv2d(in_channels=512, out_channels=8, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=8, out_channels=8, kernel_size=1)
        ]
        super(VaeEncoder, self).__init__(*layers)
        
    def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        for layer in self:
            if getattr(layer, 'stride', None) == (2,2):
                x = F.pad(x, (0,1,0,1))
            x = layer(x)
        # Bx8x64x64 -> (2x) of Bx4x64x64
        mean, log_var = x.chunk(2, dim=1)
        std = log_var.clamp(-30, 20).exp().sqrt()
        latent = (mean + std * noise) * 0.18215
        return latent # Bx4x64x64


In [None]:
import torch.nn as nn

class VaeDecoder(nn.Sequential):
    def __init__(self) -> None:
        layers = [
            # Initial layers
            nn.Conv2d(in_channels=4, out_channels=4, kernel_size=1, padding=0),
            nn.Conv2d(in_channels=4, out_channels=512, kernel_size=3, padding=1),
            
            # Residual blocks 64x64
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeAttentionBlock(n_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            
            # Upsampling to 128x128
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            
            # Upsampling to 256x256
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=512, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),
            
            # Upsampling to 512x512
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=256, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),
            
            nn.GroupNorm(num_groups=32, num_channels=128),
            nn.SiLU(),
            nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, padding=1)
        ]
        super(VaeDecoder, self).__init__(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x /= 0.18215
        for layer in self: x = layer(x)
        return x


In [None]:
class VAE(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.encoder = VaeEncoder()
        self.decoder = VaeDecoder()
        
    def encode(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        return self.encoder(x, noise)

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(x)
        
    def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        latent = self.encoder(x, noise)
        return self.decoder(latent)

In [None]:
vae = VAE()

In [None]:
def print_parameters(model):
    for name, param in model.state_dict().items():
        print(name)
        
# Assuming 'model' is your Pytorch model
print_parameters(vae)

In [None]:
device = 'mps'
x = torch.randn(1, 3, 512, 512).to(device)
vae = VAE().to(device)
enc_x = vae.encode(x, torch.randn(1, 4, 64, 64).to(device))

In [None]:
input_file = '../data/checkpoints/v1-5-pruned-emaonly.ckpt'
device = 'mps'
original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
converted = {}

In [None]:
converted['vae'] = {}

converted['vae']['encoder.0.weight'] = original_model['first_stage_model.encoder.conv_in.weight']
converted['vae']['encoder.0.bias'] = original_model['first_stage_model.encoder.conv_in.bias']
converted['vae']['encoder.1.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.0.block.0.norm1.weight']
converted['vae']['encoder.1.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.0.block.0.norm1.bias']
converted['vae']['encoder.1.conv_1.weight'] = original_model['first_stage_model.encoder.down.0.block.0.conv1.weight']
converted['vae']['encoder.1.conv_1.bias'] = original_model['first_stage_model.encoder.down.0.block.0.conv1.bias']
converted['vae']['encoder.1.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.0.block.0.norm2.weight']
converted['vae']['encoder.1.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.0.block.0.norm2.bias']
converted['vae']['encoder.1.conv_2.weight'] = original_model['first_stage_model.encoder.down.0.block.0.conv2.weight']
converted['vae']['encoder.1.conv_2.bias'] = original_model['first_stage_model.encoder.down.0.block.0.conv2.bias']
converted['vae']['encoder.2.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.0.block.1.norm1.weight']
converted['vae']['encoder.2.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.0.block.1.norm1.bias']
converted['vae']['encoder.2.conv_1.weight'] = original_model['first_stage_model.encoder.down.0.block.1.conv1.weight']
converted['vae']['encoder.2.conv_1.bias'] = original_model['first_stage_model.encoder.down.0.block.1.conv1.bias']
converted['vae']['encoder.2.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.0.block.1.norm2.weight']
converted['vae']['encoder.2.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.0.block.1.norm2.bias']
converted['vae']['encoder.2.conv_2.weight'] = original_model['first_stage_model.encoder.down.0.block.1.conv2.weight']
converted['vae']['encoder.2.conv_2.bias'] = original_model['first_stage_model.encoder.down.0.block.1.conv2.bias']
converted['vae']['encoder.3.weight'] = original_model['first_stage_model.encoder.down.0.downsample.conv.weight']
converted['vae']['encoder.3.bias'] = original_model['first_stage_model.encoder.down.0.downsample.conv.bias']
converted['vae']['encoder.4.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.1.block.0.norm1.weight']
converted['vae']['encoder.4.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.1.block.0.norm1.bias']
converted['vae']['encoder.4.conv_1.weight'] = original_model['first_stage_model.encoder.down.1.block.0.conv1.weight']
converted['vae']['encoder.4.conv_1.bias'] = original_model['first_stage_model.encoder.down.1.block.0.conv1.bias']
converted['vae']['encoder.4.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.1.block.0.norm2.weight']
converted['vae']['encoder.4.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.1.block.0.norm2.bias']
converted['vae']['encoder.4.conv_2.weight'] = original_model['first_stage_model.encoder.down.1.block.0.conv2.weight']
converted['vae']['encoder.4.conv_2.bias'] = original_model['first_stage_model.encoder.down.1.block.0.conv2.bias']
converted['vae']['encoder.4.skip.weight'] = original_model['first_stage_model.encoder.down.1.block.0.nin_shortcut.weight']
converted['vae']['encoder.4.skip.bias'] = original_model['first_stage_model.encoder.down.1.block.0.nin_shortcut.bias']
converted['vae']['encoder.5.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.1.block.1.norm1.weight']
converted['vae']['encoder.5.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.1.block.1.norm1.bias']
converted['vae']['encoder.5.conv_1.weight'] = original_model['first_stage_model.encoder.down.1.block.1.conv1.weight']
converted['vae']['encoder.5.conv_1.bias'] = original_model['first_stage_model.encoder.down.1.block.1.conv1.bias']
converted['vae']['encoder.5.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.1.block.1.norm2.weight']
converted['vae']['encoder.5.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.1.block.1.norm2.bias']
converted['vae']['encoder.5.conv_2.weight'] = original_model['first_stage_model.encoder.down.1.block.1.conv2.weight']
converted['vae']['encoder.5.conv_2.bias'] = original_model['first_stage_model.encoder.down.1.block.1.conv2.bias']
converted['vae']['encoder.6.weight'] = original_model['first_stage_model.encoder.down.1.downsample.conv.weight']
converted['vae']['encoder.6.bias'] = original_model['first_stage_model.encoder.down.1.downsample.conv.bias']
converted['vae']['encoder.7.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.2.block.0.norm1.weight']
converted['vae']['encoder.7.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.2.block.0.norm1.bias']
converted['vae']['encoder.7.conv_1.weight'] = original_model['first_stage_model.encoder.down.2.block.0.conv1.weight']
converted['vae']['encoder.7.conv_1.bias'] = original_model['first_stage_model.encoder.down.2.block.0.conv1.bias']
converted['vae']['encoder.7.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.2.block.0.norm2.weight']
converted['vae']['encoder.7.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.2.block.0.norm2.bias']
converted['vae']['encoder.7.conv_2.weight'] = original_model['first_stage_model.encoder.down.2.block.0.conv2.weight']
converted['vae']['encoder.7.conv_2.bias'] = original_model['first_stage_model.encoder.down.2.block.0.conv2.bias']
converted['vae']['encoder.7.skip.weight'] = original_model['first_stage_model.encoder.down.2.block.0.nin_shortcut.weight']
converted['vae']['encoder.7.skip.bias'] = original_model['first_stage_model.encoder.down.2.block.0.nin_shortcut.bias']
converted['vae']['encoder.8.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.2.block.1.norm1.weight']
converted['vae']['encoder.8.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.2.block.1.norm1.bias']
converted['vae']['encoder.8.conv_1.weight'] = original_model['first_stage_model.encoder.down.2.block.1.conv1.weight']
converted['vae']['encoder.8.conv_1.bias'] = original_model['first_stage_model.encoder.down.2.block.1.conv1.bias']
converted['vae']['encoder.8.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.2.block.1.norm2.weight']
converted['vae']['encoder.8.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.2.block.1.norm2.bias']
converted['vae']['encoder.8.conv_2.weight'] = original_model['first_stage_model.encoder.down.2.block.1.conv2.weight']
converted['vae']['encoder.8.conv_2.bias'] = original_model['first_stage_model.encoder.down.2.block.1.conv2.bias']
converted['vae']['encoder.9.weight'] = original_model['first_stage_model.encoder.down.2.downsample.conv.weight']
converted['vae']['encoder.9.bias'] = original_model['first_stage_model.encoder.down.2.downsample.conv.bias']
converted['vae']['encoder.10.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.3.block.0.norm1.weight']
converted['vae']['encoder.10.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.3.block.0.norm1.bias']
converted['vae']['encoder.10.conv_1.weight'] = original_model['first_stage_model.encoder.down.3.block.0.conv1.weight']
converted['vae']['encoder.10.conv_1.bias'] = original_model['first_stage_model.encoder.down.3.block.0.conv1.bias']
converted['vae']['encoder.10.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.3.block.0.norm2.weight']
converted['vae']['encoder.10.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.3.block.0.norm2.bias']
converted['vae']['encoder.10.conv_2.weight'] = original_model['first_stage_model.encoder.down.3.block.0.conv2.weight']
converted['vae']['encoder.10.conv_2.bias'] = original_model['first_stage_model.encoder.down.3.block.0.conv2.bias']
converted['vae']['encoder.11.groupnorm_1.weight'] = original_model['first_stage_model.encoder.down.3.block.1.norm1.weight']
converted['vae']['encoder.11.groupnorm_1.bias'] = original_model['first_stage_model.encoder.down.3.block.1.norm1.bias']
converted['vae']['encoder.11.conv_1.weight'] = original_model['first_stage_model.encoder.down.3.block.1.conv1.weight']
converted['vae']['encoder.11.conv_1.bias'] = original_model['first_stage_model.encoder.down.3.block.1.conv1.bias']
converted['vae']['encoder.11.groupnorm_2.weight'] = original_model['first_stage_model.encoder.down.3.block.1.norm2.weight']
converted['vae']['encoder.11.groupnorm_2.bias'] = original_model['first_stage_model.encoder.down.3.block.1.norm2.bias']
converted['vae']['encoder.11.conv_2.weight'] = original_model['first_stage_model.encoder.down.3.block.1.conv2.weight']
converted['vae']['encoder.11.conv_2.bias'] = original_model['first_stage_model.encoder.down.3.block.1.conv2.bias']
converted['vae']['encoder.12.groupnorm_1.weight'] = original_model['first_stage_model.encoder.mid.block_1.norm1.weight']
converted['vae']['encoder.12.groupnorm_1.bias'] = original_model['first_stage_model.encoder.mid.block_1.norm1.bias']
converted['vae']['encoder.12.conv_1.weight'] = original_model['first_stage_model.encoder.mid.block_1.conv1.weight']
converted['vae']['encoder.12.conv_1.bias'] = original_model['first_stage_model.encoder.mid.block_1.conv1.bias']
converted['vae']['encoder.12.groupnorm_2.weight'] = original_model['first_stage_model.encoder.mid.block_1.norm2.weight']
converted['vae']['encoder.12.groupnorm_2.bias'] = original_model['first_stage_model.encoder.mid.block_1.norm2.bias']
converted['vae']['encoder.12.conv_2.weight'] = original_model['first_stage_model.encoder.mid.block_1.conv2.weight']
converted['vae']['encoder.12.conv_2.bias'] = original_model['first_stage_model.encoder.mid.block_1.conv2.bias']
converted['vae']['encoder.13.groupnorm.weight'] = original_model['first_stage_model.encoder.mid.attn_1.norm.weight']
converted['vae']['encoder.13.groupnorm.bias'] = original_model['first_stage_model.encoder.mid.attn_1.norm.bias']
converted['vae']['encoder.14.groupnorm_1.weight'] = original_model['first_stage_model.encoder.mid.block_2.norm1.weight']
converted['vae']['encoder.14.groupnorm_1.bias'] = original_model['first_stage_model.encoder.mid.block_2.norm1.bias']
converted['vae']['encoder.14.conv_1.weight'] = original_model['first_stage_model.encoder.mid.block_2.conv1.weight']
converted['vae']['encoder.14.conv_1.bias'] = original_model['first_stage_model.encoder.mid.block_2.conv1.bias']
converted['vae']['encoder.14.groupnorm_2.weight'] = original_model['first_stage_model.encoder.mid.block_2.norm2.weight']
converted['vae']['encoder.14.groupnorm_2.bias'] = original_model['first_stage_model.encoder.mid.block_2.norm2.bias']
converted['vae']['encoder.14.conv_2.weight'] = original_model['first_stage_model.encoder.mid.block_2.conv2.weight']
converted['vae']['encoder.14.conv_2.bias'] = original_model['first_stage_model.encoder.mid.block_2.conv2.bias']
converted['vae']['encoder.15.weight'] = original_model['first_stage_model.encoder.norm_out.weight']
converted['vae']['encoder.15.bias'] = original_model['first_stage_model.encoder.norm_out.bias']
converted['vae']['encoder.17.weight'] = original_model['first_stage_model.encoder.conv_out.weight']
converted['vae']['encoder.17.bias'] = original_model['first_stage_model.encoder.conv_out.bias']
converted['vae']['encoder.18.weight'] = original_model['first_stage_model.quant_conv.weight']
converted['vae']['encoder.18.bias'] = original_model['first_stage_model.quant_conv.bias']
converted['vae']['encoder.13.attention.QKV.weight'] = torch.cat((original_model['first_stage_model.encoder.mid.attn_1.q.weight'], original_model['first_stage_model.encoder.mid.attn_1.k.weight'], original_model['first_stage_model.encoder.mid.attn_1.v.weight']), 0).reshape((1536, 512))
converted['vae']['encoder.13.attention.QKV.bias'] = torch.cat((original_model['first_stage_model.encoder.mid.attn_1.q.bias'], original_model['first_stage_model.encoder.mid.attn_1.k.bias'], original_model['first_stage_model.encoder.mid.attn_1.v.bias']), 0)
converted['vae']['encoder.13.attention.O.weight'] = original_model['first_stage_model.encoder.mid.attn_1.proj_out.weight'].reshape((512, 512))
converted['vae']['encoder.13.attention.O.bias'] = original_model['first_stage_model.encoder.mid.attn_1.proj_out.bias']

converted['vae']['decoder.0.weight'] = original_model['first_stage_model.post_quant_conv.weight']
converted['vae']['decoder.0.bias'] = original_model['first_stage_model.post_quant_conv.bias']
converted['vae']['decoder.1.weight'] = original_model['first_stage_model.decoder.conv_in.weight']
converted['vae']['decoder.1.bias'] = original_model['first_stage_model.decoder.conv_in.bias']
converted['vae']['decoder.2.groupnorm_1.weight'] = original_model['first_stage_model.decoder.mid.block_1.norm1.weight']
converted['vae']['decoder.2.groupnorm_1.bias'] = original_model['first_stage_model.decoder.mid.block_1.norm1.bias']
converted['vae']['decoder.2.conv_1.weight'] = original_model['first_stage_model.decoder.mid.block_1.conv1.weight']
converted['vae']['decoder.2.conv_1.bias'] = original_model['first_stage_model.decoder.mid.block_1.conv1.bias']
converted['vae']['decoder.2.groupnorm_2.weight'] = original_model['first_stage_model.decoder.mid.block_1.norm2.weight']
converted['vae']['decoder.2.groupnorm_2.bias'] = original_model['first_stage_model.decoder.mid.block_1.norm2.bias']
converted['vae']['decoder.2.conv_2.weight'] = original_model['first_stage_model.decoder.mid.block_1.conv2.weight']
converted['vae']['decoder.2.conv_2.bias'] = original_model['first_stage_model.decoder.mid.block_1.conv2.bias']
converted['vae']['decoder.3.groupnorm.weight'] = original_model['first_stage_model.decoder.mid.attn_1.norm.weight']
converted['vae']['decoder.3.groupnorm.bias'] = original_model['first_stage_model.decoder.mid.attn_1.norm.bias']
converted['vae']['decoder.3.attention.QKV.weight'] = torch.cat((original_model['first_stage_model.decoder.mid.attn_1.q.weight'], original_model['first_stage_model.decoder.mid.attn_1.k.weight'], original_model['first_stage_model.decoder.mid.attn_1.v.weight']), 0).reshape((1536, 512))
converted['vae']['decoder.3.attention.QKV.bias'] = torch.cat((original_model['first_stage_model.decoder.mid.attn_1.q.bias'], original_model['first_stage_model.decoder.mid.attn_1.k.bias'], original_model['first_stage_model.decoder.mid.attn_1.v.bias']), 0)
converted['vae']['decoder.3.attention.O.weight'] = original_model['first_stage_model.decoder.mid.attn_1.proj_out.weight'].reshape((512, 512))
converted['vae']['decoder.3.attention.O.bias'] = original_model['first_stage_model.decoder.mid.attn_1.proj_out.bias']
converted['vae']['decoder.4.groupnorm_1.weight'] = original_model['first_stage_model.decoder.mid.block_2.norm1.weight']
converted['vae']['decoder.4.groupnorm_1.bias'] = original_model['first_stage_model.decoder.mid.block_2.norm1.bias']
converted['vae']['decoder.4.conv_1.weight'] = original_model['first_stage_model.decoder.mid.block_2.conv1.weight']
converted['vae']['decoder.4.conv_1.bias'] = original_model['first_stage_model.decoder.mid.block_2.conv1.bias']
converted['vae']['decoder.4.groupnorm_2.weight'] = original_model['first_stage_model.decoder.mid.block_2.norm2.weight']
converted['vae']['decoder.4.groupnorm_2.bias'] = original_model['first_stage_model.decoder.mid.block_2.norm2.bias']
converted['vae']['decoder.4.conv_2.weight'] = original_model['first_stage_model.decoder.mid.block_2.conv2.weight']
converted['vae']['decoder.4.conv_2.bias'] = original_model['first_stage_model.decoder.mid.block_2.conv2.bias']
converted['vae']['decoder.20.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.0.block.0.norm1.weight']
converted['vae']['decoder.20.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.0.block.0.norm1.bias']
converted['vae']['decoder.20.conv_1.weight'] = original_model['first_stage_model.decoder.up.0.block.0.conv1.weight']
converted['vae']['decoder.20.conv_1.bias'] = original_model['first_stage_model.decoder.up.0.block.0.conv1.bias']
converted['vae']['decoder.20.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.0.block.0.norm2.weight']
converted['vae']['decoder.20.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.0.block.0.norm2.bias']
converted['vae']['decoder.20.conv_2.weight'] = original_model['first_stage_model.decoder.up.0.block.0.conv2.weight']
converted['vae']['decoder.20.conv_2.bias'] = original_model['first_stage_model.decoder.up.0.block.0.conv2.bias']
converted['vae']['decoder.20.skip.weight'] = original_model['first_stage_model.decoder.up.0.block.0.nin_shortcut.weight']
converted['vae']['decoder.20.skip.bias'] = original_model['first_stage_model.decoder.up.0.block.0.nin_shortcut.bias']
converted['vae']['decoder.21.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.0.block.1.norm1.weight']
converted['vae']['decoder.21.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.0.block.1.norm1.bias']
converted['vae']['decoder.21.conv_1.weight'] = original_model['first_stage_model.decoder.up.0.block.1.conv1.weight']
converted['vae']['decoder.21.conv_1.bias'] = original_model['first_stage_model.decoder.up.0.block.1.conv1.bias']
converted['vae']['decoder.21.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.0.block.1.norm2.weight']
converted['vae']['decoder.21.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.0.block.1.norm2.bias']
converted['vae']['decoder.21.conv_2.weight'] = original_model['first_stage_model.decoder.up.0.block.1.conv2.weight']
converted['vae']['decoder.21.conv_2.bias'] = original_model['first_stage_model.decoder.up.0.block.1.conv2.bias']
converted['vae']['decoder.22.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.0.block.2.norm1.weight']
converted['vae']['decoder.22.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.0.block.2.norm1.bias']
converted['vae']['decoder.22.conv_1.weight'] = original_model['first_stage_model.decoder.up.0.block.2.conv1.weight']
converted['vae']['decoder.22.conv_1.bias'] = original_model['first_stage_model.decoder.up.0.block.2.conv1.bias']
converted['vae']['decoder.22.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.0.block.2.norm2.weight']
converted['vae']['decoder.22.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.0.block.2.norm2.bias']
converted['vae']['decoder.22.conv_2.weight'] = original_model['first_stage_model.decoder.up.0.block.2.conv2.weight']
converted['vae']['decoder.22.conv_2.bias'] = original_model['first_stage_model.decoder.up.0.block.2.conv2.bias']
converted['vae']['decoder.15.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.1.block.0.norm1.weight']
converted['vae']['decoder.15.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.1.block.0.norm1.bias']
converted['vae']['decoder.15.conv_1.weight'] = original_model['first_stage_model.decoder.up.1.block.0.conv1.weight']
converted['vae']['decoder.15.conv_1.bias'] = original_model['first_stage_model.decoder.up.1.block.0.conv1.bias']
converted['vae']['decoder.15.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.1.block.0.norm2.weight']
converted['vae']['decoder.15.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.1.block.0.norm2.bias']
converted['vae']['decoder.15.conv_2.weight'] = original_model['first_stage_model.decoder.up.1.block.0.conv2.weight']
converted['vae']['decoder.15.conv_2.bias'] = original_model['first_stage_model.decoder.up.1.block.0.conv2.bias']
converted['vae']['decoder.15.skip.weight'] = original_model['first_stage_model.decoder.up.1.block.0.nin_shortcut.weight']
converted['vae']['decoder.15.skip.bias'] = original_model['first_stage_model.decoder.up.1.block.0.nin_shortcut.bias']
converted['vae']['decoder.16.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.1.block.1.norm1.weight']
converted['vae']['decoder.16.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.1.block.1.norm1.bias']
converted['vae']['decoder.16.conv_1.weight'] = original_model['first_stage_model.decoder.up.1.block.1.conv1.weight']
converted['vae']['decoder.16.conv_1.bias'] = original_model['first_stage_model.decoder.up.1.block.1.conv1.bias']
converted['vae']['decoder.16.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.1.block.1.norm2.weight']
converted['vae']['decoder.16.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.1.block.1.norm2.bias']
converted['vae']['decoder.16.conv_2.weight'] = original_model['first_stage_model.decoder.up.1.block.1.conv2.weight']
converted['vae']['decoder.16.conv_2.bias'] = original_model['first_stage_model.decoder.up.1.block.1.conv2.bias']
converted['vae']['decoder.17.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.1.block.2.norm1.weight']
converted['vae']['decoder.17.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.1.block.2.norm1.bias']
converted['vae']['decoder.17.conv_1.weight'] = original_model['first_stage_model.decoder.up.1.block.2.conv1.weight']
converted['vae']['decoder.17.conv_1.bias'] = original_model['first_stage_model.decoder.up.1.block.2.conv1.bias']
converted['vae']['decoder.17.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.1.block.2.norm2.weight']
converted['vae']['decoder.17.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.1.block.2.norm2.bias']
converted['vae']['decoder.17.conv_2.weight'] = original_model['first_stage_model.decoder.up.1.block.2.conv2.weight']
converted['vae']['decoder.17.conv_2.bias'] = original_model['first_stage_model.decoder.up.1.block.2.conv2.bias']
converted['vae']['decoder.19.weight'] = original_model['first_stage_model.decoder.up.1.upsample.conv.weight']
converted['vae']['decoder.19.bias'] = original_model['first_stage_model.decoder.up.1.upsample.conv.bias']
converted['vae']['decoder.10.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.2.block.0.norm1.weight']
converted['vae']['decoder.10.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.2.block.0.norm1.bias']
converted['vae']['decoder.10.conv_1.weight'] = original_model['first_stage_model.decoder.up.2.block.0.conv1.weight']
converted['vae']['decoder.10.conv_1.bias'] = original_model['first_stage_model.decoder.up.2.block.0.conv1.bias']
converted['vae']['decoder.10.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.2.block.0.norm2.weight']
converted['vae']['decoder.10.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.2.block.0.norm2.bias']
converted['vae']['decoder.10.conv_2.weight'] = original_model['first_stage_model.decoder.up.2.block.0.conv2.weight']
converted['vae']['decoder.10.conv_2.bias'] = original_model['first_stage_model.decoder.up.2.block.0.conv2.bias']
converted['vae']['decoder.11.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.2.block.1.norm1.weight']
converted['vae']['decoder.11.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.2.block.1.norm1.bias']
converted['vae']['decoder.11.conv_1.weight'] = original_model['first_stage_model.decoder.up.2.block.1.conv1.weight']
converted['vae']['decoder.11.conv_1.bias'] = original_model['first_stage_model.decoder.up.2.block.1.conv1.bias']
converted['vae']['decoder.11.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.2.block.1.norm2.weight']
converted['vae']['decoder.11.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.2.block.1.norm2.bias']
converted['vae']['decoder.11.conv_2.weight'] = original_model['first_stage_model.decoder.up.2.block.1.conv2.weight']
converted['vae']['decoder.11.conv_2.bias'] = original_model['first_stage_model.decoder.up.2.block.1.conv2.bias']
converted['vae']['decoder.12.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.2.block.2.norm1.weight']
converted['vae']['decoder.12.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.2.block.2.norm1.bias']
converted['vae']['decoder.12.conv_1.weight'] = original_model['first_stage_model.decoder.up.2.block.2.conv1.weight']
converted['vae']['decoder.12.conv_1.bias'] = original_model['first_stage_model.decoder.up.2.block.2.conv1.bias']
converted['vae']['decoder.12.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.2.block.2.norm2.weight']
converted['vae']['decoder.12.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.2.block.2.norm2.bias']
converted['vae']['decoder.12.conv_2.weight'] = original_model['first_stage_model.decoder.up.2.block.2.conv2.weight']
converted['vae']['decoder.12.conv_2.bias'] = original_model['first_stage_model.decoder.up.2.block.2.conv2.bias']
converted['vae']['decoder.14.weight'] = original_model['first_stage_model.decoder.up.2.upsample.conv.weight']
converted['vae']['decoder.14.bias'] = original_model['first_stage_model.decoder.up.2.upsample.conv.bias']
converted['vae']['decoder.5.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.3.block.0.norm1.weight']
converted['vae']['decoder.5.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.3.block.0.norm1.bias']
converted['vae']['decoder.5.conv_1.weight'] = original_model['first_stage_model.decoder.up.3.block.0.conv1.weight']
converted['vae']['decoder.5.conv_1.bias'] = original_model['first_stage_model.decoder.up.3.block.0.conv1.bias']
converted['vae']['decoder.5.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.3.block.0.norm2.weight']
converted['vae']['decoder.5.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.3.block.0.norm2.bias']
converted['vae']['decoder.5.conv_2.weight'] = original_model['first_stage_model.decoder.up.3.block.0.conv2.weight']
converted['vae']['decoder.5.conv_2.bias'] = original_model['first_stage_model.decoder.up.3.block.0.conv2.bias']
converted['vae']['decoder.6.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.3.block.1.norm1.weight']
converted['vae']['decoder.6.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.3.block.1.norm1.bias']
converted['vae']['decoder.6.conv_1.weight'] = original_model['first_stage_model.decoder.up.3.block.1.conv1.weight']
converted['vae']['decoder.6.conv_1.bias'] = original_model['first_stage_model.decoder.up.3.block.1.conv1.bias']
converted['vae']['decoder.6.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.3.block.1.norm2.weight']
converted['vae']['decoder.6.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.3.block.1.norm2.bias']
converted['vae']['decoder.6.conv_2.weight'] = original_model['first_stage_model.decoder.up.3.block.1.conv2.weight']
converted['vae']['decoder.6.conv_2.bias'] = original_model['first_stage_model.decoder.up.3.block.1.conv2.bias']
converted['vae']['decoder.7.groupnorm_1.weight'] = original_model['first_stage_model.decoder.up.3.block.2.norm1.weight']
converted['vae']['decoder.7.groupnorm_1.bias'] = original_model['first_stage_model.decoder.up.3.block.2.norm1.bias']
converted['vae']['decoder.7.conv_1.weight'] = original_model['first_stage_model.decoder.up.3.block.2.conv1.weight']
converted['vae']['decoder.7.conv_1.bias'] = original_model['first_stage_model.decoder.up.3.block.2.conv1.bias']
converted['vae']['decoder.7.groupnorm_2.weight'] = original_model['first_stage_model.decoder.up.3.block.2.norm2.weight']
converted['vae']['decoder.7.groupnorm_2.bias'] = original_model['first_stage_model.decoder.up.3.block.2.norm2.bias']
converted['vae']['decoder.7.conv_2.weight'] = original_model['first_stage_model.decoder.up.3.block.2.conv2.weight']
converted['vae']['decoder.7.conv_2.bias'] = original_model['first_stage_model.decoder.up.3.block.2.conv2.bias']
converted['vae']['decoder.9.weight'] = original_model['first_stage_model.decoder.up.3.upsample.conv.weight']
converted['vae']['decoder.9.bias'] = original_model['first_stage_model.decoder.up.3.upsample.conv.bias']
converted['vae']['decoder.23.weight'] = original_model['first_stage_model.decoder.norm_out.weight']
converted['vae']['decoder.23.bias'] = original_model['first_stage_model.decoder.norm_out.bias']
converted['vae']['decoder.25.weight'] = original_model['first_stage_model.decoder.conv_out.weight']
converted['vae']['decoder.25.bias'] = original_model['first_stage_model.decoder.conv_out.bias']

In [None]:
vae.load_state_dict(converted['vae'])