In [3]:
from segmentation_models_pytorch import Unet

# Initialize Unet with ResNet34 encoder pre-trained on ImageNet
model = Unet(encoder_name="resnet50", encoder_weights="imagenet")

In [4]:
model

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [12]:
ResNet50_Weights.SENTINEL2_ALL_MOCO.meta

{'dataset': 'SSL4EO-S12',
 'in_chans': 13,
 'model': 'resnet50',
 'publication': 'https://arxiv.org/abs/2211.07044',
 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
 'ssl_method': 'moco',
 'bands': ['B1',
  'B2',
  'B3',
  'B4',
  'B5',
  'B6',
  'B7',
  'B8',
  'B8a',
  'B9',
  'B10',
  'B11',
  'B12']}

In [40]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
%reload_ext autoreload

In [1]:
from src.model_zoo.models import define_model_torchgeo
import timm
import torch
import torchgeo

  print(f"\SATELLITE BAND ADAPTATION for {key_name}")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Bands dict mapping

weight_sentinel2_bands = {
0:'B1',
1:  'B2',
2:  'B3',
3:  'B4',
4:  'B5',
5:  'B6',
6:  'B7',
7:  'B8',
8:  'B8a',
9:  'B9',
10:  'B10',
11:  'B11',
12:  'B12'
}
select_bands = [1,2,3,4,5,6,7,8,11,12]   ## match with bands presented at BigEarthNet 
model = define_model_torchgeo('resnet50',
                               weights='ResNet50_Weights.SENTINEL2_ALL_MOCO',
                                num_classes=20,
                                input_channels=10,
                                bands = weight_sentinel2_bands,
                                selected_channels = select_bands,
                                freeze_backbone=False)

Loading PyTorchGeo weights by name: ResNet50_Weights.SENTINEL2_ALL_MOCO
\SATELLITE BAND ADAPTATION for conv1.weight
Source channels: 13 -> Target channels: 10
Available bands in source weights:
  Channel  0: B1
  Channel  1: B2
  Channel  2: B3
  Channel  3: B4
  Channel  4: B5
  Channel  5: B6
  Channel  6: B7
  Channel  7: B8
  Channel  8: B8a
  Channel  9: B9
  Channel 10: B10
  Channel 11: B11
  Channel 12: B12
REMOVING 3 specific channels:
 Removing Channel  0: B1
 Removing Channel  9: B9
 Removing Channel 10: B10
SELECTING 10 channels (in order):
  Position  0 <- Channel  1: B2
  Position  1 <- Channel  2: B3
  Position  2 <- Channel  3: B4
  Position  3 <- Channel  4: B5
  Position  4 <- Channel  5: B6
  Position  5 <- Channel  6: B7
  Position  6 <- Channel  7: B8
  Position  7 <- Channel  8: B8a
  Position  8 <- Channel 11: B11
  Position  9 <- Channel 12: B12
