In [1]:
import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

Welcome in JZAY


In [2]:
from src.commons.utils_io import load_sam
from src.models.commons.bisam import BiSam2, SamModeInference
from src.models.commons.model import BiSam

from src.commons.utils import batch_to_list

### Load dloader manually

In [3]:
from src.data.loader import BiTemporalDataset
from src.data.process import DefaultTransform
from omegaconf import OmegaConf

params = {
    "prompt_type": "sample",
    "n_prompt": 1,
    "n_shape":3,
    "loc": "center",
    "batch_size": 2,
}
ds = BiTemporalDataset(
            name="levir-cd",
            dtype="test",
            transform=DefaultTransform(),
            params=OmegaConf.create(params),
        )

In [4]:
dloader = data.DataLoader(
            ds,
            batch_size=params.get('batch_size'),
            shuffle=False,
            num_workers=0,
        )

In [5]:
batch = next(iter(dloader))

/home/rustt/Documents/IGN/data/levir-cd/test/label/test_1.png
/home/rustt/Documents/IGN/data/levir-cd/test/label/test_2.png


### Load model

In [6]:
bisam = load_sam(
    model_type="vit_b", model_cls=BiSam, version= "dev", device="cpu"
)

2024-07-23 21:21:21,942 - INFO ::  build vit_b BiSam


### Load config

In [7]:
from src.commons.utils_io import load_config
import hydra
from hydra.core.global_hydra import GlobalHydra

In [8]:
GlobalHydra.instance().clear()
list_args=["experiment=mp_naive", "sam_type=small", "data=levir-cd", "data.params.n_shape=3", "data.params.num_worker=0"]
cfg = load_config(list_args)

In [9]:
module = hydra.utils.instantiate(cfg.model.instance)

In [10]:
input_images = torch.cat(
    [batch["img_A"], batch["img_B"]]
)
input_images = bisam.preprocess(input_images)

In [11]:
%%time
emb_B = bisam.image_encoder(batch["img_B"][0].unsqueeze(0))

CPU times: user 49.7 s, sys: 17.5 s, total: 1min 7s
Wall time: 8.49 s


In [11]:
img_A = torch.rand((256, 64, 64))
img_B = torch.rand((256, 64, 64))

In [12]:
image_embeddings = torch.cat([img_A, img_B])

In [13]:
# one mask for prompt
point_coords = batch["point_coords"]
point_labels = batch["point_labels"]

sparse_embeddings, dense_embeddings = bisam.prompt_encoder(
    points=(
        point_coords[:,None, :],
        point_labels[:, None,...],
    ), 
    boxes=None,
    masks=None,
)

In [14]:
print(f"sparse_embeddings: {sparse_embeddings.shape}")
print(f"dense_embeddings: {dense_embeddings.shape}")

sparse_embeddings: torch.Size([2, 1, 4, 256])
dense_embeddings: torch.Size([2, 1, 256, 64, 64])


