In [1]:
import torch
from transformers import PreTrainedModel, PretrainedConfig


class ControlNet(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config, unet):
        super().__init__(config)

        torch.save(unet.timestep_embedding, 'model/temp')
        self.timestep_embedding = torch.load('model/temp')

        torch.save(unet.s_in, 'model/temp')
        self.unet_s_in = torch.load('model/temp')

        torch.save(unet.down, 'model/temp')
        self.unet_down = torch.load('model/temp')

        torch.save(unet.mid, 'model/temp')
        self.unet_mid = torch.load('model/temp')

        self.embedding = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, 3, 1, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(16, 16, 3, 1, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(16, 32, 3, 2, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 32, 3, 1, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 96, 3, 2, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(96, 96, 3, 1, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(96, 256, 3, 2, 1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 320, 3, 1, 1),
        )

        self.controlnet_down = torch.nn.ModuleList([
            torch.nn.Conv2d(320, 320, 1),
            torch.nn.Conv2d(320, 320, 1),
            torch.nn.Conv2d(320, 320, 1),
            torch.nn.Conv2d(320, 320, 1),
            torch.nn.Conv2d(640, 640, 1),
            torch.nn.Conv2d(640, 640, 1),
            torch.nn.Conv2d(640, 640, 1),
            torch.nn.Conv2d(1280, 1280, 1),
            torch.nn.Conv2d(1280, 1280, 1),
            torch.nn.Conv2d(1280, 1280, 1),
            torch.nn.Conv2d(1280, 1280, 1),
            torch.nn.Conv2d(1280, 1280, 1),
        ])

        self.controlnet_mid = torch.nn.Conv2d(1280, 1280, 1)

        for i in self.controlnet_down.parameters():
            torch.nn.init.zeros_(i)

        for i in self.controlnet_mid.parameters():
            torch.nn.init.zeros_(i)

    def forward(self, q, kv, timestep, controlnet_cond):
        timestep = self.timestep_embedding(timestep, q.dtype)

        q = self.unet_s_in(q)
        controlnet_cond = self.embedding(controlnet_cond)

        controlnet_cond = controlnet_cond[:, :, :q.shape[2], :q.shape[3]]

        q = q + controlnet_cond

        controlnet_down = [q]
        for i in self.unet_down:
            q, h = i(q=q, kv=kv, timestep=timestep)
            controlnet_down.extend(h)

        q = self.unet_mid(q, kv=kv, timestep=timestep)

        for i in range(12):
            controlnet_down[i] = self.controlnet_down[i](controlnet_down[i])

        controlnet_mid = self.controlnet_mid(q)

        return controlnet_down, controlnet_mid