In [5]:
%load_ext dotenv
%dotenv .env.hf-token

The dotenv extension is already loaded. To reload it, use:
  %reload_ext dotenv


In [7]:
import torch

state_dict_legacy = torch.load("/home/antoine/models/nereus/dino_resnet50/resnet50.pth", map_location="cpu")

In [8]:
import torch
import torchvision

with torch.no_grad():
    resnet50 = torchvision.models.resnet50(weights=None)
    resnet50.fc = torch.nn.Identity()
    
    conv1 = resnet50.conv1
    new_conv1 = torch.nn.Conv2d(
                in_channels=1,
                out_channels=conv1.out_channels,
                kernel_size=conv1.kernel_size,
                stride=conv1.stride,
                padding=conv1.padding,
                bias=(conv1.bias is not None)
            )
    
    new_conv1.weight = torch.nn.Parameter(conv1.weight.sum(dim=1, keepdim=True))

    resnet50.conv1 = new_conv1

resnet50.eval()

state_dict = {}

trad = {
    "0": "conv1",
    "1": "bn1",
    "4": "layer1",
    "5": "layer2",
    "6": "layer3",
    "7": "layer4",
}

for k, v in state_dict_legacy.items():
    if k.startswith("student_backbone.backbone."):
        k_new = k.removeprefix("student_backbone.backbone.")
        for kk, vv in trad.items():
            if k_new.startswith(kk):
                k_new = k_new.replace(kk, vv, 1)
                break
        state_dict[k_new] = v

resnet50.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
from torchinfo import summary

summary(resnet50, (1, 1, 256, 256), device="cpu", depth=10, row_settings=["var_names", "depth"], mode="eval")

