In [1]:
import torch
from transformer import Transformer2D
from transformers import PreTrainedModel, PretrainedConfig


class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):

        super().__init__()
        self.s_in = torch.nn.Sequential(
            torch.nn.GroupNorm(32, dim_in, 1e-5, True), torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in, dim_out, 3, 1, 1))

        self.s_timestep = torch.nn.Sequential(torch.nn.SiLU(),
                                              torch.nn.Linear(1280, dim_out))

        self.s_out = torch.nn.Sequential(
            torch.nn.GroupNorm(32, dim_out, 1e-5, True), torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out, dim_out, 3, 1, 1))

        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in, dim_out, 1, 1, 0)

    def forward(self, x, timestep):
        #x -> [1, dim_in, w, h]
        #timestep -> [1, 1280]

        #[1, dim_in, w, h] -> [1, dim_out, w, h]
        res = x
        if self.res:
            res = self.res(x)

        #[1, dim_in, w, h] -> [1, dim_out, w, h]
        x = self.s_in(x)

        #[1, 1280] -> [1, dim_out, 1, 1]
        timestep = self.s_timestep(timestep).unflatten(1, (-1, 1, 1))

        #[1, dim_out, w, h] + [1, dim_out, 1, 1] -> [1, dim_out, w, h]
        x = x + timestep

        #[1, dim_out, w, h]
        x = self.s_out(x)

        #[1, dim_out, w, h]
        return res + x


