In [1]:
# load env variables
from datetime import datetime
import os
from dotenv import load_dotenv
load_dotenv()
# set local dir logs based on date _${now:%Y-%m-%d}_${now:%H-%M-%S}
os.environ["SLURM_JOB_ID"] = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [1]:
import hydra
import torch
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

In [2]:
def load_config():
    # Initialize the Hydra configuration
    hydra.initialize(config_path="../configs", version_base=None)
    
    # Compose the configuration with the desired environment override
    cfg = hydra.compose(
        config_name="train", 
        overrides=["experiment=adapter_concat", 
                   "sam_type=small", 
                   "data=levir-cd",
                   "data.params.batch_size=1"
                  ])
    
    return cfg

def get_dloader(mode: str, dmodule):

    def wrap_mode(mode):
        if mode == "train":
            return "fit"
        return mode
    if not dmodule.ds_dict_type:
        mode_ = wrap_mode(mode)
        dmodule.setup(mode_)
    factory_dl = {
        "train": dmodule.train_dataloader,
        "val": dmodule.val_dataloader,
        "test": dmodule.test_dataloader,
    }
    return factory_dl[mode]()

In [3]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
cfg = load_config()

In [4]:
cfg.sam_ckpt_path

'/var/data/usr/mdizier/stylo_magique/checkpoints/sam/sam_vit_b_01ec64.pth'

In [6]:
module = hydra.utils.instantiate(cfg.model.instance)

INIT ADAPTER VIT


2024-08-13 20:45:18,587 - INFO ::  Weights loaded for : ['image_encoder']


In [7]:
data_module = hydra.utils.instantiate(cfg.data)

In [8]:
dloader = get_dloader("test", data_module)

In [9]:
batch = next(iter(dloader))

In [11]:
preds, _ = module.model(batch, multimask_output=False)
preds = preds.squeeze()

In [12]:
preds.shape

torch.Size([1024, 1024])

In [15]:
img_embeddings = module.model.image_embeddings

In [24]:
t = torch.rand((2, 2, 256, 64, 64))

In [29]:
res = F.max_pool2d(t, kernel_size=4, stride=2)

NameError: name 'F' is not defined

In [26]:
res.shape

torch.Size([2, 256, 32, 32])

In [33]:
ld = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
ld(t.view(-1, *t.shape[2:])).shape

torch.Size([4, 256, 32, 32])

In [42]:
l = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
res2 = l(res)

In [3]:
import torch.nn as nn

In [3]:
q = torch.rand(4096, 1, 256)
k = torch.rand(4096, 2, 256)
v = torch.rand(4096, 2, 256)

In [4]:
k.shape

torch.Size([4096, 2, 256])

In [5]:
print(q.shape)
print(k.transpose(-2, -1).shape)

torch.Size([4096, 1, 256])
torch.Size([4096, 256, 2])


In [6]:
attn = (q @ k.transpose(-2, -1))

In [7]:
attn.shape

torch.Size([4096, 1, 2])

In [35]:
torch.sqrt([5])

TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not list

On veut 4096 x 256

In [8]:
attn2 = (attn @ v)
attn2.shape

torch.Size([4096, 1, 256])

In [9]:
x = torch.rand(2, 256, 2, 64, 64)
B, C, T, H, W = x.shape

# flat spatial dimensions
x = x.permute(0, 3, 4, 2, 1).view(B, -1, T, C)

In [10]:
x.shape

torch.Size([2, 4096, 2, 256])

In [11]:
dim = 256
num_patches = 64 * 64
qkv_bias=False

wk = nn.Linear(dim, dim, bias=qkv_bias)
wv = nn.Linear(dim, dim, bias=qkv_bias)
wq = nn.Linear(dim, dim, bias=qkv_bias)
# self.q_learned = nn.Parameter(torch.zeros(1, 1, dim))


q_learned = nn.Parameter(torch.zeros(1, num_patches, dim))
# self.pos_embed = nn.Parameter(torch.zeros(1, self.num

In [15]:
q_learned.expand(B, -1, -1).shape

torch.Size([2, 4096, 256])

In [18]:
wk(x).shape

torch.Size([2, 4096, 2, 256])

In [21]:
B, N, T, C = x.shape
q = q_learned.expand(B, -1, -1).unsqueeze(2)
k = wk(x)
v = wv(x)

print("q", q.shape)
print("k", k.shape)
print("v", v.shape)

# attn torch.Size([B, 8, 4096, 4096])
attn = (q @ k.transpose(-2, -1)) 
print("attn", attn.shape)


q torch.Size([2, 4096, 1, 256])
k torch.Size([2, 4096, 2, 256])
v torch.Size([2, 4096, 2, 256])
attn torch.Size([2, 4096, 1, 2])


In [22]:

# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, num_patches, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# print("xb", x.shape)
# x = x[:,  :self.num_patches_original, :]
print("x cut", x.shape)

x cut torch.Size([2, 4096, 256])


