On cherche à réecrire le checkpoint pth de SAM pour être "conforme" à lightning, i.e avec les clés attendues dans le dictionnaire de state_dict.

L'erreur initiale est : `KeyError: 'pytorch-lightning_version'`

In [1]:
import hydra
import torch
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

In [2]:
ls ../../configs

[0m[01;34mcallbacks[0m/  eval.yaml    [01;34mhparams_search[0m/  [01;34mlocal[0m/    [01;34mmodel[0m/     [01;34mtrainer[0m/
[01;34mdata[0m/       [01;34mexperiment[0m/  [01;34mhydra[0m/           [01;34mlogger[0m/   [01;34mpaths[0m/     train.yaml
[01;34mdebug[0m/      [01;34mextras[0m/      __init__.py      [01;34mmetrics[0m/  [01;34msam_type[0m/


In [3]:
def load_config():
    # Initialize the Hydra configuration
    hydra.initialize(config_path="../../configs", version_base=None)
    
    # Compose the configuration with the desired environment override
    cfg = hydra.compose(config_name="train", overrides=["experiment=adapter", "sam_type=small", "data=levir-cd"])
    
    return cfg

In [4]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
cfg = load_config()
print(OmegaConf.to_yaml(cfg))

data:
  name: levir-cd
  _target_: src.data.datamodule.CDDataModule
  params:
    prompt_type: sample
    n_prompt: 1
    loc: center
    batch_size: 2
    n_shape: 3
    num_worker: 2
    pin_memory: false
model:
  network:
    image_encoder:
      _target_: src.models.magic_pen.adapter.ImageEncoderViTAdapter
      depth: 12
      embed_dim: 768
      img_size: 1024
      mlp_ratio: 4
      norm_layer: null
      num_heads: 12
      patch_size: 16
      qkv_bias: true
      use_rel_pos: true
      global_attn_indexes:
      - 2
      - 5
      - 8
      - 11
      window_size: 14
      out_chans: 256
      adapter_inter_dim: 16
    prompt_encoder:
      _target_: src.models.segment_anything.modeling.prompt_encoder_dev.PromptEncoder
      embed_dim: 512
      image_embedding_size:
      - 64
      - 64
      input_image_size:
      - 1024
      - 1024
      mask_in_chans: 16
    mask_decoder:
      transformer:
        _target_: src.models.segment_anything.modeling.transformer_dev.TwoW

In [10]:
module = hydra.utils.instantiate(cfg.model.instance)

2024-07-30 15:58:46,285 - INFO ::  Weights loaded for : ['image_encoder']


In [8]:
path = "/var/data/usr/mdizier/stylo_magique/checkpoints/sam/sam_vit_b_01ec64.pth"
module.model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [7]:
module.model.image_encoder

ImageEncoderViTAdapter(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x AdapterBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
      (adapter): Adapter(
        (act): GELU(approximate='none')
        (down_layer): Linear(in_features=768, out_features=16, bias=True)
        (up_layer): Linear(in_features=16, out_features=768, bias=True)
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(768, 256, kernel_size=(1, 1), stri

In [13]:
import re

bool(re.search('adapter', "blocks.0.adater.up_layer.weight"))

False

In [11]:
for name, m in module.model.image_encoder.named_parameters():
    #if not name.startwith
    print(name, m.shape)

pos_embed torch.Size([1, 64, 64, 768])
patch_embed.proj.weight torch.Size([768, 3, 16, 16])
patch_embed.proj.bias torch.Size([768])
blocks.0.scale torch.Size([768])
blocks.0.norm1.weight torch.Size([768])
blocks.0.norm1.bias torch.Size([768])
blocks.0.attn.rel_pos_h torch.Size([27, 64])
blocks.0.attn.rel_pos_w torch.Size([27, 64])
blocks.0.attn.qkv.weight torch.Size([2304, 768])
blocks.0.attn.qkv.bias torch.Size([2304])
blocks.0.attn.proj.weight torch.Size([768, 768])
blocks.0.attn.proj.bias torch.Size([768])
blocks.0.norm2.weight torch.Size([768])
blocks.0.norm2.bias torch.Size([768])
blocks.0.mlp.lin1.weight torch.Size([3072, 768])
blocks.0.mlp.lin1.bias torch.Size([3072])
blocks.0.mlp.lin2.weight torch.Size([768, 3072])
blocks.0.mlp.lin2.bias torch.Size([768])
blocks.0.adapter.down_layer.weight torch.Size([16, 768])
blocks.0.adapter.down_layer.bias torch.Size([16])
blocks.0.adapter.up_layer.weight torch.Size([768, 16])
blocks.0.adapter.up_layer.bias torch.Size([768])
blocks.1.scale 