### Load partial weights

* cas d'usage : charger weights image encoder et prompt encoder et entrainer decoder.

In [1]:
import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

In [2]:
from src.commons.utils_io import load_sam
from src.models.commons.bisam import BiSam2, SamModeInference
from src.models.segment_any_change.model import BiSam

from src.commons.utils import batch_to_list

In [3]:
def show_prompts_on_mask(mask: torch.Tensor, batch, batch_idx: int):
    if mask.shape[-1] != IMG_SIZE[0]:
        mask = resize(mask, IMG_SIZE)
    coord_points = batch["point_coords"][batch_idx]
    mask_pt = get_mask_with_prompt(binarize_mask(mask[batch_idx], th=0) , coord_points)
    show_img(mask_pt)

### Load dloader manually

In [4]:
from src.data.loader import BiTemporalDataset
from src.data.process import DefaultTransform
from omegaconf import OmegaConf

params = {
    "prompt_type": "sample",
    "n_prompt": 1,
    "n_shape":3,
    "loc": "center",
    "batch_size": 2,
}
ds = BiTemporalDataset(
            name="levir-cd",
            dtype="train",
            transform=DefaultTransform(),
            params=OmegaConf.create(params),
        )

In [25]:
dloader = data.DataLoader(
            ds,
            batch_size=params.get('batch_size'),
            shuffle=False,
            num_workers=0,
        )

In [26]:
batch = next(iter(dloader))

/home/MDizier/data/dl/levir-cd/test/label/test_1.png
/home/MDizier/data/dl/levir-cd/test/label/test_2.png


In [5]:
no_label_track = []
min_area = 10

for sample in tqdm(ds):
    if torch.sum(sample["label"]) < min_area:
        no_label_track.append(sample["label_path"])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 445/445 [19:37<00:00,  2.65s/it]


In [7]:
len(no_label_track)

47

In [8]:
no_label_track