In [11]:
from src.models.commons.rpe.irpe import build_rpe, get_rpe_config
num_heads=8
head_dim = dim // num_heads
rpe_config = get_rpe_config(
    ratio=1.9,
    method="euc",
    mode='ctx',
    shared_head=True,
    skip=0,
    rpe_on='k',# do we need more ? 
)
rpe_q, rpe_k, rpe_v = build_rpe(rpe_config,
                               head_dim=head_dim,
                               num_heads=num_heads)



In [12]:
B, N, T,C = x.shape
num_patches = N // T
q_ = q_learned.expand(B, N, -1).reshape(B, N, num_heads, C // num_heads).permute(0, 2, 1, 3)

# print(self.wq(x).shape)
q = wq(x).reshape(B, N, T, num_heads, C // num_heads).permute(0, 3, 1, 2, 4)
# BNC -> BNH(C/H) -> BHN(C/H)
k = wk(x).reshape(B, N, T, num_heads, C // num_heads).permute(0, 3, 1, 2, 4)
# BNC -> BNH(C/H) -> BHN(C/H)
v = wv(x).reshape(B, N, T, num_heads, C // num_heads).permute(0, 3, 1, 2, 4)

In [13]:
xx = torch.rand(B, 4096, 256)
kk = wk(xx).reshape(B, N, num_heads, C // num_heads).permute(0, 2, 1, 3)
print(kk.shape)
print(kk.transpose(-2, -1).shape)

torch.Size([2, 8, 4096, 32])
torch.Size([2, 8, 32, 4096])


In [39]:
q1 = q_learned.expand(B, N, -1).unsqueeze(2).reshape(B, N, 1, num_heads, C // num_heads)
k1 = wk(x).reshape(B, N, T, num_heads, C // num_heads)
print(q1.shape)
print(k1.shape)

torch.Size([2, 4096, 1, 8, 32])
torch.Size([2, 4096, 2, 8, 32])


In [15]:
# q1 = q1.permute(0, 3, 1, 1, 4)
# k1 = k1.permute(0, 3, 3, 1, 4)

In [16]:
k1.shape

torch.Size([2, 4096, 2, 256])

In [17]:
q1.shape

torch.Size([2, 4096, 1, 256])

In [45]:
k1.transpose(-2, -1).shape

torch.Size([2, 4096, 2, 32, 8])

In [47]:
attn = (q1 @ k1.transpose(-2, -1))

In [48]:
attn.shape

torch.Size([2, 4096, 2, 8, 8])

In [24]:
q1.shape

torch.Size([2, 4096, 1, 256])

In [42]:
# Given tensors
A = torch.randn(2, 4096, 1, 8, 32)  # Tensor A
B = torch.randn(2, 4096, 2, 8, 32)  # Tensor B

# Step 1: Perform the dot product along the last two dimensions (dim=-2 and dim=-1)
# We use torch.einsum for a flexible dot product operation.
# "ijklm,ijnlm->ijkn" indicates the reduction of the last two dimensions.

result = torch.einsum('ijklm,ijnlm->ijkn', A, B)
result.shape

torch.Size([2, 4096, 1, 2])

In [30]:
qq = q1.reshape(2, 8, 4096, 1, 256 // 8).squeeze(3)
qq.shape

torch.Size([2, 8, 4096, 32])

In [36]:
n_modalities = 2
print("full rpe k",  rpe_k(q1).shape)
print("rpe k extend", rpe_k(q1).repeat(1, 1, 1, n_modalities).shape)
repq1 = rpe_k(q1)

full rpe k torch.Size([2, 32768, 1, 1])
rpe k extend torch.Size([2, 32768, 1, 2])


In [38]:
repq1.reshape(2, 8, 32768 // 8, 1, 1).shape

torch.Size([2, 8, 4096, 1, 1])

In [31]:
n_modalities = 2
print("full rpe k",  rpe_k(qq).shape)
print("rpe k extend", rpe_k(qq).repeat(1, 1, 1, n_modalities).shape)


x irpe torch.Size([2, 8, 4096, 32])
L 4096
skip 0
7
7
full rpe k torch.Size([2, 8, 4096, 4096])
rpe k extend torch.Size([2, 8, 4096, 8192])


In [34]:
32768/8

4096.0

In [None]:
attn2 = (attn @ v)

In [None]:
attn2.shape

In [23]:
from torch import nn
dim=256
num_patches=8192
num_heads=8
x = torch.rand((2, 8192, 256))
q_learned = nn.Parameter(torch.zeros(1, 1, dim))

In [6]:
B, N, C = x.shape
# B1C -> B1H(C/H) -> BH1(C/H)
q_ = q_learned.expand(B, num_patches, -1)
q = q_.reshape(B, num_patches, num_heads, C // num_heads).permute(0, 2, 1, 3)

In [7]:
q.shape

torch.Size([2, 8, 8192, 32])

In [8]:
x = x.permute(0, 2, 1)

In [10]:
x.shape

torch.Size([2, 256, 8192])

In [17]:
import torch.nn.functional as F
import numpy as np

pa = 91*91 - x.shape[-1]
x = F.pad(x, (0, int(pa)))
x.shape

torch.Size([2, 256, 8281])

In [18]:
pa

-8190

In [19]:
91**2

8281

In [22]:
x.permute(0, 2, 1)[0, 8280, :]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 