In [None]:
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, backbone_name='resnet18', pretrained=True,
                 feature_dim=512, attention_dim=128, attention_heads=1, dropout=0.25):
        super().__init__()
        self.L = feature_dim
        self.D = attention_dim
        self.K = attention_heads
        self.num_classes = num_classes

        # backbone
        self.feature_extractor, backbone_out_dim = self._get_backbone(backbone_name, pretrained)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_embedder = nn.Sequential(
            nn.Linear(backbone_out_dim, self.L),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Gated attention mechanism
        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )
        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )
        self.attention_weights = nn.Linear(self.D, self.K)

        # Final classifier
        self.classifier = nn.Linear(self.L * self.K, self.num_classes)

    def _get_backbone(self, name, pretrained):
        weights = 'DEFAULT' if pretrained else None

        if name == "resnet18":
            model = resnet18(weights=weights)
            return nn.Sequential(*list(model.children())[:-2]), 512
        elif name == "resnet50":
            model = resnet50(weights=weights)
            return nn.Sequential(*list(model.children())[:-2]), 2048
        elif name == "resnet152":
            model = resnet152(weights=weights)
            return nn.Sequential(*list(model.children())[:-2]), 2048
        elif name == "efficientnet_b0":
            model = efficientnet_b0(weights=weights)
            return model.features, 1280
        elif name == "vgg16":
            model = vgg16(weights=weights)
            return nn.Sequential(*list(model.features.children())[:-1]), 512
        else:
            raise ValueError(f"Unsupported backbone: {name}")

    def forward(self, x):
        B, N, C, H, W = x.shape
        x = x.view(B * N, C, H, W)

        features = self.feature_extractor(x)
        features = self.pool(features).view(B * N, -1)
        embedded = self.feature_embedder(features)

        # Gated attention
        A_v = self.attention_V(embedded)
        A_u = self.attention_U(embedded)
        A = self.attention_weights(A_v * A_u)
        A = A.view(B, N, self.K)
        A = F.softmax(A, dim=1)

        A = A.transpose(1, 2)  # (B, K, N)
        embedded = embedded.view(B, N, self.L)
        M = torch.bmm(A, embedded)  # (B, K, L)

        # Classification
        M = M.view(B, -1)  # Flatten
        logits = self.classifier(M)
        probs = F.softmax(logits, dim=1)

        return logits, probs, A.transpose(1, 2)

    def calculate_objective(self, x_batch, y_batch, criterion):
        logits, _, _ = self.forward(x_batch)
        return criterion(logits, y_batch)

    def calculate_prediction_accuracy(self, x_batch, y_batch):
        logits, _, _ = self.forward(x_batch)
        preds = torch.argmax(logits, dim=1)
        correct = preds.eq(y_batch).sum().item()
        total = y_batch.size(0)
        return preds, correct / total, correct, total

    def get_attention_map(self, bag_tensor):
        self.eval()
        with torch.no_grad():
            if bag_tensor.ndim != 4:
              raise ValueError(f"Expected input of shape (bag_size, 3, H, W), got {bag_tensor.shape}")

            N, C, H, W = bag_tensor.shape  # bag_size, channels, height, width

            # Move input to device
            device = next(self.parameters()).device
            bag_tensor = bag_tensor.to(device)

            # Extract features
            features = self.feature_extractor(bag_tensor)
            features = self.pool(features).view(N, -1)
            embedded = self.feature_embedder(features)

            # gated attention
            A_v = self.attention_V(embedded)
            A_u = self.attention_U(embedded)
            A_scores = self.attention_weights(A_v * A_u)

            A_weights = F.softmax(A_scores, dim=0)
            attention = A_weights[:, 0].cpu().numpy()

            return attention