In [1]:
import torch
from transformers import PreTrainedModel, PretrainedConfig


class Attention(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.norm = torch.nn.GroupNorm(32, 512, 1e-6, True)
        self.q = torch.nn.Linear(512, 512)
        self.k = torch.nn.Linear(512, 512)
        self.v = torch.nn.Linear(512, 512)
        self.out = torch.nn.Linear(512, 512)

    def forward(self, x):
        res = x
        shape = x.shape

        x = self.norm(x.flatten(start_dim=2)).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        atten = torch.empty(1,
                            x.shape[1],
                            x.shape[1],
                            dtype=q.dtype,
                            device=q.device)
        atten = torch.baddbmm(atten,
                              q,
                              k.transpose(1, 2),
                              beta=0,
                              alpha=512**-0.5)

        atten = atten.float().softmax(dim=-1).to(q.dtype)
        atten = atten.bmm(v)
        atten = self.out(atten)
        atten = atten.transpose(1, 2).reshape(shape)

        return atten + res


class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.s = torch.nn.Sequential(
            torch.nn.GroupNorm(32, dim_in, 1e-6, True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in, dim_out, 3, 1, 1),
            torch.nn.GroupNorm(32, dim_out, 1e-6, True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out, dim_out, 3, 1, 1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in, dim_out, 1, 1, 0)

    def forward(self, x):
        res = self.s(x)

        if self.res:
            x = self.res(x)

        return x + res


def Mid():
    return torch.nn.Sequential(Resnet(512, 512), Attention(), Resnet(512, 512))


class Down(torch.nn.Module):

    def __init__(self, dim_in, dim_out, downsample):
        super().__init__()
        self.s = torch.nn.Sequential(Resnet(dim_in, dim_out),
                                     Resnet(dim_out, dim_out))

        self.downsample = None
        if downsample:
            self.downsample = torch.nn.Conv2d(dim_out, dim_out, 3, 2, 0)

    def forward(self, x):
        x = self.s(x)

        if self.downsample:
            x = torch.nn.functional.pad(x, (0, 1, 0, 1),
                                        mode='constant',
                                        value=0)
            x = self.downsample(x)

        return x


class Up(torch.nn.Module):

    def __init__(self, dim_in, dim_out, upsample):
        super().__init__()

        self.s = torch.nn.Sequential(Resnet(dim_in, dim_out),
                                     Resnet(dim_out, dim_out),
                                     Resnet(dim_out, dim_out))

        self.upsample = None
        if upsample:
            self.upsample = torch.nn.Conv2d(dim_out, dim_out, 3, padding=1)

    def forward(self, x):
        x = self.s(x)

        if self.upsample:
            x = torch.nn.functional.interpolate(x,
                                                scale_factor=2.0,
                                                mode='nearest')
            x = self.upsample(x)

        return x


class VAE(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 128, 3, 1, 1),
            Down(128, 128, True),
            Down(128, 256, True),
            Down(256, 512, True),
            Down(512, 512, False),
            Mid(),
            torch.nn.GroupNorm(32, 512, 1e-6, True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(512, 8, 3, padding=1),
            torch.nn.Conv2d(8, 8, 1),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(4, 4, 1),
            torch.nn.Conv2d(4, 512, 3, 1, 1),
            Mid(),
            Up(512, 512, True),
            Up(512, 512, True),
            Up(512, 256, True),
            Up(256, 128, False),
            torch.nn.GroupNorm(32, 128, 1e-6, True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(128, 3, 3, padding=1),
        )

    def encode(self, x):
        h = self.encoder(x)

        mean, std = torch.chunk(h, 2, dim=1)
        std = (std.clamp(-30.0, 20.0) * 0.5).exp()

        return mean, std

    def decode(self, h):
        return self.decoder(h)

    def forward(self, x):
        mean = self.encode(x)[0]
        return self.decode(mean)