In [28]:
import torch
from torch import nn
from collections import OrderedDict

In [29]:
class Resnet(nn.Module):
    
    def __init__(self, dim_in, dim_out):
        super(Resnet, self).__init__()

        # Time embedding layer
        self.time_embedding = nn.Sequential(
            nn.SiLU(),
            nn.Linear(1280, dim_out),
            nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),
        )

        # First convolutional block
        self.conv_block1 = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=dim_in, eps=1e-05, affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1),
        )

        # Second convolutional block
        self.conv_block2 = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=dim_out, eps=1e-05, affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1),
        )

        # Residual connection
        self.residual = None
        if dim_in != dim_out:
            self.residual = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x, time):
        # Save the input tensor for the residual connection
        residual = x

        # Apply time embedding
        time_emb = self.time_embedding(time)

        # Apply the first convolutional block and add time embedding
        x = self.conv_block1(x) + time_emb

        # Apply the second convolutional block
        x = self.conv_block2(x)

        # Apply residual connection if dimensions do not match
        if self.residual:
            residual = self.residual(residual)

        # Add the residual connection
        x = x + residual

        return x
    
Resnet(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 1280)).shape


torch.Size([1, 640, 32, 32])

In [30]:
class CrossAttention(nn.Module):

    def __init__(self, dim_q, dim_kv):
        super(CrossAttention, self).__init__()

        self.dim_q = dim_q

        self.q = nn.Linear(dim_q, dim_q, bias=False)
        self.k = nn.Linear(dim_kv, dim_q, bias=False)
        self.v = nn.Linear(dim_kv, dim_q, bias=False)
        self.out = nn.Linear(dim_q, dim_q)
    
    def reshape(self, x, split_size):
        b, lens, dim = x.shape
        x = x.reshape(b, lens, split_size, dim // split_size)
        x = x.transpose(1, 2)
        x = x.reshape(b * split_size, lens, dim // split_size)
        return x
    
    def reshape_back(self, x, split_size):
        b, lens, dim = x.shape
        x = x.reshape(b // split_size, split_size, lens, dim)
        x = x.transpose(1, 2)
        x = x.reshape(b // split_size, lens, dim * split_size)
        return x
    
    def forward(self, q, kv):
        q = self.q(q)
        k = self.k(kv)
        v = self.v(kv)

        q = self.reshape(q, 8)
        k = self.reshape(k, 8)
        v = self.reshape(v, 8)

        scale = (self.dim_q // 8) ** -0.5
        atten = torch.baddbmm(
            torch.empty(q.shape[0], q.shape[1], k.shape[1], device=q.device),
            q, k.transpose(1, 2),
            beta=0,
            alpha=scale
        )

        atten = atten.softmax(dim=-1)
        atten = atten.bmm(v)

        atten = self.reshape_back(atten, 8)
        atten = self.out(atten)

        return atten
    
CrossAttention(320, 768)(torch.randn(1, 4096, 320), torch.randn(1, 77, 768)).shape


torch.Size([1, 4096, 320])

In [31]:
class Transformer(nn.Module):
    def __init__(self, dim):
        super(Transformer, self).__init__()

        self.dim = dim

        self.input_block = nn.Sequential(OrderedDict([
            ('norm_in', nn.GroupNorm(num_groups=32, num_channels=dim, eps=1e-6, affine=True)),
            ('cnn_in', nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)),
        ]))

        self.attention_block = nn.Sequential(OrderedDict([
            ('cross_atten1', CrossAttention(dim, dim)),
            ('cross_atten2', CrossAttention(dim, 768)),
            ('norm_atten0', nn.LayerNorm(dim, elementwise_affine=True)),
            ('norm_atten1', nn.LayerNorm(dim, elementwise_affine=True)),
        ]))

        self.activation_block = nn.Sequential(OrderedDict([
            ('fc0', nn.Linear(dim, dim * 8)),
            ('act', nn.GELU()),
            ('fc1', nn.Linear(dim * 4, dim)),
            ('norm_act', nn.LayerNorm(dim, elementwise_affine=True)),
        ]))

        self.output_block = nn.Sequential(OrderedDict([
            ('cnn_out', nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)),
        ]))

    def forward(self, q, kv):
        b, _, h, w = q.shape
        res1 = q

        # ---- Input Processing ----
        q = self.input_block(q)
        q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)

        # ---- Attention ----
        q = self.attention_block.norm_atten0(q)
        q = self.attention_block.cross_atten1(q=q, kv=q) + q

        q = self.attention_block.norm_atten1(q)
        q = self.attention_block.cross_atten2(q=q, kv=kv) + q

        # ---- Activation ----
        res2 = q
        q = self.activation_block.norm_act(q)
        q = self.activation_block.fc0(q)

        d = q.shape[2] // 2
        q = q[:, :, :d] * self.activation_block.act(q[:, :, d:])
        q = self.activation_block.fc1(q) + res2

        # ---- Output Processing ----
        q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()
        q = self.output_block.cnn_out(q) + res1

        return q
Transformer(320)(torch.randn(1, 320, 64, 64), torch.randn(1, 77, 768)).shape

torch.Size([1, 320, 64, 64])

In [32]:
class DownBlock(nn.Module):
    
    def __init__(self, dim_in, dim_out) -> None:
        super(DownBlock, self).__init__()

        self.DownBlock1 = nn.Sequential(OrderedDict([
            ('transformer1', Transformer(dim_out)),
            ('resnet1', Resnet(dim_in, dim_out)),
        ]))

        self.DownBlock2 = nn.Sequential(OrderedDict([
            ('transformer2', Transformer(dim_out)),
            ('resnet2', Resnet(dim_out, dim_out)),
        ]))

        self.downsample = nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=2, padding=1)

    def forward(self, out_vae, out_encoder, time):
        outs = []

        out_vae = self.DownBlock1.resnet1(out_vae, time)
        out_vae = self.DownBlock1.transformer1(out_vae, out_encoder)
        outs.append(out_vae)

        out_vae = self.DownBlock2.resnet2(out_vae, time)
        out_vae = self.DownBlock2.transformer2(out_vae, out_encoder)
        outs.append(out_vae)

        out_vae = self.downsample(out_vae)
        outs.append(out_vae)

        return out_vae, outs

DownBlock(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 77, 768), torch.randn(1, 1280))[0].shape