class Upsampler(torch.nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.conv = torch.nn.Conv2d(dim, dim, kernel_size=3, padding=1)

    def forward(self, x, size=None):
        #x -> [1, c, w, h]

        dtype = x.dtype

        #[1, c, w, h] -> [1, c, w*2, h*2]
        if size:
            x = torch.nn.functional.interpolate(x.to(torch.float32),
                                                size=size,
                                                mode='nearest').to(dtype)
        else:
            x = torch.nn.functional.interpolate(x.to(torch.float32),
                                                scale_factor=2.0,
                                                mode='nearest').to(dtype)

        #[1, c, w*2, h*2]
        return self.conv(x)


class Mid(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.transformer = Transformer2D(heads=20, dim=1280)
        self.resnet1 = Resnet(dim_in=1280, dim_out=1280)
        self.resnet2 = Resnet(dim_in=1280, dim_out=1280)

    def forward(self, q, kv, timestep):
        #q -> [1, 1280, 8, 8]
        #kv -> [1, 77, 1024]
        #timestep -> [1, 1280]

        #[1, 1280, 8, 8]
        q = self.resnet1(q, timestep)

        #[1, 1280, 8, 8]
        q = self.transformer(q, kv=kv)

        #[1, 1280, 8, 8]
        return self.resnet2(q, timestep)


class TransformerDown(torch.nn.Module):

    def __init__(self, dim_in, dim_out, heads):
        super().__init__()

        self.resnet1 = Resnet(dim_in=dim_in, dim_out=dim_out)
        self.resnet2 = Resnet(dim_in=dim_out, dim_out=dim_out)

        self.transformer1 = Transformer2D(heads=heads, dim=dim_out)
        self.transformer2 = Transformer2D(heads=heads, dim=dim_out)

        self.downsampler = torch.nn.Conv2d(dim_out, dim_out, 3, 2, 1)

    def forward(self, q, kv, timestep):
        #q -> [1, dim_in, w, h]
        #kv -> [1, 77, 1024]
        #timestep -> [1, 1280]

        hidden = []

        #[1, dim_in, w, h] -> [1, dim_out, w, h]
        q = self.resnet1(q, timestep)
        #[1, dim_out, w, h]
        q = self.transformer1(q, kv=kv)
        hidden.append(q)

        #[1, dim_out, w, h]
        q = self.resnet2(q, timestep)
        #[1, dim_out, w, h]
        q = self.transformer2(q, kv=kv)
        hidden.append(q)

        #[1, dim_out, w/2, h/2]
        q = self.downsampler(q)
        hidden.append(q)

        return q, hidden


class Down(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.resnet1 = Resnet(dim_in=1280, dim_out=1280)
        self.resnet2 = Resnet(dim_in=1280, dim_out=1280)

    def forward(self, q, timestep, **kwargs):
        #q -> [1, 1280, 8, 8]
        #timestep -> [1, 1280]

        hidden = []

        #[1, 1280, 8, 8]
        q = self.resnet1(q, timestep)
        hidden.append(q)

        #[1, 1280, 8, 8]
        q = self.resnet2(q, timestep)
        hidden.append(q)

        return q, hidden


class Up(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.resnet = torch.nn.ModuleList([
            Resnet(dim_in=2560, dim_out=1280),
            Resnet(dim_in=2560, dim_out=1280),
            Resnet(dim_in=2560, dim_out=1280)
        ])

        self.upsampler = Upsampler(1280)

    def forward(self, q, hidden, timestep, size, **kwargs):
        #q -> [1, 1280, 8, 8]
        #hidden ->  [[1, 1280, 8, 8], [1, 1280, 8, 8], [1, 1280, 8, 8]]

        for i in self.resnet:
            #[1, 1280+1280, 8, 8] -> [1, 2560, 8, 8]
            q = torch.cat([q, hidden.pop(-1)], dim=1)
            q = i(q, timestep)

        #[1, 2560, 8, 8] -> [1, 2560, 16, 16]
        return self.upsampler(q, size)


class TransformerUp(torch.nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden, heads, add_upsample):
        super().__init__()

        self.transformer = torch.nn.ModuleList([
            Transformer2D(heads=heads, dim=dim_out),
            Transformer2D(heads=heads, dim=dim_out),
            Transformer2D(heads=heads, dim=dim_out)
        ])

        self.resnet = torch.nn.ModuleList([
            Resnet(dim_in=dim_hidden + dim_out, dim_out=dim_out),
            Resnet(dim_in=dim_out + dim_out, dim_out=dim_out),
            Resnet(dim_in=dim_out + dim_in, dim_out=dim_out)
        ])

        self.upsampler = None
        if add_upsample:
            self.upsampler = Upsampler(dim_out)

    def forward(self, q, hidden, timestep, kv, size=None):
        #q -> [1, dim_hidden, w, h]
        #hidden -> [[1, dim_in, w, h], [1, dim_out, w, h], [1, dim_out, w, h]]

        for i in range(3):
            #[1, dim_hidden+..., w, h]
            h = hidden.pop(-1)
            q = torch.cat([q, h], dim=1)
            #[1, dim_hidden+..., w, h] -> [1, dim_out, w, h]
            q = self.resnet[i](q, timestep)
            #[1, dim_out, w, h]
            q = self.transformer[i](q, kv=kv)

        if self.upsampler:
            #[1, dim_out, w, h] -> [1, dim_out, w*2, h*2]
            q = self.upsampler(q, size)

        return q


class TimestepEmbedding(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.s = torch.nn.Sequential(torch.nn.Linear(320, 1280),
                                     torch.nn.SiLU(),
                                     torch.nn.Linear(1280, 1280))

        import math
        embedding = torch.arange(160, dtype=torch.float32)
        embedding = (embedding * -math.log(10000) / 160).exp()
        self.register_buffer('embedding', embedding)

    def forward(self, timestep, dtype):
        #[1, 160]
        timestep = (timestep * self.embedding).reshape(1, -1)

        #[1, 160] -> [1, 320]
        timestep = torch.cat([timestep.cos(), timestep.sin()],
                             dim=1).to(dtype=dtype)

        #[1, 320] -> [1, 1280]
        return self.s(timestep)


class UNet(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)

        self.timestep_embedding = TimestepEmbedding()

        self.down = torch.nn.ModuleList([
            TransformerDown(dim_in=320, dim_out=320, heads=5),
            TransformerDown(dim_in=320, dim_out=640, heads=10),
            TransformerDown(dim_in=640, dim_out=1280, heads=20),
            Down()
        ])

        self.mid = Mid()

        self.up = torch.nn.ModuleList([
            Up(),
            TransformerUp(dim_in=640,
                          dim_out=1280,
                          dim_hidden=1280,
                          heads=20,
                          add_upsample=True),
            TransformerUp(dim_in=320,
                          dim_out=640,
                          dim_hidden=1280,
                          heads=10,
                          add_upsample=True),
            TransformerUp(dim_in=320,
                          dim_out=320,
                          dim_hidden=640,
                          heads=5,
                          add_upsample=False)
        ])

        self.s_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)

        self.s_out = torch.nn.Sequential(
            torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
            torch.nn.SiLU(), torch.nn.Conv2d(320, 4, kernel_size=3, padding=1))

    def forward(self,
                q,
                kv,
                timestep,
                controlnet_down=None,
                controlnet_mid=None):
        #q -> [1, 4, 64, 64]
        #kv -> [1, 77, 1024]
        #timestep -> [1]

        #[1, 1280]
        timestep = self.timestep_embedding(timestep, q.dtype)

        #[1, 4, 64, 64] -> [1, 320, 64, 64]
        q = self.s_in(q)

        out = [q]

        for i in range(4):
            #[1, 320, 64, 64] -> [1, 320, 32, 32]
            #[1, 320, 32, 32] -> [1, 640, 16, 16]
            #[1, 640, 16, 16] -> [1, 1280, 8, 8]
            #[1, 1280, 8, 8] -> [1, 1280, 8, 8]
            q, hidden = self.down[i](q=q, kv=kv, timestep=timestep)
            out.extend(list(hidden))

        if controlnet_down is not None:
            for i in range(len(out)):
                out[i] = out[i] + controlnet_down[i]

        #[1, 1280, 8, 8]
        q = self.mid(q=q, kv=kv, timestep=timestep)

        if controlnet_mid is not None:
            q = q + controlnet_mid

        for i in range(4):
            hidden = [out.pop(-1) for _ in range(3)]
            hidden = list(reversed(hidden))

            size = None
            if out:
                size = out[-1].shape[2:]

            #[1, 1280, 8, 8] -> [1, 1280, 16, 16]
            #[1, 1280, 16, 16] -> [1, 1280, 32, 32]
            #[1, 1280, 32, 32] -> [1, 640, 64, 64]
            #[1, 640, 64, 64] -> [1, 320, 64, 64]
            q = self.up[i](q=q,
                           hidden=hidden,
                           kv=kv,
                           timestep=timestep,
                           size=size)

        #[1, 320, 64, 64] -> [1, 4, 64, 64]
        return self.s_out(q)

config.json:   0%|          | 0.00/104 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]