In [2]:
from urllib.request import urlopen
from PIL import Image
import timm
import torch

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('convnext_small.in12k_ft_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

In [3]:
top5_class_indices

tensor([[967, 868, 504, 415, 505]])

In [5]:
import torch
import torch.nn as nn
import timm

class MLJETConvNext(nn.Module):
    def __init__(self, backbone='convnext_small.in12k_ft_in1k', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=0)
        
        backbone_features = self.backbone.num_features
        
        # Multi-head classifier outputs
        self.energy_loss_head = nn.Sequential(
            nn.Linear(backbone_features, 1),
            nn.Sigmoid()
        )
        self.alpha_head = nn.Sequential(
            nn.Linear(backbone_features, 3),
            nn.Softmax(dim=1)
        )
        self.q0_head = nn.Sequential(
            nn.Linear(backbone_features, 4),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        energy_loss_output = self.energy_loss_head(features)
        alpha_output = self.alpha_head(features)
        q0_output = self.q0_head(features)
        
        return {
            'energy_loss_output': energy_loss_output,
            'alpha_output': alpha_output,
            'q0_output': q0_output
        }

# Example usage
model = MLJETConvNext()

# Example input
x = torch.randn((1, 1, 32, 32))  # Updated input shape for single-channel input
outputs = model(x)

print(outputs)

{'energy_loss_output': tensor([[0.2407]], grad_fn=<SigmoidBackward0>), 'alpha_output': tensor([[0.3277, 0.2453, 0.4270]], grad_fn=<SoftmaxBackward0>), 'q0_output': tensor([[0.1059, 0.1067, 0.0496, 0.7378]], grad_fn=<SoftmaxBackward0>)}
