In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class CCT_Tokenizer(nn.Module):
    """
    Convolutional Tokenizer for CCT.
    Replaces the standard patching with a convolutional block to preserve local spatial relationships[cite: 317].
    """
    def __init__(self, img_size=224, embed_dim=256, kernel_size=7, stride=2, padding=3):
        super(CCT_Tokenizer, self).__init__()
        # Standard initial convolution to reduce dimensions and tokenize
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(64, embed_dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
        )
        self.embed_dim = embed_dim

    def forward(self, x):
        x = self.conv(x)
        # Flatten spatial dimensions to create sequence of tokens: (B, C, H, W) -> (B, H*W, C)
        x = x.flatten(2).transpose(1, 2)
        return x

class SequencePooling(nn.Module):
    """
    Sequence Pooling (SeqPool) as described in CCT papers and utilized here.
    Gathers sequential information without a CLS token[cite: 276, 277].
    """
    def __init__(self, embed_dim):
        super().__init__()
        self.attention_pool = nn.Linear(embed_dim, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x shape: (Batch, Seq_Len, Embed_Dim)
        # Calculate weights for each token
        w = self.attention_pool(x) # (Batch, Seq_Len, 1)
        w = self.softmax(w)
        # Weighted sum of tokens
        out = torch.matmul(x.transpose(1, 2), w).squeeze(-1) # (Batch, Embed_Dim)
        return out

class CompactConvolutionalTransformer(nn.Module):
    def __init__(self, img_size=224, embed_dim=128, num_layers=4, num_heads=4):
        super().__init__()
        self.tokenizer = CCT_Tokenizer(img_size=img_size, embed_dim=embed_dim)

        # Positional Embedding (Learnable)
        # Note: Sequence length depends on input size and tokenizer strides.
        # For 224x224 and standard strides, seq_len is approx 56*56 or 28*28 depending on depth.
        # We allow a max length or dynamic handling.
        self.pos_embed = nn.Parameter(torch.zeros(1, 3136, embed_dim)) # Approx max seq len

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.seq_pool = SequencePooling(embed_dim)

    def forward(self, x):
        x = self.tokenizer(x)
        B, N, C = x.shape
        # Add positional embeddings (truncated to current sequence length)
        x = x + self.pos_embed[:, :N, :]
        x = self.transformer(x)
        x = self.seq_pool(x)
        return x

class HybridEarNet(nn.Module):
    """
    The proposed hybrid architecture: MobileNetV2 + ViT (CCT)[cite: 23, 319].
    """
    def __init__(self, cct_embed_dim=128):
        super(HybridEarNet, self).__init__()

        # --- Branch 1: MobileNetV2 ---
        # "we used the CNN model MobileNetV2" [cite: 315]
        # "the last layer of MobileNetV2 is removed and its feature vector of 1280 is concatenated" [cite: 319]
        base_mobilenet = models.mobilenet_v2(pretrained=True)
        self.mobilenet_features = base_mobilenet.features
        self.mobilenet_pool = nn.AdaptiveAvgPool2d((1, 1))

        # --- Branch 2: CCT (Vision Transformer) ---
        # "CCT architecture... reduces the number of parameters" [cite: 317]
        self.cct = CompactConvolutionalTransformer(embed_dim=cct_embed_dim)

        # --- Fusion Head ---
        # MobileNetV2 outputs 1280 dim. CCT outputs cct_embed_dim.
        # Figure 4 shows these concatenated, then into a FC layer of 512, then to 2 classes.
        combined_dim = 1280 + cct_embed_dim

        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, 512), # "FC 512" from Figure 4 [cite: 302]
            nn.ReLU(),
            nn.Dropout(0.2), # Standard practice, though not explicitly detailed in text
            nn.Linear(512, 2) # "Male / Female" [cite: 308]
        )

    def forward(self, x):
        # 1. MobileNet Path
        m_out = self.mobilenet_features(x)
        m_out = self.mobilenet_pool(m_out)
        m_out = torch.flatten(m_out, 1) # (Batch, 1280)

        # 2. CCT Path
        c_out = self.cct(x) # (Batch, cct_embed_dim)

        # 3. Concatenation
        # "concatenated with the feature vector of CCT model" [cite: 319]
        combined = torch.cat((m_out, c_out), dim=1)

        # 4. Classification
        output = self.classifier(combined)
        return output

# Instantiate the model
model = HybridEarNet(cct_embed_dim=128)
print(f"Model created. Parameters should be efficient (Low parameter count target).")