torch.Size([1, 640, 16, 16])

In [33]:
class UpBlock(nn.Module):
    def __init__(self, dim_in, dim_out, dim_prev, add_up):
        super(UpBlock, self).__init__()

        self.UpBlock1 = nn.Sequential(OrderedDict([
            ('resnet1', Resnet(dim_out + dim_prev, dim_out)),
            ('resnet2', Resnet(dim_out + dim_out, dim_out)),
            ('resnet3', Resnet(dim_in + dim_out, dim_out)),
        ]))

        self.UpBlock2 = nn.Sequential(OrderedDict([
            ('transformer1', Transformer(dim_out)),
            ('transformer2', Transformer(dim_out)),
            ('transformer3', Transformer(dim_out)),
        ]))

        self.upsample = None
        if add_up:
            self.upsample = nn.Sequential(OrderedDict([
                ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
                ('conv', nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1)),
            ]))

    def forward(self, out_vae, out_encoder, time, out_down):
        out_vae = self.UpBlock1.resnet1(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.UpBlock2.transformer1(out_vae, out_encoder)

        out_vae = self.UpBlock1.resnet2(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.UpBlock2.transformer2(out_vae, out_encoder)

        out_vae = self.UpBlock1.resnet3(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.UpBlock2.transformer3(out_vae, out_encoder)

        if self.upsample:
            out_vae = self.upsample(out_vae)
        return out_vae
UpBlock(320, 640, 1280, True)(torch.randn(1, 1280, 32, 32),
                                    torch.randn(1, 77, 768),
                                    torch.randn(1, 1280),
                                    [torch.randn(1, 320, 32, 32), 
                                    torch.randn(1, 640, 32, 32),
                                    torch.randn(1, 640, 32, 32)]).shape 


torch.Size([1, 640, 64, 64])

In [34]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        #  input layer
        self.in_vae = nn.Conv2d(4, 320, kernel_size=3, padding=1)
        self.in_time = nn.Sequential(
            nn.Linear(320, 1280),
            nn.SiLU(),
            nn.Linear(1280, 1280),
        )

        #  Down_sampling blocks
        self.down_blocks = nn.Sequential(OrderedDict([
            ('down_block1', DownBlock(320, 320)),
            ('down_block2', DownBlock(320 ,640)),
            ('down_block3', DownBlock(640, 1280)),
        ]))
        self.down_residuals = nn.Sequential(OrderedDict([
            ('down_resnet1', Resnet(1280, 1280)),
            ('down_resnet2', Resnet(1280, 1280)),
        ]))

        #  Mid layers
        self.mid_layers = nn.Sequential(OrderedDict([
            ('mid_resnet1', Resnet(1280, 1280)),
            ('mid_transformer', Transformer(1280)),
            ('mid_resnet2', Resnet(1280, 1280)),
        ]))

        #  Upsampling layers
        self.up_residuals = nn.Sequential(OrderedDict([
            ('up_resnet1', Resnet(2560, 1280)),
            ('up_resnet2', Resnet(2560, 1280)),
            ('up_resnet3', Resnet(2560, 1280)),
        ]))
        self.up_in = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(1280, 1280, kernel_size=3, padding=1),
        )

        #  Upsampling blocks
        self.up_blocks = nn.Sequential(OrderedDict([
            ('up_block1', UpBlock(640, 1280, 1280, True)),
            ('up_block2', UpBlock(320, 640, 1280, True)),
            ('up_block3', UpBlock(320, 320, 640, False)),
        ]))

        #  Output layer
        self.out = nn.Sequential(
            nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
            nn.SiLU(),
            nn.Conv2d(320, 4, kernel_size=3, padding=1),
        )

    def get_time_embed(self, t):
        e = torch.arange(160) * -9.210340371976184 / 160    #   -9.210340371976184 = -math.log(10000)
        e = e.exp().to(t.device) * t
        e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)
        return e
    
    def forward(self, out_vae, out_encoder, time):
        #  Input processing
        out_vae = self.in_vae(out_vae)
        time = self.get_time_embed(time)
        time = self.in_time(time)

        #  Downsampling
        out_down = [out_vae]
        for down_block in self.down_blocks:
            out_vae, out = down_block(out_vae=out_vae, out_encoder=out_encoder, time=time)
            out_down.extend(out)
        
        for down_residual in self.down_residuals:
            out_vae = down_residual(out_vae, time)
            out_down.append(out_vae)
        
        #  Mid layers
        for layer in self.mid_layers:
            if isinstance(layer, Transformer):
                out_vae = layer(out_vae, out_encoder)
            else:
                out_vae = layer(out_vae, time)
        
        #  Upsampling
        for up_residual in self.up_residuals:
            out_vae = up_residual(torch.cat([out_vae, out_down.pop()], dim=1), time)
        
        out_vae = self.up_in(out_vae)

        for up_block in self.up_blocks:
            out_vae = up_block(out_vae=out_vae, out_encoder=out_encoder, time=time, out_down=out_down)

        #  Output layer
        out_vae = self.out(out_vae)
        return out_vae

Unet()(torch.randn(2, 4, 64, 64), torch.randn(2, 77, 768), torch.LongTensor([26])).shape


torch.Size([2, 4, 64, 64])

In [35]:
import torch

from torch import nn
from transformers import CLIPTextModel
from collections import OrderedDict
from diffusers import AutoencoderKL, UNet2DConditionModel

pretrain_model_path = '/home/junlei/.cache/huggingface/lansinuote/diffsion_from_scratch.params'
Unet_params = UNet2DConditionModel.from_pretrained(pretrain_model_path, subfolder='unet')


def load_transformer(model, param):

    #  load input block
    model.input_block.norm_in.load_state_dict(param.norm.state_dict())
    model.input_block.cnn_in.load_state_dict(param.proj_in.state_dict())
    # print(model.input_block.cnn_in.state_dict(), param.proj_in.state_dict())
    # return

    #  load attention_block

    for (name, proj) in zip(['q', 'k', 'v', 'out'], [param.transformer_blocks[0].attn1.to_q, param.transformer_blocks[0].attn1.to_k, param.transformer_blocks[0].attn1.to_v, param.transformer_blocks[0].attn1.to_out[0]]):
        getattr(model.attention_block.cross_atten1, name).load_state_dict(proj.state_dict())
        # print(proj.state_dict(), "#######\n")

    for (name, proj) in zip(['q', 'k', 'v', 'out'], [param.transformer_blocks[0].attn2.to_q, param.transformer_blocks[0].attn2.to_k, param.transformer_blocks[0].attn2.to_v, param.transformer_blocks[0].attn2.to_out[0]]):
        getattr(model.attention_block.cross_atten2, name).load_state_dict(proj.state_dict())
        # print(proj.state_dict(), "#######\n")
    
    model.activation_block.fc0.load_state_dict(param.transformer_blocks[0].ff.net[0].proj.state_dict())
    model.activation_block.fc1.load_state_dict(param.transformer_blocks[0].ff.net[2].state_dict())

    model.attention_block.norm_atten0.load_state_dict(param.transformer_blocks[0].norm1.state_dict())
    model.attention_block.norm_atten1.load_state_dict(param.transformer_blocks[0].norm2.state_dict())

    #  load activate block
    model.activation_block.norm_act.load_state_dict(param.transformer_blocks[0].norm3.state_dict())

    #  load output block
    model.output_block.cnn_out.load_state_dict(param.proj_out.state_dict())


def load_unet_resnet(model, params):
    model.time_embedding[1].load_state_dict(params.time_emb_proj.state_dict())

    model.conv_block1[0].load_state_dict(params.norm1.state_dict())
    model.conv_block1[2].load_state_dict(params.conv1.state_dict())

    model.conv_block2[0].load_state_dict(params.norm2.state_dict())
    model.conv_block2[2].load_state_dict(params.conv2.state_dict())
    
    if isinstance(model.residual, torch.nn.Module):
        model.residual.load_state_dict(params.conv_shortcut.state_dict())

def load_unet_down_block(model, params):
    load_transformer(model.DownBlock1.transformer1, params.attentions[0])
    return
    load_transformer(model.DownBlock2.transformer2, params.attentions[1])

    load_unet_resnet(model.DownBlock1.resnet1, params.resnets[0])
    load_unet_resnet(model.DownBlock2.resnet2, params.resnets[1])

    model.downsample.load_state_dict(params.downsamplers[0].conv.state_dict())


def load_unet_up_block(model, params):
    load_transformer(model.UpBlock2.transformer1, params.attentions[0])
    load_transformer(model.UpBlock2.transformer2, params.attentions[1])
    load_transformer(model.UpBlock2.transformer3, params.attentions[2])

    load_unet_resnet(model.UpBlock1.resnet1, params.resnets[0])
    load_unet_resnet(model.UpBlock1.resnet2, params.resnets[1])
    load_unet_resnet(model.UpBlock1.resnet3, params.resnets[2])

    if isinstance(model.upsample, torch.nn.Module):
        model.upsample.conv.load_state_dict(params.upsamplers[0].conv.state_dict())

unet = Unet()

unet.in_vae.load_state_dict(Unet_params.conv_in.state_dict())
unet.in_time[0].load_state_dict(Unet_params.time_embedding.linear_1.state_dict())
unet.in_time[2].load_state_dict(Unet_params.time_embedding.linear_2.state_dict())

#  load down block layer
load_unet_down_block(unet.down_blocks.down_block1, Unet_params.down_blocks[0])
print(unet.down_blocks.down_block1.state_dict(), "\n###########\n", Unet_params.down_blocks[0].state_dict())
# load_unet_down_block(unet.down_blocks.down_block2, Unet_params.down_blocks[1])
# load_unet_down_block(unet.down_blocks.down_block3, Unet_params.down_blocks[2])

# load_unet_resnet(unet.down_residuals.down_resnet1, Unet_params.down_blocks[3].resnets[0])
# load_unet_resnet(unet.down_residuals.down_resnet2, Unet_params.down_blocks[3].resnets[1])

# #  load mid block layer
# load_transformer(unet.mid_layers.mid_transformer, Unet_params.mid_block.attentions[0])
# load_unet_resnet(unet.mid_layers.mid_resnet1, Unet_params.mid_block.resnets[0])
# load_unet_resnet(unet.mid_layers.mid_resnet2, Unet_params.mid_block.resnets[1])

# #  load up block layer
# load_unet_resnet(unet.up_residuals.up_resnet1, Unet_params.up_blocks[0].resnets[0])
# load_unet_resnet(unet.up_residuals.up_resnet2, Unet_params.up_blocks[0].resnets[1])
# load_unet_resnet(unet.up_residuals.up_resnet3, Unet_params.up_blocks[0].resnets[2])
# unet.up_in[1].load_state_dict(Unet_params.up_blocks[0].upsamplers[0].conv.state_dict())

# load_unet_up_block(unet.up_blocks.up_block1, Unet_params.up_blocks[1])
# load_unet_up_block(unet.up_blocks.up_block2, Unet_params.up_blocks[2])
# load_unet_up_block(unet.up_blocks.up_block3, Unet_params.up_blocks[3])

# #  load output layer
# unet.out[0].load_state_dict(Unet_params.conv_norm_out.state_dict())
# unet.out[2].load_state_dict(Unet_params.conv_out.state_dict())

# out_vae = torch.randn(1, 4, 64, 64)
# out_encoder = torch.randn(1, 77, 768)
# time = torch.LongTensor([26])

# a = unet(out_vae=out_vae, out_encoder=out_encoder, time=time)
# b = Unet_params(out_vae, time, out_encoder).sample

# (a == b).all()

OrderedDict([('DownBlock1.transformer1.input_block.norm_in.weight', tensor([0.3490, 0.2030, 0.2801, 0.4251, 0.4367, 0.3979, 0.5001, 0.4106, 0.4244,
        0.4812, 0.4026, 0.3334, 0.4434, 0.4175, 0.3872, 0.2145, 0.3776, 0.4181,
        0.4261, 0.2849, 0.3011, 0.2483, 0.2570, 0.2685, 0.2780, 0.2617, 0.2677,
        0.2888, 0.2755, 0.2995, 0.3901, 0.4083, 0.3994, 0.3687, 0.4601, 0.4188,
        0.2570, 0.4555, 0.4450, 0.4097, 0.3300, 0.2953, 0.3260, 0.2577, 0.3079,
        0.2378, 0.3315, 0.3126, 0.2107, 0.2976, 0.2905, 0.2212, 0.4036, 0.3347,
        0.2961, 0.4213, 0.3420, 0.4301, 0.5623, 0.3347, 0.2853, 0.2879, 0.2788,
        0.2691, 0.2539, 0.2760, 0.2594, 0.2581, 0.2488, 0.2640, 0.4241, 0.4493,
        0.2774, 0.2226, 0.3849, 0.3289, 0.4136, 0.5445, 0.2573, 0.3671, 0.2960,
        0.2682, 0.3079, 0.3148, 0.2531, 0.2308, 0.2861, 0.2718, 0.2787, 0.2440,
        0.4583, 0.4646, 0.5300, 0.3048, 0.2136, 0.2923, 0.5164, 0.2869, 0.5453,
        0.4225, 0.1585, 0.3478, 0.3200, 0.3825, 0.36