In [1]:
import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

In [2]:
from src.commons.utils_io import load_sam
from src.models.commons.bisam import BiSam2, SamModeInference
from src.models.segment_any_change.model import BiSam

from src.commons.utils import batch_to_list

In [3]:
def show_prompts_on_mask(mask: torch.Tensor, batch, batch_idx: int):
    if mask.shape[-1] != IMG_SIZE[0]:
        mask = resize(mask, IMG_SIZE)
    coord_points = batch["point_coords"][batch_idx]
    mask_pt = get_mask_with_prompt(binarize_mask(mask[batch_idx], th=0) , coord_points)
    show_img(mask_pt)

### Load dloader manually

In [8]:
from src.data.loader import BiTemporalDataset
from src.data.process import DefaultTransform
from omegaconf import OmegaConf

params = {
    "prompt_type": "sample",
    "n_prompt": 1,
    "n_shape":3,
    "loc": "center",
    "batch_size": 2,
}
ds = BiTemporalDataset(
            name="levir-cd",
            dtype="test",
            transform=DefaultTransform(),
            params=OmegaConf.create(params),
        )

In [9]:
dloader = data.DataLoader(
            ds,
            batch_size=params.get('batch_size'),
            shuffle=False,
            num_workers=0,
        )

In [10]:
batch = next(iter(dloader))

/home/MDizier/data/dl/levir-cd/test/label/test_1.png
/home/MDizier/data/dl/levir-cd/test/label/test_2.png


### Load model

In [7]:
bisam2 = load_sam(
    model_type="vit_b", model_cls=BiSam2, version= "dev2", device="cpu"
)

2024-07-25 09:50:04,288 - INFO ::  build vit_b BiSam2


In [8]:
bisam = load_sam(
    model_type="vit_b", model_cls=BiSam, version= "dev", device="cpu"
)

2024-07-25 09:50:05,274 - INFO ::  build vit_b BiSam


### change model : many prompt to one mask :

In [3]:
from src.commons.utils_io import load_config
import hydra
from hydra.core.global_hydra import GlobalHydra

In [4]:
GlobalHydra.instance().clear()
list_args=["experiment=mp_naive", "sam_type=small", "data=levir-cd", "data.params.n_shape=3", "data.params.num_worker=0"]
cfg = load_config(list_args)

In [5]:
module = hydra.utils.instantiate(cfg.model.instance)

In [6]:
cfg.data.params.num_worker

0

In [9]:
from src.models.segment_anything.modeling.prompt_encoder import PromptEncoder

In [12]:
prompt_encoder = module.model.prompt_encoder

In [10]:
prompt_embed_dim = 512
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size


prompt_encoder_extent = PromptEncoder(
    embed_dim=prompt_embed_dim,
    image_embedding_size=(image_embedding_size, image_embedding_size),
    input_image_size=(image_size, image_size),
    mask_in_chans=16,
)

#### Get& set new weights

In [20]:
from torch import nn

In [21]:
for name, p in prompt_encoder.named_parameters():
    print(name, p.shape)

point_embeddings.0.weight torch.Size([1, 256])
point_embeddings.1.weight torch.Size([1, 256])
point_embeddings.2.weight torch.Size([1, 256])
point_embeddings.3.weight torch.Size([1, 256])
not_a_point_embed.weight torch.Size([1, 256])
mask_downscaling.0.weight torch.Size([4, 1, 2, 2])
mask_downscaling.0.bias torch.Size([4])
mask_downscaling.1.weight torch.Size([4])
mask_downscaling.1.bias torch.Size([4])
mask_downscaling.3.weight torch.Size([16, 4, 2, 2])
mask_downscaling.3.bias torch.Size([16])
mask_downscaling.4.weight torch.Size([16])
mask_downscaling.4.bias torch.Size([16])
mask_downscaling.6.weight torch.Size([256, 16, 1, 1])
mask_downscaling.6.bias torch.Size([256])
no_mask_embed.weight torch.Size([1, 256])


