In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_height=64, img_width=128, patch_size=8, overlap=4, in_chans=1, embed_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.overlap = overlap
        stride = patch_size - overlap
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class GrnnNet(nn.Module):
    def __init__(self, in_chans=1, num_classes=105, img_height=64, img_width=128, patch_size=8, overlap=4, embed_dim=256, depth=4, num_heads=4, mlp_ratio=2., mode='vertical'):
        super().__init__()
        self.mode = mode
        self.patch_embed = PatchEmbedding(img_height, img_width, patch_size, overlap, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, (img_height - overlap) // (patch_size - overlap) * (img_width - overlap) // (patch_size - overlap) + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.1)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio))
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.transformer(x)

        if self.mode == 'vertical':
            x = x[:, :x.size(1) // 2].mean(dim=1)
        elif self.mode == 'horizontal':
            x = x[:, x.size(1) // 2:].mean(dim=1)
        else:
            x = x.mean(dim=1)

        x = self.norm(x)
        x = self.head(x)
        return x

if __name__ == '__main__':
    x = torch.rand(1, 1, 64, 128)
    mod = GrnnNet(in_chans=1, num_classes=105, mode='vertical')
    logits = mod(x)
    print(logits.shape)




torch.Size([1, 105])
