In [1]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from functools import partial
from einops import rearrange
from torch import einsum

# --- Transformer blocks (iguais)
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class SelfAttention(nn.Module):
    def __init__(self, dim, heads=6, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
        dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
        attn = self.attend(dots)
        out = einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)

# --- EfficientNetB0 
class EfficientNetB0(nn.Module):
    def __init__(self):
        super().__init__()
        model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        self.features = model.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.out_channels = 1280

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        return x.view(x.size(0), -1)

# --- ConTrans final com EfficientNetB0 + ViT + atributos
class ConTrans(nn.Module):
    def __init__(self, attr_dim=120, num_classes=3):
        super().__init__()

        # Hiperparâmetros
        n_feats = 64
        n_heads = 4
        n_layers = 4
        dim_head = n_feats // n_heads
        expansion_ratio = 4
        dropout_rate = 0.1
        image_size = 70
        patch_size = 5
        num_patches = (image_size // patch_size) ** 2
        
        self.stem = nn.Sequential(
        nn.Conv2d(3, n_feats, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(n_feats),
        nn.ReLU(inplace=True),
        nn.Dropout2d(0.1)
        )

        # Embedding de patches
        self.patch_embedding = nn.Conv2d(n_feats, n_feats, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, n_feats))
        self.embedding_dim = n_feats

        # Transformer
        self.transformer = nn.ModuleList([
            nn.ModuleList([
                PreNorm(n_feats, SelfAttention(n_feats, n_heads, dim_head, dropout_rate)),
                PreNorm(n_feats, FeedForward(n_feats, n_feats * expansion_ratio, dropout_rate))
            ]) for _ in range(n_layers)
        ])

        # EfficientNetB0 adaptada
        self.effnet_branch = EfficientNetB0()
        self.effnet_out_channels = self.effnet_branch.out_channels

        # Classificador final
        fusion_dim = self.embedding_dim + self.effnet_out_channels
        self.fc = nn.Sequential(
            nn.Linear(fusion_dim + attr_dim, 256),  # entrada = concatenação de todos os vetores e 
                                                    # redução da dimensionalidade para um vetor menor
            nn.ReLU(),                              # ativação não-linear
            nn.Dropout(0.5),                        # regularização
            nn.Linear(256, num_classes)             # saída final com número de classes
        )

    def forward(self, x, attr=None):
        # Stem
        x_stem = self.stem(x)  # [B, 64, 70, 70]

        # ViT branch
        patches = self.patch_embedding(x_stem)  # [B, C, H', W']
        b, c, h, w = patches.shape
        patches = patches.flatten(2).transpose(1, 2)  # [B, N, C]
        patches += self.pos_embedding[:, :patches.size(1)]

        for attn, ff in self.transformer:
            patches = attn(patches) + patches
            patches = ff(patches) + patches

        vit_out = patches.mean(dim=1)  # [B, 64]

        # EfficientNet branch
        effnet_out = self.effnet_branch(x)  # [B, 1280]

        # Fusão
        fusion = torch.cat([vit_out, effnet_out], dim=1)  # [B, 1344]
        if attr is not None:
            fusion = torch.cat([fusion, attr], dim=1)      # [B, 1344 +119]

        return self.fc(fusion)  # [B, num_classes]

def get_ConTrans_func_by_name(name):
    if name == "ConTrans":
        return lambda deploy=False: ConTrans()
    else:
        raise ValueError(f"Modelo '{name}' não reconhecido")