['/home/MDizier/data/dl/levir-cd/train/label/train_183.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_192.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_195.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_196.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_197.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_198.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_204.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_205.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_206.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_212.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_213.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_214.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_217.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_218.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_219.png',
 '/home/MDizier/data/dl/levir-cd/train/label/train_221.png',
 '/home/MDizier/data/dl/

### change model : many prompt to one mask :

In [7]:
from src.commons.utils_io import load_config
import hydra
from hydra.core.global_hydra import GlobalHydra
import pprint as pp

In [10]:
GlobalHydra.instance().clear()
list_args=["experiment=probing_diff", "sam_type=small", "data=levir-cd", "data.params.n_shape=3", "data.params.num_worker=0"]
cfg = load_config(list_args)

In [11]:
module = hydra.utils.instantiate(cfg.model.instance)

2024-07-31 09:22:41,540 - INFO ::  Weights loaded for : ['image_encoder']


In [10]:
pp.pprint(dict(cfg))

{'callbacks': {'model_checkpoint': {'_target_': 'lightning.pytorch.callbacks.ModelCheckpoint', 'dirpath': '${paths.output_dir}/checkpoints', 'filename': 'epoch_{epoch:03d}', 'monitor': 'val/loss', 'verbose': False, 'save_last': True, 'save_top_k': 1, 'mode': 'min', 'auto_insert_metric_name': False, 'save_weights_only': False, 'every_n_train_steps': None, 'train_time_interval': None, 'every_n_epochs': None, 'save_on_train_epoch_end': None}, 'rich_progress_bar': {'_target_': 'lightning.pytorch.callbacks.RichProgressBar'}},
 'ckpt_path': None,
 'data': {'name': 'levir-cd', '_target_': 'src.data.datamodule.CDDataModule', 'params': {'prompt_type': 'sample', 'n_prompt': 1, 'loc': 'center', 'batch_size': 2, 'n_shape': 3, 'num_worker': 0, 'pin_memory': False}},
 'logger': {'tensorboard': {'_target_': 'lightning.pytorch.loggers.tensorboard.TensorBoardLogger', 'save_dir': '${paths.output_dir}/tensorboard/', 'name': None, 'default_hp_metric': False}},
 'model': {'network': {'image_encoder': {'_ta

In [12]:
model = module.model

In [16]:
print(f"total model parameters : {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

total model parameters : 93.74M


In [18]:
print(f"image encoder number parameters : {sum(p.numel() for p in model.image_encoder.parameters())/1e6:.2f}M")

image encoder number parameters : 89.67M


In [19]:
print(f"prompt_encoder number parameters : {sum(p.numel() for p in model.prompt_encoder.parameters())/1e6:.2f}M")

prompt_encoder number parameters : 0.01M


In [20]:
print(f"mask_decoder number parameters : {sum(p.numel() for p in model.mask_decoder.parameters())/1e6:.2f}M")

mask_decoder number parameters : 4.06M


In [12]:
model

BiSamDiff(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()


In [14]:
from src.commons.utils_io import load_sam
from src.models.magic_pen.bisam_concat import BiSamConcat

In [15]:
bisam_concat = load_sam(model_type="vit_b", model_cls=BiSamConcat, is_strict=False, embed_dim=512)

2024-07-26 16:04:41,574 - INFO ::  build vit_b BiSamConcat


RuntimeError: Error(s) in loading state_dict for BiSamConcat:
	size mismatch for image_encoder.neck.0.weight: copying a param with shape torch.Size([256, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 768, 1, 1]).
	size mismatch for image_encoder.neck.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for image_encoder.neck.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for image_encoder.neck.2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for image_encoder.neck.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for image_encoder.neck.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for prompt_encoder.pe_layer.positional_encoding_gaussian_matrix: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([2, 256]).
	size mismatch for prompt_encoder.point_embeddings.0.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for prompt_encoder.point_embeddings.1.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for prompt_encoder.point_embeddings.2.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for prompt_encoder.point_embeddings.3.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for prompt_encoder.not_a_point_embed.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for prompt_encoder.mask_downscaling.6.weight: copying a param with shape torch.Size([256, 16, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 16, 1, 1]).
	size mismatch for prompt_encoder.mask_downscaling.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for prompt_encoder.no_mask_embed.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.q_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.k_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.v_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.out_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.0.self_attn.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.mlp.lin1.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
	size mismatch for mask_decoder.transformer.layers.0.mlp.lin2.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for mask_decoder.transformer.layers.0.mlp.lin2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm4.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.norm4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.q_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.q_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.k_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.v_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.out_proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.transformer.layers.1.self_attn.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.mlp.lin1.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
	size mismatch for mask_decoder.transformer.layers.1.mlp.lin2.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
	size mismatch for mask_decoder.transformer.layers.1.mlp.lin2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm4.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.norm4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.q_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.q_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.k_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.k_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.v_proj.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.v_proj.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.out_proj.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for mask_decoder.transformer.final_attn_token_to_image.out_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.norm_final_attn.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.transformer.norm_final_attn.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.iou_token.weight: copying a param with shape torch.Size([1, 256]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for mask_decoder.mask_tokens.weight: copying a param with shape torch.Size([4, 256]) from checkpoint, the shape in current model is torch.Size([4, 512]).
	size mismatch for mask_decoder.output_upscaling.0.weight: copying a param with shape torch.Size([256, 64, 2, 2]) from checkpoint, the shape in current model is torch.Size([512, 128, 2, 2]).
	size mismatch for mask_decoder.output_upscaling.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mask_decoder.output_upscaling.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mask_decoder.output_upscaling.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mask_decoder.output_upscaling.3.weight: copying a param with shape torch.Size([64, 32, 2, 2]) from checkpoint, the shape in current model is torch.Size([128, 64, 2, 2]).
	size mismatch for mask_decoder.output_upscaling.3.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.1.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.2.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.0.layers.2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.1.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.2.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.1.layers.2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.1.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.2.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.2.layers.2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.1.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.2.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for mask_decoder.output_hypernetworks_mlps.3.layers.2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for mask_decoder.iou_prediction_head.layers.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).

In [None]:
model_dict = self.model.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() if not k.startswith("mask_decoder")}
model_dict.update(pretrained_weights)
# let default values for decoder init 
self.model.load_state_dict(model_dict)

In [None]:
weights = torch.load(cfg.sam_ckpt_path)

In [21]:
weights.keys()

dict_keys(['image_encoder.neck.0.weight', 'image_encoder.neck.1.weight', 'image_encoder.neck.1.bias', 'image_encoder.neck.2.weight', 'image_encoder.neck.3.weight', 'image_encoder.neck.3.bias', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias', 'image_encoder.blocks.0.norm1.weight', 'image_encoder.blocks.0.norm1.bias', 'image_encoder.blocks.0.attn.rel_pos_h', 'image_encoder.blocks.0.attn.rel_pos_w', 'image_encoder.blocks.0.attn.qkv.weight', 'image_encoder.blocks.0.attn.qkv.bias', 'image_encoder.blocks.0.attn.proj.weight', 'image_encoder.blocks.0.attn.proj.bias', 'image_encoder.blocks.0.norm2.weight', 'image_encoder.blocks.0.norm2.bias', 'image_encoder.blocks.0.mlp.lin1.weight', 'image_encoder.blocks.0.mlp.lin1.bias', 'image_encoder.blocks.0.mlp.lin2.weight', 'image_encoder.blocks.0.mlp.lin2.bias', 'image_encoder.blocks.1.norm1.weight', 'image_encoder.blocks.1.norm1.bias', 'image_encoder.blocks.1.attn.rel_pos_h', 'image_encoder.blocks.1.attn.rel_pos_w', 'imag