In [25]:
layer = nn.Conv2d(1, 16 // 4, kernel_size=2, stride=2)

In [27]:
layer.weight.data.shape

torch.Size([4, 1, 2, 2])

In [29]:
# copy
new_weights = {}
for name, p in prompt_encoder.named_parameters():
    if "weight" in name:
        if not "mask_downscaling" in name:
            new_weights[name] = torch.cat([p, p], dim=1)
        else:
            new_weights[name] = p

    print(name, p.shape)

point_embeddings.0.weight torch.Size([1, 256])
point_embeddings.1.weight torch.Size([1, 256])
point_embeddings.2.weight torch.Size([1, 256])
point_embeddings.3.weight torch.Size([1, 256])
not_a_point_embed.weight torch.Size([1, 256])
mask_downscaling.0.weight torch.Size([4, 1, 2, 2])
mask_downscaling.0.bias torch.Size([4])
mask_downscaling.1.weight torch.Size([4])
mask_downscaling.1.bias torch.Size([4])
mask_downscaling.3.weight torch.Size([16, 4, 2, 2])
mask_downscaling.3.bias torch.Size([16])
mask_downscaling.4.weight torch.Size([16])
mask_downscaling.4.bias torch.Size([16])
mask_downscaling.6.weight torch.Size([256, 16, 1, 1])
mask_downscaling.6.bias torch.Size([256])
no_mask_embed.weight torch.Size([1, 256])


In [37]:
state_dict = prompt_encoder_extent.state_dict()
state_dict.keys()

odict_keys(['pe_layer.positional_encoding_gaussian_matrix', 'point_embeddings.0.weight', 'point_embeddings.1.weight', 'point_embeddings.2.weight', 'point_embeddings.3.weight', 'not_a_point_embed.weight', 'mask_downscaling.0.weight', 'mask_downscaling.0.bias', 'mask_downscaling.1.weight', 'mask_downscaling.1.bias', 'mask_downscaling.3.weight', 'mask_downscaling.3.bias', 'mask_downscaling.4.weight', 'mask_downscaling.4.bias', 'mask_downscaling.6.weight', 'mask_downscaling.6.bias', 'no_mask_embed.weight'])

In [40]:
for p, name in new_weights.items():
    if name in state_dict:
        print(f"Found parameter '{name}' with shape {state_dict[name].shape}")
        state_dict[name] = p

In [None]:
# check if now we can

In [41]:
prompt_encoder = module.model.prompt_encoder

In [43]:
prompt_encoder_extent.load_state_dict(state_dict)


<All keys matched successfully>

In [None]:
prompt_encoder

### Mask decoder

In [49]:
from src.models.segment_anything.modeling.mask_decoder_dev import MaskDecoder
from src.models.segment_anything.modeling.transformer_dev import TwoWayTransformer

In [54]:
mask_decoder = module.model.mask_decoder

In [53]:
new_dim=512
mask_decoder_extent=MaskDecoder(
    num_multimask_outputs=3,
    transformer=TwoWayTransformer(
        depth=2,
        embedding_dim=new_dim,
        mlp_dim=2048,
        num_heads=8,
        ),
    transformer_dim=new_dim,
    iou_head_depth=3,
    iou_head_hidden_dim=256,
)

In [76]:
from src.models.segment_anything.modeling.common import LayerNorm2d

transformer_dim= 512
activation=nn.GELU
layer_up_extent = nn.Sequential(
            nn.ConvTranspose2d(
                transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
            ),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(
                transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
            ),
            activation(),
        )

In [89]:
print("extended : ")
for n, l2 in layer_up_extent.named_parameters():
    print(n, l2.shape)
print("======")
print("old : ")
for n, l2 in mask_decoder.output_upscaling.named_parameters():
    print(n, l2.shape)

extended : 
0.weight torch.Size([512, 128, 2, 2])
0.bias torch.Size([128])
1.weight torch.Size([128])
1.bias torch.Size([128])
3.weight torch.Size([128, 64, 2, 2])
3.bias torch.Size([64])
old : 
0.weight torch.Size([256, 64, 2, 2])
0.bias torch.Size([64])
1.weight torch.Size([64])
1.bias torch.Size([64])
3.weight torch.Size([64, 32, 2, 2])
3.bias torch.Size([32])


In [91]:
for layer in layer_up_extent.children():
    print(layer)
    if isinstance(layer, nn.Conv2d):
        print(layer.weight.data.shape)

ConvTranspose2d(512, 128, kernel_size=(2, 2), stride=(2, 2))
LayerNorm2d()
GELU(approximate='none')
ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
GELU(approximate='none')


In [85]:
layer_up_extent[0].weight.shape

torch.Size([512, 128, 2, 2])

In [86]:
mask_decoder.output_upscaling[0].weight.shape

torch.Size([256, 64, 2, 2])

In [99]:
type(layer_up_extent)

torch.nn.modules.container.Sequential

In [97]:
state_dict = mask_decoder.output_upscaling.state_dict()

new_weights = {}

for name, p in  mask_decoder.output_upscaling.named_parameters():
    if "weight" in name and p.ndim > 1:
        print(p.shape)
        #new_weights[name] = torch.cat([p, p], dim)
        new_weights[name] = p.repeat(2, 1, 1, 1).repeat(1, 2, 1, 1)
    else:
        new_weights[name] = torch.cat([p, p])
for name, p in new_weights.items():
    if name in state_dict:
        print(f"Found parameter '{name}' with shape {p.shape}")
        state_dict[name] = p


torch.Size([256, 64, 2, 2])
torch.Size([64, 32, 2, 2])
Found parameter '0.weight' with shape torch.Size([512, 128, 2, 2])
Found parameter '0.bias' with shape torch.Size([128])
Found parameter '1.weight' with shape torch.Size([128])
Found parameter '1.bias' with shape torch.Size([128])
Found parameter '3.weight' with shape torch.Size([128, 64, 2, 2])
Found parameter '3.bias' with shape torch.Size([64])


In [98]:
layer_up_extent.load_state_dict(state_dict)

<All keys matched successfully>

In [100]:
img_A = torch.rand(1, 256, 64, 64)
img_B = torch.rand(1, 256, 64, 64)
img = torch.cat([img_A, img_B], dim=1)
print(f"im shape : {img.shape}")
out = layer_up_extent(img)

im shape : torch.Size([1, 512, 64, 64])


In [None]:
# copy
new_weights = {}
for name, p in mask_decoder.named_parameters():
    if "weight" in name:
        if not "mask_downscaling" in name:
            new_weights[name] = torch.cat([p, p], dim=1)
        else:
            new_weights[name] = p

    print(name, p.shape)

In [101]:
for name, p in mask_decoder_extent.named_parameters():
    print(name, p.shape)

transformer.layers.0.self_attn.q_proj.weight torch.Size([512, 512])
transformer.layers.0.self_attn.q_proj.bias torch.Size([512])
transformer.layers.0.self_attn.k_proj.weight torch.Size([512, 512])
transformer.layers.0.self_attn.k_proj.bias torch.Size([512])
transformer.layers.0.self_attn.v_proj.weight torch.Size([512, 512])
transformer.layers.0.self_attn.v_proj.bias torch.Size([512])
transformer.layers.0.self_attn.out_proj.weight torch.Size([512, 512])
transformer.layers.0.self_attn.out_proj.bias torch.Size([512])
transformer.layers.0.norm1.weight torch.Size([512])
transformer.layers.0.norm1.bias torch.Size([512])
transformer.layers.0.cross_attn_token_to_image.q_proj.weight torch.Size([256, 512])
transformer.layers.0.cross_attn_token_to_image.q_proj.bias torch.Size([256])
transformer.layers.0.cross_attn_token_to_image.k_proj.weight torch.Size([256, 512])
transformer.layers.0.cross_attn_token_to_image.k_proj.bias torch.Size([256])
transformer.layers.0.cross_attn_token_to_image.v_proj.we

In [None]:
t1 = torch.tensor([10, 10])
t2 = torch.tensor([20, 20])

In [56]:
for name, p in mask_decoder.named_parameters():
    print(name, p.shape)

transformer.layers.0.self_attn.q_proj.weight torch.Size([256, 256])
transformer.layers.0.self_attn.q_proj.bias torch.Size([256])
transformer.layers.0.self_attn.k_proj.weight torch.Size([256, 256])
transformer.layers.0.self_attn.k_proj.bias torch.Size([256])
transformer.layers.0.self_attn.v_proj.weight torch.Size([256, 256])
transformer.layers.0.self_attn.v_proj.bias torch.Size([256])
transformer.layers.0.self_attn.out_proj.weight torch.Size([256, 256])
transformer.layers.0.self_attn.out_proj.bias torch.Size([256])
transformer.layers.0.norm1.weight torch.Size([256])
transformer.layers.0.norm1.bias torch.Size([256])
transformer.layers.0.cross_attn_token_to_image.q_proj.weight torch.Size([128, 256])
transformer.layers.0.cross_attn_token_to_image.q_proj.bias torch.Size([128])
transformer.layers.0.cross_attn_token_to_image.k_proj.weight torch.Size([128, 256])
transformer.layers.0.cross_attn_token_to_image.k_proj.bias torch.Size([128])
transformer.layers.0.cross_attn_token_to_image.v_proj.we

Compliqué d'aligner les dimenions, i.e de copier les poids. Faisons simple pour l'instant : MLP qui projette la concatentaion dans la dimension d'entrée du transformer (256). On initialisera le reseau pour renvoyer une des images ($f(i1, i2) = i1$)

## Concat + proj layer

In [4]:
from src.models.magic_pen.bisam_concat import BiSamConcat

In [5]:
bisam_concat = load_sam(model_type="vit_b", model_cls=BiSamConcat, version="dev", is_strict=False)

2024-07-25 17:35:18,550 - INFO ::  build vit_b BiSamConcat


RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 8796093022208 bytes. Error code 12 (Cannot allocate memory)

In [10]:
t1 = torch.rand(1, 256, 64, 64)
t2 = torch.rand(1, 256, 64, 64)

t = torch.cat([t1, t2], dim=1)



In [11]:
from torch import nn
embedding_dim = 512
proj_layer = nn.Sequential(
    nn.Linear(embedding_dim, embedding_dim // 2),
    nn.GELU(),
    nn.Linear(embedding_dim // 2, embedding_dim // 2),
)
nt = t.permute(0, 2, 3, 1)
out = proj_layer(nt)

In [14]:
proj_layer[2].weight.shape

torch.Size([256, 256])

In [12]:
out.shape

torch.Size([1, 64, 64, 256])

In [11]:
bisam_concat.proj_layer

Sequential(
  (0): Linear(in_features=512, out_features=256, bias=True)
  (1): GELU(approximate='none')
  (2): Linear(in_features=256, out_features=256, bias=True)
)

In [15]:
bisam_concat.proj_layer[2].weight.shape

torch.Size([256, 256])

In [30]:
t1[0]

tensor([[[0.1032, 0.7729, 0.2045,  ..., 0.2749, 0.0041, 0.5167],
         [0.0074, 0.6442, 0.1886,  ..., 0.9020, 0.1105, 0.5572],
         [0.0708, 0.1315, 0.0636,  ..., 0.1769, 0.3098, 0.7716],
         ...,
         [0.8921, 0.3043, 0.5191,  ..., 0.4278, 0.4532, 0.6163],
         [0.8727, 0.4626, 0.5150,  ..., 0.9024, 0.8074, 0.7236],
         [0.1782, 0.8031, 0.4868,  ..., 0.1242, 0.3426, 0.9536]],

        [[0.5906, 0.2554, 0.1775,  ..., 0.3373, 0.0346, 0.0101],
         [0.9056, 0.8721, 0.1368,  ..., 0.9305, 0.0161, 0.0986],
         [0.2562, 0.3361, 0.3596,  ..., 0.0277, 0.9163, 0.9335],
         ...,
         [0.5878, 0.8925, 0.1020,  ..., 0.4737, 0.1308, 0.6430],
         [0.3662, 0.3958, 0.0709,  ..., 0.5917, 0.2893, 0.7657],
         [0.5449, 0.9591, 0.5408,  ..., 0.8525, 0.2673, 0.5322]],

        [[0.3269, 0.6760, 0.1862,  ..., 0.9873, 0.6220, 0.9696],
         [0.6660, 0.1807, 0.7212,  ..., 0.2625, 0.5854, 0.7509],
         [0.7402, 0.2325, 0.5359,  ..., 0.9713, 0.7115, 0.

In [31]:
t.shape

torch.Size([1, 512, 64, 64])

In [32]:
nt = t.view(1, 512, -1)
nt.shape

torch.Size([1, 512, 4096])

In [27]:
bisam_concat.proj_layer(t)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32768x64 and 512x256)