* Let's decompose attention

In [1]:
import hydra
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils_io import load_sam
from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

In [2]:
def load_config():
    # Initialize the Hydra configuration
    hydra.initialize(config_path="../../configs", version_base=None)
    
    # Compose the configuration with the desired environment override
    cfg = hydra.compose(config_name="train", overrides=["experiment=probing_diff", "sam_type=small", "data=levir-cd"])
    
    return cfg

In [3]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
cfg = load_config()
print(OmegaConf.to_yaml(cfg))

data:
  name: levir-cd
  _target_: src.data.datamodule.CDDataModule
  params:
    prompt_type: sample
    n_prompt: 1
    loc: center
    batch_size: 2
    num_worker: 4
    pin_memory: false
    n_shape: 3
model:
  network:
    image_encoder:
      _target_: src.models.segment_anything.modeling.image_encoder_dev.ImageEncoderViT
      depth: 12
      embed_dim: 768
      img_size: 1024
      mlp_ratio: 4
      norm_layer: null
      num_heads: 12
      patch_size: 16
      qkv_bias: true
      use_rel_pos: true
      global_attn_indexes:
      - 2
      - 5
      - 8
      - 11
      window_size: 14
      out_chans: 256
    prompt_encoder:
      _target_: src.models.segment_anything.modeling.prompt_encoder_dev.PromptEncoder
      embed_dim: 256
      image_embedding_size:
      - 64
      - 64
      input_image_size:
      - 1024
      - 1024
      mask_in_chans: 16
    mask_decoder:
      transformer:
        _target_: src.models.segment_anything.modeling.transformer_dev.TwoWayTransfo

### Load dloader manually

In [4]:
from src.data.loader import BiTemporalDataset
from src.data.process import DefaultTransform
from omegaconf import OmegaConf

params = {
    "prompt_type": "sample",
    "n_prompt": 1,
    "n_shape":3,
    "loc": "center",
    "batch_size": 2,
}
ds = BiTemporalDataset(
            name="levir-cd",
            dtype="test",
            transform=DefaultTransform(),
            params=OmegaConf.create(params),
        )

dloader = data.DataLoader(
            ds,
            batch_size=params.get('batch_size'),
            shuffle=False,
            num_workers=0,
        )

batch = next(iter(dloader))

In [5]:
module = hydra.utils.instantiate(cfg.model.instance)

INIT VIT


2024-08-05 16:11:18,889 - INFO ::  Weights loaded for : ['image_encoder']


In [6]:
path = "/var/data/usr/mdizier/stylo_magique/checkpoints/sam/sam_vit_b_01ec64.pth"
module.model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [7]:
from src.models.segment_any_change.model import BiSam

bisam = load_sam(
    model_type="vit_b", model_cls=BiSam, version= "dev", device="cpu"
)

2024-08-05 16:11:19,831 - INFO ::  build vit_b BiSam


INIT VIT


In [8]:
patcher = bisam.image_encoder.patch_embed

In [9]:
img_patches = patcher(batch["img_A"])

In [10]:
next(iter(patcher.proj.named_parameters()))[1].shape

torch.Size([768, 3, 16, 16])

In [11]:
print(batch["img_A"].shape)
print(img_patches.shape)

torch.Size([2, 3, 1024, 1024])
torch.Size([2, 64, 64, 768])


In [12]:
3*1024*1024

3145728

In [13]:
64*64*768

3145728

### Attention

In [14]:
from torch import nn

In [15]:
num_heads = 8
dim = 768
qkv_bias=True
head_dim = dim // num_heads
scale = head_dim**-0.5

qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
proj = nn.Linear(dim, dim)

In [16]:
x = img_patches.clone()
print("patches", x.shape)

patches torch.Size([2, 64, 64, 768])


* expand last dim (channel dimension) with Linear

In [17]:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv_out = (
    qkv(x)
)
print("qkv_out", qkv_out.shape)
qkv_out = qkv_out.reshape(B, H * W, 3, num_heads, -1).permute(2, 0, 3, 1, 4)
print("qkv_out resh", qkv_out.shape)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv_out.reshape(3, B * num_heads, H * W, -1).unbind(0)
print("q :", q.shape)
print("k :", k.shape)
print("v :", v.shape)
attn = (q * scale) @ k.transpose(-2, -1)
print("attn :", v.shape)

qkv_out torch.Size([2, 64, 64, 2304])
qkv_out resh torch.Size([3, 2, 8, 4096, 96])
q : torch.Size([16, 4096, 96])
k : torch.Size([16, 4096, 96])
v : torch.Size([16, 4096, 96])
attn : torch.Size([16, 4096, 96])


In [30]:
k.transpose(-2, -1).shape

torch.Size([16, 96, 4096])

In [29]:
attn = attn.softmax(dim=-1)
x = (
    (attn @ v)
    .view(B, num_heads, H, W, -1)
    .permute(0, 2, 3, 1, 4)
    .reshape(B, H, W, -1)
)
print("x :", x.shape)
x = proj(x)
print("out :", x.shape)

x : torch.Size([2, 64, 64, 768])
out : torch.Size([2, 64, 64, 768])