In [15]:
import torch.nn as nn
layer = nn.Conv2d(image_embeddings.shape[0], image_embeddings.shape[0] // 2, kernel_size=3, padding=1)
out = layer(image_embeddings)

In [16]:
layer.weight.shape

torch.Size([256, 512, 3, 3])

In [33]:
print(image_embeddings.shape)
print(out.shape)

torch.Size([512, 64, 64])
torch.Size([256, 64, 64])


### Try to extent SAM modules to bitemporal
* prompt embedding
* Maks decoder

Essayons de concatener les embeddings

In [42]:
image_embeddings.shape

torch.Size([512, 64, 64])

In [43]:
image_embeddings = image_embeddings.unsqueeze(0) # extent over batch dimension

In [40]:
from src.models.segment_anything.modeling.prompt_encoder_dev import PromptEncoder

In [49]:
prompt_embed_dim = image_embeddings.shape[1]
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size


prompt_encoder_extent = PromptEncoder(
    embed_dim=prompt_embed_dim,
    image_embedding_size=(image_embedding_size, image_embedding_size),
    input_image_size=(image_size, image_size),
    mask_in_chans=16,
)

In [50]:
sparse_embeddings, dense_embeddings = prompt_encoder_extent(
    points=(
        point_coords[:,None, :],
        point_labels[:, None,...],
    ), 
    boxes=None,
    masks=None,
)

In [51]:
print(f"sparse_embeddings: {sparse_embeddings.shape}")
print(f"dense_embeddings: {dense_embeddings.shape}")

sparse_embeddings: torch.Size([2, 1, 4, 512])
dense_embeddings: torch.Size([2, 1, 512, 64, 64])


### Cross attention

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LocalCrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super(LocalCrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x1, x2):
        B, C, H, W = x1.size()
        window_size = self.window_size
        assert H % window_size == 0 and W % window_size == 0, "Height and Width must be divisible by the window size."

        # Function to divide into windows
        def window_partition(x):
            """
            return [B * num_windows, C, window_size, window_size].
            """
            B, C, H, W = x.shape
            x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
            windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
            windows = windows.view(-1, C, window_size, window_size)
            return windows

        # Function to merge windows back to feature map
        def window_reverse(windows, H, W):
            B = int(windows.shape[0] / (H * W / window_size / window_size))
            x = windows.view(B, H // window_size, W // window_size, -1, window_size, window_size)
            x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
            x = x.view(B, -1, H, W)
            return x

        x1_windows = window_partition(x1)  # shape: (num_windows*B, C, window_size, window_size)
        x2_windows = window_partition(x2)  # shape: (num_windows*B, C, window_size, window_size)

        # Flatten spatial dimensions
        x1_windows = x1_windows.view(-1, C, window_size * window_size).permute(0, 2, 1)  # (num_windows*B, window_size*window_size, C)
        x2_windows = x2_windows.view(-1, C, window_size * window_size).permute(0, 2, 1)  # (num_windows*B, window_size*window_size, C)

        # Apply cross attention within windows : multihead_attn(query, key, value)
        attn_output, _ = self.multihead_attn(x1_windows, x2_windows, x2_windows)  # (num_windows*B, window_size*window_size, C)
        x1_windows = x1_windows + attn_output
        x1_windows = self.norm1(x1_windows)

        # Feed Forward
        ff_output = self.ff(x1_windows)  # (num_windows*B, window_size*window_size, C)
        x1_windows = x1_windows + ff_output
        x1_windows = self.norm2(x1_windows)

        # Reshape back to windowed spatial dimensions
        x1_windows = x1_windows.permute(0, 2, 1).view(-1, C, window_size, window_size)  # (num_windows*B, C, window_size, window_size)

        # Merge windows back to original dimensions
        x1 = window_reverse(x1_windows, H, W)  # (B, C, H, W)

        return x1


sam_decoder = None  # replace with actual decoder
embed_dim = 256
num_heads = 8
window_size = 8  # Example window size, adjust as needed

# Instantiate the local cross attention module
local_cross_attention_module = LocalCrossAttention(embed_dim, num_heads, window_size)

B = 2
# Example input tensors, replace with actual image data
emb1 = torch.rand(B, 256, 64, 64)
emb2 = torch.rand(B, 256, 64, 64)

output = local_cross_attention_module(emb1, emb2)

In [6]:
output.shape

torch.Size([2, 256, 64, 64])

=> implement with current Attention Layer

In [61]:
# Function to divide into windows
def window_partition(x):
    """
    return [B * num_windows, C, window_size, window_size].
    """
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
    windows = windows.view(-1, C, window_size, window_size)
    return windows

In [74]:
stack_tensor = []
for _ in range(10):
    stack_tensor.append(
        torch.zeros(1, 1, 64, 64) + _
    )
stack_tensor = torch.cat(stack_tensor, dim=1)

In [75]:
stack_tensor.shape

torch.Size([1, 10, 64, 64])

In [76]:
window_size = 8
win_tensor = window_partition(stack_tensor)

In [77]:
win_tensor.shape

torch.Size([64, 10, 8, 8])

In [19]:
low_res_masks, iou_predictions =  bisam.mask_decoder.predict_masks_batch(
    image_embeddings=image_embeddings,  # (B, 256, 64, 64)
    image_pe=bisam.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
    sparse_prompt_embeddings=sparse_embeddings,  # (B, N, 2, 256)
    dense_prompt_embeddings=dense_embeddings,  # (B, N, 256, 64, 64)
)

preds, iou_predictions = bisam.select_masks(
    low_res_masks, 
    iou_predictions, 
    multimask_output=False
)
preds = bisam.upscale_masks(
    preds,
    IMG_SIZE
)
preds = preds > 0

RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 4

In [2]:
x.shape

NameError: name 'x' is not defined

### Custom impl

In [5]:
import hydra
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils_io import load_sam
from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

### Load dloader manually

In [91]:
from src.data.loader import BiTemporalDataset
from src.data.process import DefaultTransform
from omegaconf import OmegaConf

params = {
    "prompt_type": "sample",
    "n_prompt": 1,
    "n_shape":3,
    "loc": "center",
    "batch_size": 2,
}
ds = BiTemporalDataset(
            name="levir-cd",
            dtype="train",
            transform=DefaultTransform(),
            params=OmegaConf.create(params),
        )

dloader = data.DataLoader(
            ds,
            batch_size=params.get('batch_size'),
            shuffle=False,
            num_workers=0,
        )

batch = next(iter(dloader))

/home/MDizier/data/dl/levir-cd/train/label/train_1.png
/home/MDizier/data/dl/levir-cd/train/label/train_2.png


In [92]:
len(ds)

445

In [88]:
(445 // 2 )*2

444

In [8]:
#module = hydra.utils.instantiate(cfg.model.instance)

In [9]:
# path = "/var/data/usr/mdizier/stylo_magique/checkpoints/sam/sam_vit_b_01ec64.pth"
# module.model.load_state_dict(torch.load(path))

In [10]:
from src.models.segment_any_change.model import BiSam

bisam = load_sam(
    model_type="vit_b", model_cls=BiSam, version= "dev", device="cpu"
)

2024-07-30 14:14:00,322 - INFO ::  build vit_b BiSam


In [11]:
patcher = bisam.image_encoder.patch_embed

In [12]:
img_patches = patcher(batch["img_A"])

In [15]:
next(iter(patcher.proj.named_parameters()))[1].shape

torch.Size([768, 3, 16, 16])

In [14]:
print(batch["img_A"].shape)
print(img_patches.shape)

torch.Size([2, 3, 1024, 1024])
torch.Size([2, 64, 64, 768])


In [74]:
import math
from typing import Type
from models.segment_anything.modeling.common import MLPBlock
from models.segment_anything.modeling.transformer import Attention
from torch import Tensor, nn
import torch


class CrossAttentionBlock(nn.Module):
    def __init__(
            self, 
            embedding_dim: int, 
            num_heads: int, 
            mlp_dim: int = 2048,
            activation: Type[nn.Module] = nn.ReLU,

        ) -> None:
        super().__init__()
        self.cross_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.norm3 = nn.LayerNorm(embedding_dim)
        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)

    def forward(self, queries, keys):
        """Queries attend to keys"""
        q = self.norm1(queries)
        k = self.norm2(keys)
        attn_out = self.cross_attn(q=q, k=k, v=k)
        q = q + attn_out
        q = self.norm2(q)
        out = self.mlp(q) + q
        return out


# class Attention(nn.Module):
#     """
#     An attention layer that allows for downscaling the size of the embedding
#     after projection to queries, keys, and values.
#     """

#     def __init__(
#         self,
#         embedding_dim: int,
#         num_heads: int,
#         downsample_rate: int = 1,
#     ) -> None:
#         super().__init__()
#         self.embedding_dim = embedding_dim
#         self.internal_dim = embedding_dim // downsample_rate
#         self.num_heads = num_heads
#         assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

#         self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
#         self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
#         self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
#         self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

#     def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
#         b, n, c = x.shape
#         x = x.reshape(b, n, num_heads, c // num_heads)
#         return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

#     def _recombine_heads(self, x: Tensor) -> Tensor:
#         b, n_heads, n_tokens, c_per_head = x.shape
#         x = x.transpose(1, 2)
#         return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

#     def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
#         # Input projections
#         q = self.q_proj(q)
#         k = self.k_proj(k)
#         v = self.v_proj(v)

#         # Separate into heads
#         print("q", q.shape)
#         q = self._separate_heads(q, self.num_heads)
#         print("q", q.shape)

#         k = self._separate_heads(k, self.num_heads)
#         v = self._separate_heads(v, self.num_heads)

#         # Attention
#         _, _, _, c_per_head = q.shape
#         attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
#         attn = attn / math.sqrt(c_per_head)
#         attn = torch.softmax(attn, dim=-1)
#         print("attn shape", attn.shape)
#         # Get output
#         out = attn @ v
#         print(out.shape)
#         out = self._recombine_heads(out)
#         out = self.out_proj(out)

#         return out


In [69]:
x1 = bisam.image_encoder(batch["img_A"]).permute(0, 2, 3, 1)
x2 = bisam.image_encoder(batch["img_B"]).permute(0, 2, 3, 1)

In [75]:
blk = CrossAttentionBlock(embedding_dim=256, num_heads=4, mlp_dim=2048)

In [76]:
x1.shape

torch.Size([2, 64, 64, 256])

In [77]:
b, h, w, c = x1.shape

In [78]:
# flat spatial dimensions
out = blk(queries=x1.view(2, -1, 256), keys=x2.view(2, -1, 256))

In [79]:
64*64

4096

In [80]:
out.shape

torch.Size([2, 4096, 256])

In [81]:
out = out.view(b, h, w, c).permute(0, 3, 1, 2)

In [82]:
out.shape

torch.Size([2, 256, 64, 64])

In [101]:
672 / 24

28.0