In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.3):
        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = F.relu

    def forward(self, src):
        src2, _ = self.self_attn(src, src, src)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        return src

class CustomEncoderLayer(nn.Module):
    def __init__(self, input_channel, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.vgg = VGGnet(input_channel)

        # Collapse layer to flatten the output
        self.collapse = nn.Linear(512, d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Transformer Encoder
        encoder_layers = self._transformer_encoder_layers(d_model, nhead, num_layers)
        self.transformer_encoder = nn.Sequential(*encoder_layers)

    def _transformer_encoder_layers(self, d_model, nhead, num_layers):
        encoder_layers = []
        for _ in range(num_layers):
            encoder_layer = CustomTransformerEncoderLayer(d_model=d_model, nhead=nhead)
            encoder_layers.append(encoder_layer)
        return encoder_layers

    def forward(self, x):
        # VGGNet feature extraction
        x, xfeat = self.vgg(x)

        # Collapse layer to flatten
        x = self.collapse(x)
        x = x.unsqueeze(1)  # Add sequence dimension

        B, L, C = x.size()
        x = x.permute(1, 0, 2)  # (B, L, C) -> (L, B, C) for Transformer

        # Add positional encoding
        x = self.pos_encoder(x)

        # Transformer encoder
        for layer in self.transformer_encoder:
            x = layer(x)

        x = x.permute(1, 0, 2)  # (L, B, C) -> (B, L, C) after Transformer

        return x, xfeat

class VGGnet(nn.Module):
    def __init__(self, input_channel):
        super().__init__()
        layers = [64, 128, 256, 512]
        self.conv1 = self._conv(input_channel, layers[0])
        self.maxp1 = nn.MaxPool2d(2, stride=2)
        self.conv2 = self._conv(layers[0], layers[1])
        self.maxp2 = nn.MaxPool2d(2, stride=2)
        self.conv3 = self._conv(layers[1], layers[2])
        self.maxp3 = nn.MaxPool2d(2, stride=2)
        self.conv4 = self._conv(layers[2], layers[3])
        self.maxp4 = nn.MaxPool2d(2, stride=2)

        self.avg = nn.AdaptiveAvgPool2d(1)

    def _conv(self, in_channels, out_channels, nlayers=2):
        conv = []
        for _ in range(nlayers):
            conv.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False))
            conv.append(nn.BatchNorm2d(out_channels))
            conv.append(nn.ReLU(inplace=True))
            in_channels = out_channels
        return nn.Sequential(*conv)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxp1(x)
        x = self.conv2(x)
        x = self.maxp2(x)
        x = self.conv3(x)
        x = self.maxp3(x)
        xfeat = x
        x = self.conv4(x)
        x = self.maxp4(x)
        x = torch.flatten(self.avg(x), 1)
        return x, xfeat

class GrnnNet(nn.Module):
    def __init__(self, input_channel, num_classes=105, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.encoder_layer = CustomEncoderLayer(input_channel, d_model, nhead, num_layers)
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.ada = nn.Linear(d_model, d_model)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        glf, feat = self.encoder_layer(x)
        glf = glf.squeeze(1)  # Remove sequence dimension

        glfa_vertical = torch.zeros_like(glf)  # Initialize tensor for vertical mode
        glfa_horizontal = torch.zeros_like(glf)  # Initialize tensor for horizontal mode

        # Horizontal mode
        seq_h = feat.size()[-2]
        for n in range(seq_h):
            patch = feat[:, :, n, :].unsqueeze(2)
            lx = torch.flatten(self.avg(patch), 1)
            lx = self.ada(lx)
            glfa_horizontal += glf + lx

        # Vertical mode
        seq_v = feat.size()[-1] // 2
        for n in range(seq_v):
            s = 2 * n
            patch = feat[:, :, :, s:s + 2]
            lx = torch.flatten(self.avg(patch), 1)
            lx = self.ada(lx)
            glfa_vertical += glf + lx

        # Combine results from both modes
        glfa = glfa_horizontal + glfa_vertical

        logits = self.classifier(glfa)

        return logits

if __name__ == '__main__':
    x = torch.rand(1, 1, 64, 128)
    mod = GrnnNet(1, 105)
    logits = mod(x)
    print(logits.shape)


torch.Size([1, 105])
