### sliding window

In [7]:
import math
import torch
import torch.nn as nn
import numpy as np
from torch.nn.functional import interpolate
from torchvision.transforms.v2.functional import resize

class A(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.img_size = (512,512)
    def scale_img_size_semantic(self, size: tuple[int, int]):
        factor = max(
            self.img_size[0] / size[0],
            self.img_size[1] / size[1],
        )

        return [round(s * factor) for s in size]
        
    def window_imgs_semantic(self, imgs):
        img_sizes = [img.shape[-2:] for img in imgs]

        crops, origins = [], []
        for i in range(len((imgs))):
            new_img = resize(
                imgs[i],
                self.scale_img_size_semantic(img_sizes[i]),
            )

            num_crops = math.ceil(max(new_img.shape[-2:]) / min(self.img_size))
            overlap = num_crops * min(self.img_size) - max(new_img.shape[-2:])
            overlap_per_crop = (overlap / (num_crops - 1)) if overlap > 0 else 0

            for j in range(num_crops):
                start = int(j * (min(self.img_size) - overlap_per_crop))
                end = start + min(self.img_size)
                crop = (
                    new_img[:, start:end, :]
                    if new_img.shape[-2] > new_img.shape[-1]
                    else new_img[:, :, start:end]
                )

                crops.append(crop)
                origins.append((i, start, end))

        return torch.stack(crops), origins, [img.shape[-2:] for img in imgs]

In [16]:
imgs = [torch.rand(3, 600, 500)]
a = A()
crops, origins, img_sizes = a.window_imgs_semantic(imgs)
print(f"crops --> {crops.shape}")
print(f"origins --> {origins}")
print(f"img_sizes --> {img_sizes}")

crops --> torch.Size([2, 3, 512, 512])
origins --> [(0, 0, 512), (0, 102, 614)]
img_sizes --> [torch.Size([600, 500])]


In [4]:
import torch

depth = 24
ckpt_path = "/mnt/sda1/tkerssies/beit3_large_patch16_224.pth"

ckpt = torch.load(ckpt_path)["model"]

ckpt = {k.replace("beit3.", ""): v for k, v in ckpt.items()}
ckpt = {k: v for k, v in ckpt.items() if not k.startswith("text_embed")}
ckpt = {k.replace("vision_embed.", "patch_embed."): v for k, v in ckpt.items()}
ckpt.pop("patch_embed.mask_token")
ckpt.pop("mlm_head.weight")
ckpt.pop("mlm_head.bias")
ckpt.pop("mim_head.weight")
ckpt.pop("mim_head.bias")
ckpt = {k.replace("patch_embed.cls_token", "cls_token"): v for k, v in ckpt.items()}
ckpt = {k.replace("encoder.", ""): v for k, v in ckpt.items()}
ckpt = {k.replace("layers", "blocks"): v for k, v in ckpt.items()}
ckpt = {k: v for k, v in ckpt.items() if ".B." not in k}
ckpt = {k.replace(".A.", "."): v for k, v in ckpt.items()}
ckpt["pos_embed"] = ckpt.pop("embed_positions.weight")
ckpt = {k.replace("self_attn_layer_norm", "norm1"): v for k, v in ckpt.items()}
ckpt = {k.replace("self_attn", "attn"): v for k, v in ckpt.items()}
ckpt = {k.replace("inner_attn_ln", "proj.0"): v for k, v in ckpt.items()}
ckpt = {k.replace("out_proj", "proj.1"): v for k, v in ckpt.items()}
ckpt = {k.replace("ffn", "mlp"): v for k, v in ckpt.items()}
ckpt = {k.replace("final_layer_norm", "norm2"): v for k, v in ckpt.items()}
ckpt = {k.replace("mlp_layernorm", "norm"): v for k, v in ckpt.items()}
ckpt["norm.weight"] = ckpt.pop("layer_norm.weight")
ckpt["norm.bias"] = ckpt.pop("layer_norm.bias")
ckpt["pos_embed"] = ckpt["pos_embed"][None, 2:, :]

for block in range(depth):
    q_key = f"blocks.{block}.attn.q_proj"
    k_key = f"blocks.{block}.attn.k_proj"
    v_key = f"blocks.{block}.attn.v_proj"

    q_w = ckpt.pop(f"{q_key}.weight")
    k_w = ckpt.pop(f"{k_key}.weight")
    v_w = ckpt.pop(f"{v_key}.weight")

    q_b = ckpt.pop(f"{q_key}.bias")
    k_b = ckpt.pop(f"{k_key}.bias")
    v_b = ckpt.pop(f"{v_key}.bias")

    qkv_w = torch.cat([q_w, k_w, v_w], dim=0)
    qkv_b = torch.cat([q_b, k_b, v_b])

    ckpt[f"blocks.{block}.attn.qkv.weight"] = qkv_w
    ckpt[f"blocks.{block}.attn.qkv.bias"] = qkv_b

print("\n".join(ckpt.keys()))

torch.save(ckpt, ckpt_path + ".timm")

module.positional_embedding
module.text_projection
module.logit_scale
module.visual.class_embedding
module.visual.positional_embedding
module.visual.proj
module.visual.conv1.weight
module.visual.ln_pre.weight
module.visual.ln_pre.bias
module.visual.transformer.resblocks.0.ln_1.weight
module.visual.transformer.resblocks.0.ln_1.bias
module.visual.transformer.resblocks.0.attn.in_proj_weight
module.visual.transformer.resblocks.0.attn.in_proj_bias
module.visual.transformer.resblocks.0.attn.out_proj.weight
module.visual.transformer.resblocks.0.attn.out_proj.bias
module.visual.transformer.resblocks.0.ln_2.weight
module.visual.transformer.resblocks.0.ln_2.bias
module.visual.transformer.resblocks.0.mlp.c_fc.weight
module.visual.transformer.resblocks.0.mlp.c_fc.bias
module.visual.transformer.resblocks.0.mlp.c_proj.weight
module.visual.transformer.resblocks.0.mlp.c_proj.bias
module.visual.transformer.resblocks.1.ln_1.weight
module.visual.transformer.resblocks.1.ln_1.bias
module.visual.transformer