In [1]:
import torch
from transformers import PreTrainedModel, PretrainedConfig


class Embed(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.head = torch.nn.Parameter(torch.randn(1, 1, 1024))
        self.body = torch.nn.Conv2d(3, 1024, 14, 14, 0)
        self.pos = torch.nn.Parameter(torch.randn(1, 1369 + 1, 1024))

    def get_pos(self):
        #[1, 1, 1024]
        head = self.pos[:, :1]
        #[1, 1369, 1024]
        body = self.pos[:, 1:]

        #[1, 1369, 1024] -> [1, 37, 37, 1024] -> [1, 1024, 37, 37]
        body = body.reshape(1, 37, 37, 1024).permute(0, 3, 1, 2)

        dtype = body.dtype

        #0.9756756756756757
        scale_factor = float((512 // 14 + 0.1) / 1369**0.5)

        #[1, 1024, 37, 37] -> [1, 1024, 36, 36]
        body = torch.nn.functional.interpolate(body.float(),
                                               scale_factor=(scale_factor,
                                                             scale_factor),
                                               mode='bicubic',
                                               align_corners=False)

        #[1, 1024, 36, 36] -> [1, 36, 36, 1024] -> [1, 1296, 1024]
        body = body.to(dtype=dtype).permute(0, 2, 3, 1).reshape(1, -1, 1024)

        #[1, 1 + 1296, 1024] -> [1, 1297, 1024]
        return torch.cat((head, body), dim=1)

    def get_pos(self, size):
        #[1, 1, 1024]
        head = self.pos[:, :1]
        #[1, 1369, 1024]
        body = self.pos[:, 1:]

        size = [(i // 14 + 0.1) / 37.0 for i in size]

        body = body.reshape(1, 37, 37, 1024).permute(0, 3, 1, 2)
        body = torch.nn.functional.interpolate(body.float(),
                                               scale_factor=size,
                                               mode='bicubic',
                                               align_corners=False)
        body = body.to(dtype=self.pos.dtype)

        body = body.permute(0, 2, 3, 1).view(1, -1, 1024)

        return torch.cat((head, body), dim=1)

    def forward(self, x):
        size = list(x.shape[2:])
        x = self.body(x).flatten(2).transpose(1, 2)
        x = torch.cat((self.head, x), dim=1)
        return x + self.get_pos(size)


class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q = torch.nn.Linear(1024, 1024)
        self.k = torch.nn.Linear(1024, 1024)
        self.v = torch.nn.Linear(1024, 1024)
        self.out = torch.nn.Linear(1024, 1024)

    def forward(self, x):
        #x -> [1, dim, 1024]
        dim = x.shape[1]

        #维度不变
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        #[1, dim, 1024] -> [1, dim, 16, 64] -> [1, 16, dim, 64]
        q = q.reshape(1, dim, 16, 64).permute(0, 2, 1, 3)
        k = k.reshape(1, dim, 16, 64).permute(0, 2, 1, 3)
        v = v.reshape(1, dim, 16, 64).permute(0, 2, 1, 3)

        #[1, 16, dim, 64] * [1, 16, 64, dim] -> [1, 16, dim, dim]
        atten = q.matmul(k.transpose(2, 3))

        atten = atten / 64**0.5

        atten = atten.softmax(dim=-1)

        #[1, 16, dim, dim] * [1, 16, dim, 64] -> [1, 16, dim, 64]
        atten = atten.matmul(v)

        #[1, 16, dim, 64] -> [1, dim, 16, 64]
        atten = atten.permute(0, 2, 1, 3)

        #[1, dim, 16, 64] -> [1, dim, 1024]
        atten = atten.reshape(1, dim, 1024)

        return self.out(atten)


class Layer(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.atten = Atten()

        self.mlp = torch.nn.Sequential(torch.nn.Linear(1024, 4096),
                                       torch.nn.GELU(),
                                       torch.nn.Linear(4096, 1024))

        self.norm1 = torch.nn.LayerNorm(1024, eps=1e-6)
        self.norm2 = torch.nn.LayerNorm(1024, eps=1e-6)

        self.scala1 = torch.nn.Parameter(torch.ones(1024))
        self.scala2 = torch.nn.Parameter(torch.ones(1024))

    def forward(self, x):
        atten = self.atten(self.norm1(x)) * self.scala1

        res = atten + x

        out = self.mlp(self.norm2(res)) * self.scala2

        return out + res


class Encoder(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.embed = Embed()
        self.layers = torch.nn.ModuleList([Layer() for _ in range(24)])
        self.norm = torch.nn.LayerNorm(1024, eps=1e-6)

    def forward(self, x):
        x = self.embed(x)

        for i in self.layers:
            x = i(x)

        return self.norm(x)