Layer (type (var_name):depth-idx)                  Output Shape              Param #
ResNet (ResNet)                                    [1, 2048]                 --
├─Conv2d (conv1): 1-1                              [1, 64, 128, 128]         3,136
├─BatchNorm2d (bn1): 1-2                           [1, 64, 128, 128]         128
├─ReLU (relu): 1-3                                 [1, 64, 128, 128]         --
├─MaxPool2d (maxpool): 1-4                         [1, 64, 64, 64]           --
├─Sequential (layer1): 1-5                         [1, 256, 64, 64]          --
│    └─Bottleneck (0): 2-1                         [1, 256, 64, 64]          --
│    │    └─Conv2d (conv1): 3-1                    [1, 64, 64, 64]           4,096
│    │    └─BatchNorm2d (bn1): 3-2                 [1, 64, 64, 64]           128
│    │    └─ReLU (relu): 3-3                       [1, 64, 64, 64]           --
│    │    └─Conv2d (conv2): 3-4                    [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d

In [10]:
import torch

torch.save(resnet50.state_dict(), "resnet50.pth")

In [36]:
from transformers import ResNetModel, ResNetConfig

config_resnet = ResNetConfig(1)
resnet50_hf = ResNetModel(config_resnet)

summary(resnet50_hf, (1, 1, 256, 256), device="cpu", row_settings=["var_names", "depth"], depth=10)

Layer (type (var_name):depth-idx)                                           Output Shape              Param #
ResNetModel (ResNetModel)                                                   [1, 2048, 1, 1]           --
├─ResNetEmbeddings (embedder): 1-1                                          [1, 64, 64, 64]           --
│    └─ResNetConvLayer (embedder): 2-1                                      [1, 64, 128, 128]         --
│    │    └─Conv2d (convolution): 3-1                                       [1, 64, 128, 128]         3,136
│    │    └─BatchNorm2d (normalization): 3-2                                [1, 64, 128, 128]         128
│    │    └─ReLU (activation): 3-3                                          [1, 64, 128, 128]         --
│    └─MaxPool2d (pooler): 2-2                                              [1, 64, 64, 64]           --
├─ResNetEncoder (encoder): 1-2                                              [1, 2048, 8, 8]           --
│    └─ModuleList (stages): 2-3               

In [34]:
layer_name_mapping = {
    # Initial convolution and batch norm layers
    "conv1.weight": "embedder.embedder.convolution.weight",
    "bn1.weight": "embedder.embedder.normalization.weight",
    "bn1.bias": "embedder.embedder.normalization.bias",
    "bn1.running_mean": "embedder.embedder.normalization.running_mean",
    "bn1.running_var": "embedder.embedder.normalization.running_var",
    
    # Layer 1
    "layer1.0.conv1.weight": "encoder.stages.0.layers.0.layer.0.convolution.weight",
    "layer1.0.bn1.weight": "encoder.stages.0.layers.0.layer.0.normalization.weight",
    "layer1.0.bn1.bias": "encoder.stages.0.layers.0.layer.0.normalization.bias",
    "layer1.0.bn1.running_mean": "encoder.stages.0.layers.0.layer.0.normalization.running_mean",
    "layer1.0.bn1.running_var": "encoder.stages.0.layers.0.layer.0.normalization.running_var",
    "layer1.0.conv2.weight": "encoder.stages.0.layers.0.layer.1.convolution.weight",
    "layer1.0.bn2.weight": "encoder.stages.0.layers.0.layer.1.normalization.weight",
    "layer1.0.bn2.bias": "encoder.stages.0.layers.0.layer.1.normalization.bias",
    "layer1.0.bn2.running_mean": "encoder.stages.0.layers.0.layer.1.normalization.running_mean",
    "layer1.0.bn2.running_var": "encoder.stages.0.layers.0.layer.1.normalization.running_var",
    "layer1.0.conv3.weight": "encoder.stages.0.layers.0.layer.2.convolution.weight",
    "layer1.0.bn3.weight": "encoder.stages.0.layers.0.layer.2.normalization.weight",
    "layer1.0.bn3.bias": "encoder.stages.0.layers.0.layer.2.normalization.bias",
    "layer1.0.bn3.running_mean": "encoder.stages.0.layers.0.layer.2.normalization.running_mean",
    "layer1.0.bn3.running_var": "encoder.stages.0.layers.0.layer.2.normalization.running_var",
    "layer1.0.downsample.0.weight": "encoder.stages.0.layers.0.shortcut.convolution.weight",
    "layer1.0.downsample.1.weight": "encoder.stages.0.layers.0.shortcut.normalization.weight",
    "layer1.0.downsample.1.bias": "encoder.stages.0.layers.0.shortcut.normalization.bias",
    "layer1.0.downsample.1.running_mean": "encoder.stages.0.layers.0.shortcut.normalization.running_mean",
    "layer1.0.downsample.1.running_var": "encoder.stages.0.layers.0.shortcut.normalization.running_var",
    
    "layer1.1.conv1.weight": "encoder.stages.0.layers.1.layer.0.convolution.weight",
    "layer1.1.bn1.weight": "encoder.stages.0.layers.1.layer.0.normalization.weight",
    "layer1.1.bn1.bias": "encoder.stages.0.layers.1.layer.0.normalization.bias",
    "layer1.1.bn1.running_mean": "encoder.stages.0.layers.1.layer.0.normalization.running_mean",
    "layer1.1.bn1.running_var": "encoder.stages.0.layers.1.layer.0.normalization.running_var",
    "layer1.1.conv2.weight": "encoder.stages.0.layers.1.layer.1.convolution.weight",
    "layer1.1.bn2.weight": "encoder.stages.0.layers.1.layer.1.normalization.weight",
    "layer1.1.bn2.bias": "encoder.stages.0.layers.1.layer.1.normalization.bias",
    "layer1.1.bn2.running_mean": "encoder.stages.0.layers.1.layer.1.normalization.running_mean",
    "layer1.1.bn2.running_var": "encoder.stages.0.layers.1.layer.1.normalization.running_var",
    "layer1.1.conv3.weight": "encoder.stages.0.layers.1.layer.2.convolution.weight",
    "layer1.1.bn3.weight": "encoder.stages.0.layers.1.layer.2.normalization.weight",
    "layer1.1.bn3.bias": "encoder.stages.0.layers.1.layer.2.normalization.bias",
    "layer1.1.bn3.running_mean": "encoder.stages.0.layers.1.layer.2.normalization.running_mean",
    "layer1.1.bn3.running_var": "encoder.stages.0.layers.1.layer.2.normalization.running_var",
    
    "layer1.2.conv1.weight": "encoder.stages.0.layers.2.layer.0.convolution.weight",
    "layer1.2.bn1.weight": "encoder.stages.0.layers.2.layer.0.normalization.weight",
    "layer1.2.bn1.bias": "encoder.stages.0.layers.2.layer.0.normalization.bias",
    "layer1.2.bn1.running_mean": "encoder.stages.0.layers.2.layer.0.normalization.running_mean",
    "layer1.2.bn1.running_var": "encoder.stages.0.layers.2.layer.0.normalization.running_var",
    "layer1.2.conv2.weight": "encoder.stages.0.layers.2.layer.1.convolution.weight",
    "layer1.2.bn2.weight": "encoder.stages.0.layers.2.layer.1.normalization.weight",
    "layer1.2.bn2.bias": "encoder.stages.0.layers.2.layer.1.normalization.bias",
    "layer1.2.bn2.running_mean": "encoder.stages.0.layers.2.layer.1.normalization.running_mean",
    "layer1.2.bn2.running_var": "encoder.stages.0.layers.2.layer.1.normalization.running_var",
    "layer1.2.conv3.weight": "encoder.stages.0.layers.2.layer.2.convolution.weight",
    "layer1.2.bn3.weight": "encoder.stages.0.layers.2.layer.2.normalization.weight",
    "layer1.2.bn3.bias": "encoder.stages.0.layers.2.layer.2.normalization.bias",
    "layer1.2.bn3.running_mean": "encoder.stages.0.layers.2.layer.2.normalization.running_mean",
    "layer1.2.bn3.running_var": "encoder.stages.0.layers.2.layer.2.normalization.running_var",
    
    # Layer 2
    "layer2.0.conv1.weight": "encoder.stages.1.layers.0.layer.0.convolution.weight",
    "layer2.0.bn1.weight": "encoder.stages.1.layers.0.layer.0.normalization.weight",
    "layer2.0.bn1.bias": "encoder.stages.1.layers.0.layer.0.normalization.bias",
    "layer2.0.bn1.running_mean": "encoder.stages.1.layers.0.layer.0.normalization.running_mean",
    "layer2.0.bn1.running_var": "encoder.stages.1.layers.0.layer.0.normalization.running_var",
    "layer2.0.conv2.weight": "encoder.stages.1.layers.0.layer.1.convolution.weight",
    "layer2.0.bn2.weight": "encoder.stages.1.layers.0.layer.1.normalization.weight",
    "layer2.0.bn2.bias": "encoder.stages.1.layers.0.layer.1.normalization.bias",
    "layer2.0.bn2.running_mean": "encoder.stages.1.layers.0.layer.1.normalization.running_mean",
    "layer2.0.bn2.running_var": "encoder.stages.1.layers.0.layer.1.normalization.running_var",
    "layer2.0.conv3.weight": "encoder.stages.1.layers.0.layer.2.convolution.weight",
    "layer2.0.bn3.weight": "encoder.stages.1.layers.0.layer.2.normalization.weight",
    "layer2.0.bn3.bias": "encoder.stages.1.layers.0.layer.2.normalization.bias",
    "layer2.0.bn3.running_mean": "encoder.stages.1.layers.0.layer.2.normalization.running_mean",
    "layer2.0.bn3.running_var": "encoder.stages.1.layers.0.layer.2.normalization.running_var",
    "layer2.0.downsample.0.weight": "encoder.stages.1.layers.0.shortcut.convolution.weight",
    "layer2.0.downsample.1.weight": "encoder.stages.1.layers.0.shortcut.normalization.weight",
    "layer2.0.downsample.1.bias": "encoder.stages.1.layers.0.shortcut.normalization.bias",
    "layer2.0.downsample.1.running_mean": "encoder.stages.1.layers.0.shortcut.normalization.running_mean",
    "layer2.0.downsample.1.running_var": "encoder.stages.1.layers.0.shortcut.normalization.running_var",
    
    "layer2.1.conv1.weight": "encoder.stages.1.layers.1.layer.0.convolution.weight",
    "layer2.1.bn1.weight": "encoder.stages.1.layers.1.layer.0.normalization.weight",
    "layer2.1.bn1.bias": "encoder.stages.1.layers.1.layer.0.normalization.bias",
    "layer2.1.bn1.running_mean": "encoder.stages.1.layers.1.layer.0.normalization.running_mean",
    "layer2.1.bn1.running_var": "encoder.stages.1.layers.1.layer.0.normalization.running_var",
    "layer2.1.conv2.weight": "encoder.stages.1.layers.1.layer.1.convolution.weight",
    "layer2.1.bn2.weight": "encoder.stages.1.layers.1.layer.1.normalization.weight",
    "layer2.1.bn2.bias": "encoder.stages.1.layers.1.layer.1.normalization.bias",
    "layer2.1.bn2.running_mean": "encoder.stages.1.layers.1.layer.1.normalization.running_mean",
    "layer2.1.bn2.running_var": "encoder.stages.1.layers.1.layer.1.normalization.running_var",
    "layer2.1.conv3.weight": "encoder.stages.1.layers.1.layer.2.convolution.weight",
    "layer2.1.bn3.weight": "encoder.stages.1.layers.1.layer.2.normalization.weight",
    "layer2.1.bn3.bias": "encoder.stages.1.layers.1.layer.2.normalization.bias",
    "layer2.1.bn3.running_mean": "encoder.stages.1.layers.1.layer.2.normalization.running_mean",
    "layer2.1.bn3.running_var": "encoder.stages.1.layers.1.layer.2.normalization.running_var",
    
    "layer2.2.conv1.weight": "encoder.stages.1.layers.2.layer.0.convolution.weight",
    "layer2.2.bn1.weight": "encoder.stages.1.layers.2.layer.0.normalization.weight",
    "layer2.2.bn1.bias": "encoder.stages.1.layers.2.layer.0.normalization.bias",
    "layer2.2.bn1.running_mean": "encoder.stages.1.layers.2.layer.0.normalization.running_mean",
    "layer2.2.bn1.running_var": "encoder.stages.1.layers.2.layer.0.normalization.running_var",
    "layer2.2.conv2.weight": "encoder.stages.1.layers.2.layer.1.convolution.weight",
    "layer2.2.bn2.weight": "encoder.stages.1.layers.2.layer.1.normalization.weight",
    "layer2.2.bn2.bias": "encoder.stages.1.layers.2.layer.1.normalization.bias",
    "layer2.2.bn2.running_mean": "encoder.stages.1.layers.2.layer.1.normalization.running_mean",
    "layer2.2.bn2.running_var": "encoder.stages.1.layers.2.layer.1.normalization.running_var",
    "layer2.2.conv3.weight": "encoder.stages.1.layers.2.layer.2.convolution.weight",
    "layer2.2.bn3.weight": "encoder.stages.1.layers.2.layer.2.normalization.weight",
    "layer2.2.bn3.bias": "encoder.stages.1.layers.2.layer.2.normalization.bias",
    "layer2.2.bn3.running_mean": "encoder.stages.1.layers.2.layer.2.normalization.running_mean",
    "layer2.2.bn3.running_var": "encoder.stages.1.layers.2.layer.2.normalization.running_var",
    
    "layer2.3.conv1.weight": "encoder.stages.1.layers.3.layer.0.convolution.weight",
    "layer2.3.bn1.weight": "encoder.stages.1.layers.3.layer.0.normalization.weight",
    "layer2.3.bn1.bias": "encoder.stages.1.layers.3.layer.0.normalization.bias",
    "layer2.3.bn1.running_mean": "encoder.stages.1.layers.3.layer.0.normalization.running_mean",
    "layer2.3.bn1.running_var": "encoder.stages.1.layers.3.layer.0.normalization.running_var",
    "layer2.3.conv2.weight": "encoder.stages.1.layers.3.layer.1.convolution.weight",
    "layer2.3.bn2.weight": "encoder.stages.1.layers.3.layer.1.normalization.weight",
    "layer2.3.bn2.bias": "encoder.stages.1.layers.3.layer.1.normalization.bias",
    "layer2.3.bn2.running_mean": "encoder.stages.1.layers.3.layer.1.normalization.running_mean",
    "layer2.3.bn2.running_var": "encoder.stages.1.layers.3.layer.1.normalization.running_var",
    "layer2.3.conv3.weight": "encoder.stages.1.layers.3.layer.2.convolution.weight",
    "layer2.3.bn3.weight": "encoder.stages.1.layers.3.layer.2.normalization.weight",
    "layer2.3.bn3.bias": "encoder.stages.1.layers.3.layer.2.normalization.bias",
    "layer2.3.bn3.running_mean": "encoder.stages.1.layers.3.layer.2.normalization.running_mean",
    "layer2.3.bn3.running_var": "encoder.stages.1.layers.3.layer.2.normalization.running_var",
    
    # Layer 3
    "layer3.0.conv1.weight": "encoder.stages.2.layers.0.layer.0.convolution.weight",
    "layer3.0.bn1.weight": "encoder.stages.2.layers.0.layer.0.normalization.weight",
    "layer3.0.bn1.bias": "encoder.stages.2.layers.0.layer.0.normalization.bias",
    "layer3.0.bn1.running_mean": "encoder.stages.2.layers.0.layer.0.normalization.running_mean",
    "layer3.0.bn1.running_var": "encoder.stages.2.layers.0.layer.0.normalization.running_var",
    "layer3.0.conv2.weight": "encoder.stages.2.layers.0.layer.1.convolution.weight",
    "layer3.0.bn2.weight": "encoder.stages.2.layers.0.layer.1.normalization.weight",
    "layer3.0.bn2.bias": "encoder.stages.2.layers.0.layer.1.normalization.bias",
    "layer3.0.bn2.running_mean": "encoder.stages.2.layers.0.layer.1.normalization.running_mean",
    "layer3.0.bn2.running_var": "encoder.stages.2.layers.0.layer.1.normalization.running_var",
    "layer3.0.conv3.weight": "encoder.stages.2.layers.0.layer.2.convolution.weight",
    "layer3.0.bn3.weight": "encoder.stages.2.layers.0.layer.2.normalization.weight",
    "layer3.0.bn3.bias": "encoder.stages.2.layers.0.layer.2.normalization.bias",
    "layer3.0.bn3.running_mean": "encoder.stages.2.layers.0.layer.2.normalization.running_mean",
    "layer3.0.bn3.running_var": "encoder.stages.2.layers.0.layer.2.normalization.running_var",
    "layer3.0.downsample.0.weight": "encoder.stages.2.layers.0.shortcut.convolution.weight",
    "layer3.0.downsample.1.weight": "encoder.stages.2.layers.0.shortcut.normalization.weight",
    "layer3.0.downsample.1.bias": "encoder.stages.2.layers.0.shortcut.normalization.bias",
    "layer3.0.downsample.1.running_mean": "encoder.stages.2.layers.0.shortcut.normalization.running_mean",
    "layer3.0.downsample.1.running_var": "encoder.stages.2.layers.0.shortcut.normalization.running_var",
    
    "layer3.1.conv1.weight": "encoder.stages.2.layers.1.layer.0.convolution.weight",
    "layer3.1.bn1.weight": "encoder.stages.2.layers.1.layer.0.normalization.weight",
    "layer3.1.bn1.bias": "encoder.stages.2.layers.1.layer.0.normalization.bias",
    "layer3.1.bn1.running_mean": "encoder.stages.2.layers.1.layer.0.normalization.running_mean",
    "layer3.1.bn1.running_var": "encoder.stages.2.layers.1.layer.0.normalization.running_var",
    "layer3.1.conv2.weight": "encoder.stages.2.layers.1.layer.1.convolution.weight",
    "layer3.1.bn2.weight": "encoder.stages.2.layers.1.layer.1.normalization.weight",
    "layer3.1.bn2.bias": "encoder.stages.2.layers.1.layer.1.normalization.bias",
    "layer3.1.bn2.running_mean": "encoder.stages.2.layers.1.layer.1.normalization.running_mean",
    "layer3.1.bn2.running_var": "encoder.stages.2.layers.1.layer.1.normalization.running_var",
    "layer3.1.conv3.weight": "encoder.stages.2.layers.1.layer.2.convolution.weight",
    "layer3.1.bn3.weight": "encoder.stages.2.layers.1.layer.2.normalization.weight",
    "layer3.1.bn3.bias": "encoder.stages.2.layers.1.layer.2.normalization.bias",
    "layer3.1.bn3.running_mean": "encoder.stages.2.layers.1.layer.2.normalization.running_mean",
    "layer3.1.bn3.running_var": "encoder.stages.2.layers.1.layer.2.normalization.running_var",
    
    "layer3.2.conv1.weight": "encoder.stages.2.layers.2.layer.0.convolution.weight",
    "layer3.2.bn1.weight": "encoder.stages.2.layers.2.layer.0.normalization.weight",
    "layer3.2.bn1.bias": "encoder.stages.2.layers.2.layer.0.normalization.bias",
    "layer3.2.bn1.running_mean": "encoder.stages.2.layers.2.layer.0.normalization.running_mean",
    "layer3.2.bn1.running_var": "encoder.stages.2.layers.2.layer.0.normalization.running_var",
    "layer3.2.conv2.weight": "encoder.stages.2.layers.2.layer.1.convolution.weight",
    "layer3.2.bn2.weight": "encoder.stages.2.layers.2.layer.1.normalization.weight",
    "layer3.2.bn2.bias": "encoder.stages.2.layers.2.layer.1.normalization.bias",
    "layer3.2.bn2.running_mean": "encoder.stages.2.layers.2.layer.1.normalization.running_mean",
    "layer3.2.bn2.running_var": "encoder.stages.2.layers.2.layer.1.normalization.running_var",
    "layer3.2.conv3.weight": "encoder.stages.2.layers.2.layer.2.convolution.weight",
    "layer3.2.bn3.weight": "encoder.stages.2.layers.2.layer.2.normalization.weight",
    "layer3.2.bn3.bias": "encoder.stages.2.layers.2.layer.2.normalization.bias",
    "layer3.2.bn3.running_mean": "encoder.stages.2.layers.2.layer.2.normalization.running_mean",
    "layer3.2.bn3.running_var": "encoder.stages.2.layers.2.layer.2.normalization.running_var",
    
    "layer3.3.conv1.weight": "encoder.stages.2.layers.3.layer.0.convolution.weight",
    "layer3.3.bn1.weight": "encoder.stages.2.layers.3.layer.0.normalization.weight",
    "layer3.3.bn1.bias": "encoder.stages.2.layers.3.layer.0.normalization.bias",
    "layer3.3.bn1.running_mean": "encoder.stages.2.layers.3.layer.0.normalization.running_mean",
    "layer3.3.bn1.running_var": "encoder.stages.2.layers.3.layer.0.normalization.running_var",
    "layer3.3.conv2.weight": "encoder.stages.2.layers.3.layer.1.convolution.weight",
    "layer3.3.bn2.weight": "encoder.stages.2.layers.3.layer.1.normalization.weight",
    "layer3.3.bn2.bias": "encoder.stages.2.layers.3.layer.1.normalization.bias",
    "layer3.3.bn2.running_mean": "encoder.stages.2.layers.3.layer.1.normalization.running_mean",
    "layer3.3.bn2.running_var": "encoder.stages.2.layers.3.layer.1.normalization.running_var",
    "layer3.3.conv3.weight": "encoder.stages.2.layers.3.layer.2.convolution.weight",
    "layer3.3.bn3.weight": "encoder.stages.2.layers.3.layer.2.normalization.weight",
    "layer3.3.bn3.bias": "encoder.stages.2.layers.3.layer.2.normalization.bias",
    "layer3.3.bn3.running_mean": "encoder.stages.2.layers.3.layer.2.normalization.running_mean",
    "layer3.3.bn3.running_var": "encoder.stages.2.layers.3.layer.2.normalization.running_var",
    
    "layer3.4.conv1.weight": "encoder.stages.2.layers.4.layer.0.convolution.weight",
    "layer3.4.bn1.weight": "encoder.stages.2.layers.4.layer.0.normalization.weight",
    "layer3.4.bn1.bias": "encoder.stages.2.layers.4.layer.0.normalization.bias",
    "layer3.4.bn1.running_mean": "encoder.stages.2.layers.4.layer.0.normalization.running_mean",
    "layer3.4.bn1.running_var": "encoder.stages.2.layers.4.layer.0.normalization.running_var",
    "layer3.4.conv2.weight": "encoder.stages.2.layers.4.layer.1.convolution.weight",
    "layer3.4.bn2.weight": "encoder.stages.2.layers.4.layer.1.normalization.weight",
    "layer3.4.bn2.bias": "encoder.stages.2.layers.4.layer.1.normalization.bias",
    "layer3.4.bn2.running_mean": "encoder.stages.2.layers.4.layer.1.normalization.running_mean",
    "layer3.4.bn2.running_var": "encoder.stages.2.layers.4.layer.1.normalization.running_var",
    "layer3.4.conv3.weight": "encoder.stages.2.layers.4.layer.2.convolution.weight",
    "layer3.4.bn3.weight": "encoder.stages.2.layers.4.layer.2.normalization.weight",
    "layer3.4.bn3.bias": "encoder.stages.2.layers.4.layer.2.normalization.bias",
    "layer3.4.bn3.running_mean": "encoder.stages.2.layers.4.layer.2.normalization.running_mean",
    "layer3.4.bn3.running_var": "encoder.stages.2.layers.4.layer.2.normalization.running_var",
    
    "layer3.5.conv1.weight": "encoder.stages.2.layers.5.layer.0.convolution.weight",
    "layer3.5.bn1.weight": "encoder.stages.2.layers.5.layer.0.normalization.weight",
    "layer3.5.bn1.bias": "encoder.stages.2.layers.5.layer.0.normalization.bias",
    "layer3.5.bn1.running_mean": "encoder.stages.2.layers.5.layer.0.normalization.running_mean",
    "layer3.5.bn1.running_var": "encoder.stages.2.layers.5.layer.0.normalization.running_var",
    "layer3.5.conv2.weight": "encoder.stages.2.layers.5.layer.1.convolution.weight",
    "layer3.5.bn2.weight": "encoder.stages.2.layers.5.layer.1.normalization.weight",
    "layer3.5.bn2.bias": "encoder.stages.2.layers.5.layer.1.normalization.bias",
    "layer3.5.bn2.running_mean": "encoder.stages.2.layers.5.layer.1.normalization.running_mean",
    "layer3.5.bn2.running_var": "encoder.stages.2.layers.5.layer.1.normalization.running_var",
    "layer3.5.conv3.weight": "encoder.stages.2.layers.5.layer.2.convolution.weight",
    "layer3.5.bn3.weight": "encoder.stages.2.layers.5.layer.2.normalization.weight",
    "layer3.5.bn3.bias": "encoder.stages.2.layers.5.layer.2.normalization.bias",
    "layer3.5.bn3.running_mean": "encoder.stages.2.layers.5.layer.2.normalization.running_mean",
    "layer3.5.bn3.running_var": "encoder.stages.2.layers.5.layer.2.normalization.running_var",
    
    # Layer 4
    "layer4.0.conv1.weight": "encoder.stages.3.layers.0.layer.0.convolution.weight",
    "layer4.0.bn1.weight": "encoder.stages.3.layers.0.layer.0.normalization.weight",
    "layer4.0.bn1.bias": "encoder.stages.3.layers.0.layer.0.normalization.bias",
    "layer4.0.bn1.running_mean": "encoder.stages.3.layers.0.layer.0.normalization.running_mean",
    "layer4.0.bn1.running_var": "encoder.stages.3.layers.0.layer.0.normalization.running_var",
    "layer4.0.conv2.weight": "encoder.stages.3.layers.0.layer.1.convolution.weight",
    "layer4.0.bn2.weight": "encoder.stages.3.layers.0.layer.1.normalization.weight",
    "layer4.0.bn2.bias": "encoder.stages.3.layers.0.layer.1.normalization.bias",
    "layer4.0.bn2.running_mean": "encoder.stages.3.layers.0.layer.1.normalization.running_mean",
    "layer4.0.bn2.running_var": "encoder.stages.3.layers.0.layer.1.normalization.running_var",
    "layer4.0.conv3.weight": "encoder.stages.3.layers.0.layer.2.convolution.weight",
    "layer4.0.bn3.weight": "encoder.stages.3.layers.0.layer.2.normalization.weight",
    "layer4.0.bn3.bias": "encoder.stages.3.layers.0.layer.2.normalization.bias",
    "layer4.0.bn3.running_mean": "encoder.stages.3.layers.0.layer.2.normalization.running_mean",
    "layer4.0.bn3.running_var": "encoder.stages.3.layers.0.layer.2.normalization.running_var",
    "layer4.0.downsample.0.weight": "encoder.stages.3.layers.0.shortcut.convolution.weight",
    "layer4.0.downsample.1.weight": "encoder.stages.3.layers.0.shortcut.normalization.weight",
    "layer4.0.downsample.1.bias": "encoder.stages.3.layers.0.shortcut.normalization.bias",
    "layer4.0.downsample.1.running_mean": "encoder.stages.3.layers.0.shortcut.normalization.running_mean",
    "layer4.0.downsample.1.running_var": "encoder.stages.3.layers.0.shortcut.normalization.running_var",
    
    "layer4.1.conv1.weight": "encoder.stages.3.layers.1.layer.0.convolution.weight",
    "layer4.1.bn1.weight": "encoder.stages.3.layers.1.layer.0.normalization.weight",
    "layer4.1.bn1.bias": "encoder.stages.3.layers.1.layer.0.normalization.bias",
    "layer4.1.bn1.running_mean": "encoder.stages.3.layers.1.layer.0.normalization.running_mean",
    "layer4.1.bn1.running_var": "encoder.stages.3.layers.1.layer.0.normalization.running_var",
    "layer4.1.conv2.weight": "encoder.stages.3.layers.1.layer.1.convolution.weight",
    "layer4.1.bn2.weight": "encoder.stages.3.layers.1.layer.1.normalization.weight",
    "layer4.1.bn2.bias": "encoder.stages.3.layers.1.layer.1.normalization.bias",
    "layer4.1.bn2.running_mean": "encoder.stages.3.layers.1.layer.1.normalization.running_mean",
    "layer4.1.bn2.running_var": "encoder.stages.3.layers.1.layer.1.normalization.running_var",
    "layer4.1.conv3.weight": "encoder.stages.3.layers.1.layer.2.convolution.weight",
    "layer4.1.bn3.weight": "encoder.stages.3.layers.1.layer.2.normalization.weight",
    "layer4.1.bn3.bias": "encoder.stages.3.layers.1.layer.2.normalization.bias",
    "layer4.1.bn3.running_mean": "encoder.stages.3.layers.1.layer.2.normalization.running_mean",
    "layer4.1.bn3.running_var": "encoder.stages.3.layers.1.layer.2.normalization.running_var",
    
    "layer4.2.conv1.weight": "encoder.stages.3.layers.2.layer.0.convolution.weight",
    "layer4.2.bn1.weight": "encoder.stages.3.layers.2.layer.0.normalization.weight",
    "layer4.2.bn1.bias": "encoder.stages.3.layers.2.layer.0.normalization.bias",
    "layer4.2.bn1.running_mean": "encoder.stages.3.layers.2.layer.0.normalization.running_mean",
    "layer4.2.bn1.running_var": "encoder.stages.3.layers.2.layer.0.normalization.running_var",
    "layer4.2.conv2.weight": "encoder.stages.3.layers.2.layer.1.convolution.weight",
    "layer4.2.bn2.weight": "encoder.stages.3.layers.2.layer.1.normalization.weight",
    "layer4.2.bn2.bias": "encoder.stages.3.layers.2.layer.1.normalization.bias",
    "layer4.2.bn2.running_mean": "encoder.stages.3.layers.2.layer.1.normalization.running_mean",
    "layer4.2.bn2.running_var": "encoder.stages.3.layers.2.layer.1.normalization.running_var",
    "layer4.2.conv3.weight": "encoder.stages.3.layers.2.layer.2.convolution.weight",
    "layer4.2.bn3.weight": "encoder.stages.3.layers.2.layer.2.normalization.weight",
    "layer4.2.bn3.bias": "encoder.stages.3.layers.2.layer.2.normalization.bias",
    "layer4.2.bn3.running_mean": "encoder.stages.3.layers.2.layer.2.normalization.running_mean",
    "layer4.2.bn3.running_var": "encoder.stages.3.layers.2.layer.2.normalization.running_var",
}

In [37]:
torchvision_state_dict = resnet50.state_dict()

mapped_state_dict = {}
for k, v in torchvision_state_dict.items():
    if k in layer_name_mapping:
        mapped_state_dict[layer_name_mapping[k]] = v

resnet50_hf.load_state_dict(mapped_state_dict)


<All keys matched successfully>

In [149]:
import os

# resnet50_hf.save_pretrained("resnet50_hf", push_to_hub=True, repo_id="galeio-research/nereus-sar-1", token=os.getenv("HF_TOKEN"))
resnet50_hf.save_pretrained("resnet50_hf")

In [20]:
resnet50_hf.forward(torch.randn(1, 1, 256, 256)).pooler_output.shape

torch.Size([1, 2048, 1, 1])

In [189]:
TASK = "wave"

In [41]:
import torch

path_w = "/home/antoine/models/nereus/dino_resnet50/probing_head_resnet50_regression_wind_speed/w.pt"
path_b = "/home/antoine/models/nereus/dino_resnet50/probing_head_resnet50_regression_wind_speed/b.pt"

w = torch.load(path_w, weights_only=False)
b = torch.load(path_b, weights_only=False)

w = torch.from_numpy(w)
b = torch.from_numpy(b)

w = w.unsqueeze(1)

state_dict_head = {'weight': w.T, 'bias': b}

FileNotFoundError: [Errno 2] No such file or directory: '/home/antoine/models/nereus/dino_resnet50/probing_head_resnet50_regression_wind_speed/w.pt'

In [48]:
import torch

head = torch.nn.Linear(2048, 1)
head.load_state_dict(state_dict_head)

<All keys matched successfully>

In [49]:
head_path = f"linear-head_resnet50_{TASK}.pth"

torch.save(head.state_dict(), head_path)

In [197]:
if TASK == "wind":
    PROBLEM_TYPE = "regression"
    NUM_LABELS = 1
    LABEL2ID = {"wind_speed": 0}
    ID2LABEL = {0: "wind_speed"}
elif TASK == "wave":
    PROBLEM_TYPE = "regression"
    NUM_LABELS = 1
    LABEL2ID = {"wave_height": 0}
    ID2LABEL = {0: "wave_height"}
elif TASK == "tengeop":
    PROBLEM_TYPE = "multi_label_classification"
    NUM_LABELS = 10

In [None]:
import torch
from transformers import ResNetForImageClassification, ResNetConfig
from torchinfo import summary

config_resnet = ResNetConfig.from_pretrained("/home/antoine/models/nereus/dino_resnet50/resnet50_hf")
config_with_linear_head = ResNetConfig(hidden_size=2048, finetuning_task=TASK, num_labels=NUM_LABELS, problem_type=PROBLEM_TYPE, **config_resnet.to_diff_dict())
nereus = ResNetForImageClassification(config_with_linear_head)

summary(nereus, input_size=(1, 1, 256, 256), row_settings=["var_names"], depth=6)


Layer (type (var_name))                                                     Output Shape              Param #
ResNetForImageClassification (ResNetForImageClassification)                 [1, 1]                    --
├─ResNetModel (resnet)                                                      [1, 2048, 1, 1]           --
│    └─ResNetEmbeddings (embedder)                                          [1, 64, 64, 64]           --
│    │    └─ResNetConvLayer (embedder)                                      [1, 64, 128, 128]         --
│    │    │    └─Conv2d (convolution)                                       [1, 64, 128, 128]         3,136
│    │    │    └─BatchNorm2d (normalization)                                [1, 64, 128, 128]         128
│    │    │    └─ReLU (activation)                                          [1, 64, 128, 128]         --
│    │    └─MaxPool2d (pooler)                                              [1, 64, 64, 64]           --
│    └─ResNetEncoder (encoder)                

In [222]:
nereus.resnet.from_pretrained("/home/antoine/models/nereus/dino_resnet50/resnet50_hf")
state_dict_head = torch.load(f"/home/antoine/models/nereus/dino_resnet50/linear-head_resnet50_{TASK}.weights.pth")
nereus.classifier[1].load_state_dict(state_dict_head)

<All keys matched successfully>

In [237]:
x = torch.randn(1, 1, 256, 256).cuda()

In [238]:
with torch.no_grad():
    outputs = nereus.forward(x)
outputs.values()

odict_values([tensor([[13.1773]], device='cuda:0')])

In [200]:
nereus.push_to_hub(f"galeio-research/nereus-sar-1-{TASK}",private=True, token=os.environ['HF_TOKEN'])

model.safetensors: 100%|██████████| 94.3M/94.3M [00:02<00:00, 32.6MB/s]


CommitInfo(commit_url='https://huggingface.co/galeio-research/nereus-sar-1-wave/commit/a9c9ffe6acd725bec9bdc179a262215aa8b96e91', commit_message='Upload ResNetForImageClassification', commit_description='', oid='a9c9ffe6acd725bec9bdc179a262215aa8b96e91', pr_url=None, repo_url=RepoUrl('https://huggingface.co/galeio-research/nereus-sar-1-wave', endpoint='https://huggingface.co', repo_type='model', repo_id='galeio-research/nereus-sar-1-wave'), pr_revision=None, pr_num=None)

In [240]:
import os
import torch
from transformers import AutoModelForImageClassification

nereus = AutoModelForImageClassification.from_pretrained(f"galeio-research/nereus-sar-1-{TASK}", token=os.environ['HF_TOKEN']).to("cuda")

with torch.no_grad():
    outputs = nereus(x)
outputs.values()

odict_values([tensor([[224.4390]], device='cuda:0')])

In [243]:
nereus.finetuning_task

AttributeError: 'ResNetForImageClassification' object has no attribute 'finetuning_task'

In [202]:
from torchinfo import summary

summary(nereus, input_size=(1, 1, 256, 256), row_settings=["var_names"], depth=6)

Layer (type (var_name))                                                     Output Shape              Param #
ResNetForImageClassification (ResNetForImageClassification)                 [1, 1]                    --
├─ResNetModel (resnet)                                                      [1, 2048, 1, 1]           --
│    └─ResNetEmbeddings (embedder)                                          [1, 64, 64, 64]           --
│    │    └─ResNetConvLayer (embedder)                                      [1, 64, 128, 128]         --
│    │    │    └─Conv2d (convolution)                                       [1, 64, 128, 128]         3,136
│    │    │    └─BatchNorm2d (normalization)                                [1, 64, 128, 128]         128
│    │    │    └─ReLU (activation)                                          [1, 64, 128, 128]         --
│    │    └─MaxPool2d (pooler)                                              [1, 64, 64, 64]           --
│    └─ResNetEncoder (encoder)                