CUSTOM CHANNEL SELECTION APPLIED
   Original order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
   Selected order: [1, 

In [4]:
model_teste = timm.create_model(
    "resnet50",
    num_classes=20,
    in_chans=10,
    pretrained=False
)

In [7]:
weight_num = torchgeo.models.get_weight("ResNet50_Weights.SENTINEL2_ALL_MOCO")
state_dict = weight_num.get_state_dict(progress=True)

In [16]:
import re
def _keys_match(state_key: str, model_key: str) -> bool:
    """Check if two keys represent the same parameter with different naming conventions."""
    # Remove common prefixes/suffixes that might differ
    state_clean = re.sub(r'^(backbone\.|encoder\.|features\.)', '', state_key)
    model_clean = re.sub(r'^(backbone\.|encoder\.|features\.)', '', model_key)
    
    # Check for exact match after cleaning
    if state_clean == model_clean:
        return True
    
    # Check for common substitutions
    substitutions = [
        (r'\.weight$', '.weight'),
        (r'\.bias$', '.bias'),
        (r'bn(\d+)', r'norm\1'),  # batch norm naming
        (r'norm(\d+)', r'bn\1'),
        (r'downsample\.0', 'downsample.conv'),
        (r'downsample\.1', 'downsample.norm'),
    ]
    
    for pattern, replacement in substitutions:
        if re.sub(pattern, replacement, state_clean) == model_clean:
            return True
        if state_clean == re.sub(pattern, replacement, model_clean):
            return True
    
    return False

In [11]:
key_mapping = {}
unmatched_keys = []
    
for state_key in state_dict.keys():
    # Try exact match first
    if state_key in state_dict:
        key_mapping[state_key] = state_key
        continue
    
    # Try to find similar keys (handle different naming conventions)
    matched = False
    for model_key in state_dict.keys():
        if _keys_match(state_key, model_key):
            key_mapping[state_key] = model_key
            matched = True
            break
    
    if not matched:
        unmatched_keys.append(state_key)

In [20]:
for state_key, model_key in key_mapping.items():
    state_tensor = state_dict[state_key]
    model_tensor = state_dict[model_key]

In [22]:
key_mapping.items()

dict_items([('conv1.weight', 'conv1.weight'), ('bn1.weight', 'bn1.weight'), ('bn1.bias', 'bn1.bias'), ('bn1.running_mean', 'bn1.running_mean'), ('bn1.running_var', 'bn1.running_var'), ('bn1.num_batches_tracked', 'bn1.num_batches_tracked'), ('layer1.0.conv1.weight', 'layer1.0.conv1.weight'), ('layer1.0.bn1.weight', 'layer1.0.bn1.weight'), ('layer1.0.bn1.bias', 'layer1.0.bn1.bias'), ('layer1.0.bn1.running_mean', 'layer1.0.bn1.running_mean'), ('layer1.0.bn1.running_var', 'layer1.0.bn1.running_var'), ('layer1.0.bn1.num_batches_tracked', 'layer1.0.bn1.num_batches_tracked'), ('layer1.0.conv2.weight', 'layer1.0.conv2.weight'), ('layer1.0.bn2.weight', 'layer1.0.bn2.weight'), ('layer1.0.bn2.bias', 'layer1.0.bn2.bias'), ('layer1.0.bn2.running_mean', 'layer1.0.bn2.running_mean'), ('layer1.0.bn2.running_var', 'layer1.0.bn2.running_var'), ('layer1.0.bn2.num_batches_tracked', 'layer1.0.bn2.num_batches_tracked'), ('layer1.0.conv3.weight', 'layer1.0.conv3.weight'), ('layer1.0.bn3.weight', 'layer1.0.bn

In [33]:
state_dict['conv1.weight'].shape

torch.Size([64, 13, 7, 7])

In [34]:
model_teste.state_dict()['conv1.weight'].shape[1]

10

In [17]:
_keys_match(state_key, 'conv1.weight')

False

In [18]:
state_key

'layer4.2.bn3.num_batches_tracked'

In [9]:
key_mapping

{'conv1.weight': 'conv1.weight',
 'bn1.weight': 'bn1.weight',
 'bn1.bias': 'bn1.bias',
 'bn1.running_mean': 'bn1.running_mean',
 'bn1.running_var': 'bn1.running_var',
 'bn1.num_batches_tracked': 'bn1.num_batches_tracked',
 'layer1.0.conv1.weight': 'layer1.0.conv1.weight',
 'layer1.0.bn1.weight': 'layer1.0.bn1.weight',
 'layer1.0.bn1.bias': 'layer1.0.bn1.bias',
 'layer1.0.bn1.running_mean': 'layer1.0.bn1.running_mean',
 'layer1.0.bn1.running_var': 'layer1.0.bn1.running_var',
 'layer1.0.bn1.num_batches_tracked': 'layer1.0.bn1.num_batches_tracked',
 'layer1.0.conv2.weight': 'layer1.0.conv2.weight',
 'layer1.0.bn2.weight': 'layer1.0.bn2.weight',
 'layer1.0.bn2.bias': 'layer1.0.bn2.bias',
 'layer1.0.bn2.running_mean': 'layer1.0.bn2.running_mean',
 'layer1.0.bn2.running_var': 'layer1.0.bn2.running_var',
 'layer1.0.bn2.num_batches_tracked': 'layer1.0.bn2.num_batches_tracked',
 'layer1.0.conv3.weight': 'layer1.0.conv3.weight',
 'layer1.0.bn3.weight': 'layer1.0.bn3.weight',
 'layer1.0.bn3.bias'

In [None]:
model_teste = timm.create_model(
    "restnet50",
    num_classes=20,
    in_chans=10,
    pretrained=False,
    **kwargs
)

# Load PyTorchGeo weights
if weights and weights is not True:
    try:
        # Handle different weight types
        if isinstance(weights, WeightsEnum):
            print(f"Loading PyTorchGeo weights: {weights}")
            state_dict = weights.get_state_dict(progress=True)

        elif isinstance(weights, str):
            if weights.endswith('.pth') or weights.endswith('.pt'):
                # Load from file path
                print(f"Loading weights from file: {weights}")
                state_dict = torch.load(weights, map_location='cpu')
                # Handle different state dict formats
                if 'state_dict' in state_dict:
                    state_dict = state_dict['state_dict']
                elif 'model' in state_dict:
                    state_dict = state_dict['model']
            else:
                # Load by PyTorchGeo weight name
                print(f"Loading PyTorchGeo weights by name: {weights}")
                weight_enum = get_weight(weights)
                state_dict = weight_enum.get_state_dict(progress=True)
        else:
            raise ValueError(f"Unsupported weight type: {type(weights)}")
        
        # Use flexible loading with satellite band information
        load_state_dict_with_flexibility(model, state_dict, strict=False, bands=bands)
        print("✓ Weights loaded successfully")

    except Exception as e:
        print(f"Failed to load weights: {e}")
        print("Continuing with model initialization...")

timm.models.resnet.ResNet

In [6]:
?model.load_state_dict

[31mSignature:[39m
model.load_state_dict(
    state_dict: collections.abc.Mapping[str, typing.Any],
    strict: bool = [38;5;28;01mTrue[39;00m,
    assign: bool = [38;5;28;01mFalse[39;00m,
)
[31mDocstring:[39m
Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.

    If :attr:`assign` is ``True`` the optimizer must be created after
    the call to :attr:`load_state_dict` unless
    :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

Args:
    state_dict (dict): a dict containing parameters and
        persistent buffers.
    strict (bool, optional): whether to strictly enforce that the keys
        in :attr:`state_dict` match the keys returned by this module's
        :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
    assign (bool, opt

In [9]:
torch.hub.load_state_dict_from_url(
    torchgeo.models.ResNet50_Weights.SENTINEL2_ALL_MOCO
)

AttributeError: 'ResNet50_Weights' object has no attribute 'decode'

In [14]:
type(torchgeo.models.get_weight("ResNet50_Weights.SENTINEL2_ALL_MOCO"))

<enum 'ResNet50_Weights'>

In [None]:
weight_num = torchgeo.models.get_weight("ResNet50_Weights.SENTINEL2_ALL_MOCO")
weight_num.get_state_dict(progress=True)


TypeError: 'ResNet50_Weights' object is not iterable

In [18]:
torch.hub.load_state_dict_from_url(weight_num.get_state_dict())

AttributeError: 'collections.OrderedDict' object has no attribute 'decode'

In [19]:
weight_num.get_state_dict(progress=True)

OrderedDict([('conv1.weight',
              tensor([[[[-2.5599e-03, -2.5125e-02, -2.9809e-02,  ..., -1.4785e-02,
                         -1.8853e-02, -1.7345e-02],
                        [-2.0063e-02, -3.7546e-02, -4.3469e-02,  ..., -2.9081e-02,
                         -3.2164e-02, -3.0974e-02],
                        [-6.6917e-03, -2.3737e-02, -2.7291e-02,  ..., -2.1477e-02,
                         -2.5538e-02, -2.4389e-02],
                        ...,
                        [ 1.0729e-02, -7.7632e-03, -8.8432e-03,  ..., -6.5538e-03,
                         -1.4675e-02, -1.7500e-02],
                        [ 1.7471e-02,  3.1923e-03, -3.4344e-04,  ...,  4.2906e-03,
                         -5.4006e-03, -8.1703e-03],
                        [ 2.0078e-02, -5.6803e-04, -7.4397e-03,  ..., -4.1560e-03,
                         -1.2066e-02, -1.2058e-02]],
              
                       [[-7.7723e-03, -2.5698e-02, -2.3960e-02,  ..., -1.3476e-02,
                         -1.4616

In [20]:
state_dict = weight_num.get_state_dict(progress=True)

In [23]:
state_dict.keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

In [36]:
model = timm.create_model('resnet50', pretrained=False, in_chans=13, num_classes=20)

In [37]:
model_dict = model.state_dict()

In [38]:
model_dict.keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

In [39]:
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "fc.weight", "fc.bias". 

In [8]:
getattr(torchgeo.models, 'resnet50')

<function torchgeo.models.resnet.resnet50(weights: torchgeo.models.resnet.ResNet50_Weights | None = None, *args: Any, **kwargs: Any) -> timm.models.resnet.ResNet | torch.nn.modules.container.ModuleDict>

In [9]:
modelClass = getattr(torchgeo.models, 'resnet50')



In [16]:
mmodel = modelClass(
    in_chans = 12,
    num_classes = 10,
    pretrained=True
   # 'ResNet50.Weights.SENTINEL2_ALL_MOCO'
)

In [19]:
mmodel.parameters()

<generator object Module.parameters at 0x74a334bdf680>

In [22]:
mmodel.num_classes

10

In [26]:
mmodel.num_features

2048

UsageError: Missing module name.


In [71]:
%load_ext autoreload
%autoreload 2
from src.model_zoo.classification import define_model_torchgeo

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
 %reload_ext autoreload

In [72]:
?define_model_torchgeo

[31mSignature:[39m define_model_torchgeo(name, out_channels=[32m3[39m, in_channel=[32m3[39m)
[31mDocstring:[39m <no docstring>
[31mFile:[39m      Dynamically generated function. No source code available.
[31mType:[39m      function

In [None]:
mgmodel = define_model_torchgeo(
    'resnet50',
    out_channels=20,
    in_channel=12
)

## How does timm decrease the input channek???
## How to select the correct activation function 

In [76]:
mgmodel.parameters

<bound method Module.parameters of ResNet(
  (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=T