In [1]:
import torch
import torch.nn as nn
import timm

In [9]:
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
target

tensor([1, 0, 4])

In [10]:
torch.clamp(target, min=-1, max=1)

tensor([1, 0, 1])

In [4]:
target

tensor([0, 0, 0])

In [2]:
class SpatialAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention_map = self.conv(x)
        attention_map = self.sigmoid(attention_map)  # Normalize to range [0, 1]
        return x * attention_map

In [3]:
class ImageClassificationWithAttention(nn.Module):
    def __init__(self, backbone_name, num_classes):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True)
        in_features = (
            self.backbone.classifier.in_features
        )  # Adapt based on the backbone

        # Remove the original classification head of the backbone
        self.backbone.classifier = nn.Identity()

        self.attention = SpatialAttentionModule(in_features)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.head = nn.Linear(in_features, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.attention(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.head(x)
        return x

In [6]:
timm.list_models("tf_efficientnetv2*", pretrained=True)

['tf_efficientnetv2_b0.in1k',
 'tf_efficientnetv2_b1.in1k',
 'tf_efficientnetv2_b2.in1k',
 'tf_efficientnetv2_b3.in1k',
 'tf_efficientnetv2_b3.in21k',
 'tf_efficientnetv2_b3.in21k_ft_in1k',
 'tf_efficientnetv2_l.in1k',
 'tf_efficientnetv2_l.in21k',
 'tf_efficientnetv2_l.in21k_ft_in1k',
 'tf_efficientnetv2_m.in1k',
 'tf_efficientnetv2_m.in21k',
 'tf_efficientnetv2_m.in21k_ft_in1k',
 'tf_efficientnetv2_s.in1k',
 'tf_efficientnetv2_s.in21k',
 'tf_efficientnetv2_s.in21k_ft_in1k',
 'tf_efficientnetv2_xl.in21k',
 'tf_efficientnetv2_xl.in21k_ft_in1k']

In [2]:
timm.list_models("efficientnet*", pretrained=True)

['efficientnet_b0.ra_in1k',
 'efficientnet_b1.ft_in1k',
 'efficientnet_b1_pruned.in1k',
 'efficientnet_b2.ra_in1k',
 'efficientnet_b2_pruned.in1k',
 'efficientnet_b3.ra2_in1k',
 'efficientnet_b3_pruned.in1k',
 'efficientnet_b4.ra2_in1k',
 'efficientnet_b5.sw_in12k',
 'efficientnet_b5.sw_in12k_ft_in1k',
 'efficientnet_el.ra_in1k',
 'efficientnet_el_pruned.in1k',
 'efficientnet_em.ra2_in1k',
 'efficientnet_es.ra_in1k',
 'efficientnet_es_pruned.in1k',
 'efficientnet_lite0.ra_in1k',
 'efficientnetv2_rw_m.agc_in1k',
 'efficientnetv2_rw_s.ra2_in1k',
 'efficientnetv2_rw_t.ra2_in1k']

In [None]:
model = timm.create_model("efficientnet_b0", pretrained=True)