In [1]:
%reload_ext autoreload
%autoreload 2

# Use aligner_v7 kernel

import sys
add_paths = [
    "/fsx_0/user/tranx/rsync", # ALIGNER_PARENT_DIR
    "/fsx_0/user/tranx/rsync/llm_mm_aligner/replicated", # ALIGNER_PARENT_DIR/llm_mm_aligner/replicated
    "/fsx_0/shared/conda/aligner_20241030/python-packages"
]

for p in add_paths:
    if p not in sys.path:
        sys.path.append(p)
        
device = "cuda:0"

In [2]:
from typing import Any, Dict, List, Optional, Tuple, Union
from tqdm.auto import tqdm
import torch
import torchvision.datasets as datasets
from torchmetrics import Accuracy

from llm_mm_aligner.lib.encoders.metaclip_text_encoder import MetaCLIPTextTransformer
from llm_mm_aligner.lib.encoders.metaclip_vev01 import MetaCLIPVEv01
from llm_mm_aligner.lib.encoders.metaclip_vev02 import MetaCLIPVEv02

from transformers import CLIPProcessor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def get_metaclip_model(model_path:str):
    if "vev0p1" in model_path or model_path.split("/")[-1] in [
        "MetaCLIP-hp14_336_fair_vev0_ckpt_epoch205"
    ]:
        return MetaCLIPVEv01
    
    if model_path.split("/")[-1] in [
        "MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360"
    ]:
        return MetaCLIPVEv02
    
    raise ValueError(f"Unknown metaclip model from path: {model_path}")
    
class OpenCLIPModelWrapper(torch.nn.Module):
    def __init__(self, vision_model: torch.nn.Module, text_model: torch.nn.Module):
        super().__init__()
        self.vision_model = vision_model
        self.visual_projection = self.vision_model.proj
        self.text_model = text_model
        self.text_projection = self.text_model.text_projection
        self.logit_scale = torch.nn.Parameter(torch.tensor(4.605))

    @classmethod
    def from_pretrained(cls, pretrained_path: str, **kwargs):
        metaclip_model = get_metaclip_model(pretrained_path)

        vision_model = metaclip_model.from_pretrained(
            pretrained_path=pretrained_path
        )
                
        # vision_model = metaclip_model.from_pretrained(
        #     pretrained_path=pretrained_path,
        #     return_intermediate=False,
        # )
        text_model = MetaCLIPTextTransformer.from_pretrained(pretrained_path)
        return cls(vision_model, text_model)

    def get_text_features(
        self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        return self.text_model(input_ids)

    def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.vision_model(pixel_values)
    
def load_clip_model(clip_model_path: str):

    clip_model = OpenCLIPModelWrapper.from_pretrained(clip_model_path)
    clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
    print(f"Successfully loaded clip model and preprocessor from {clip_model_path}")
    
    return clip_model, clip_processor

def load_label_templates(filename: str) -> List[str]:
    """
    Load the ImageNet label templates or other custom templates in txt format.
    """
    with open(filename, "r") as f:
        templates = f.readlines()
    return [item.strip("\n") for item in templates]

def load_label_names(filename: str) -> List[str]:
    """
    Load the ImageNet label names / categories or other custom label names in txt format.
    """
    with open(filename, "r") as f:
        content = f.readlines()
        labels = [",".join(x.split(",")[1:]).strip("\n") for x in content]
    return labels


def get_label_embeddings(
    clip_model,
    processor: CLIPProcessor,
    labels: List[str],
    templates: List[str],
) -> torch.Tensor:
    """
    For ImageNet Evaluation, the first step is to get the averaged embedding of each category / label.
    """
    # pyre-fixme[16]: `CLIPModel` has no attribute `text_model`.
    clip_model.text_model = clip_model.text_model.cuda()
    # pyre-fixme[16]: `CLIPModel` has no attribute `text_projection`.
    clip_model.text_projection = clip_model.text_projection.cuda()
    with torch.no_grad():
        all_text_embeds = []
        
        for label in labels:
            # put one label / category in the templates
            texts = [template.format(label) for template in templates]
            # text to token ids in CLIP way
            tokenized_outputs = processor(text=texts, return_tensors="pt", padding=True)
            text_input_ids, text_attn_masks = (
                tokenized_outputs.input_ids,
                tokenized_outputs.attention_mask,
            )
            text_input_ids = text_input_ids.cuda()
            text_attn_masks = text_attn_masks.cuda()
            # pyre-fixme[16]: `CLIPModel` has no attribute `get_text_features`.
            text_features = clip_model.get_text_features(
                input_ids=text_input_ids,
                attention_mask=text_attn_masks,
            )
            text_embeds = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
            # averaged embedding of a category
            text_embeds = text_embeds.mean(dim=0)
            text_embeds /= text_embeds.norm()

            all_text_embeds.append(text_embeds)

    return torch.stack(all_text_embeds)


class ImageFolderWithNames(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target, path
    
class ImageNetCollator():
    def __init__(self, f_transform):
        # Assuming f_transform is CLIPProcessor
        assert isinstance(f_transform, CLIPProcessor)
        self.f_transform = f_transform
    
    def __call__(self, batch_data: list[dict[str, Any]])-> Optional[dict[str, Any]]:
        images = [item[0] for item in batch_data]
        images = self.f_transform(images=images, return_tensors="pt")["pixel_values"]
        
        labels = [item[1] for item in batch_data]
        image_names = [item[2] for item in batch_data]
        
        return {"images": images, "labels": labels, "image_names": image_names}


def create_imagenet_dataloader(data_path: str, f_transform, batch_size=8, num_workers=0, sampler=None):

    dataset = ImageFolderWithNames(data_path)
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=0,
        sampler=sampler,
        collate_fn=ImageNetCollator(f_transform=f_transform),
    )
    
    return dataloader

In [6]:
checkpoint_dir = "/fsx_0/checkpoints/clip/MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360"
clip_model, clip_processor = load_clip_model(checkpoint_dir)

vision_config: {'output_dim': 1280, 'global_layers': -1, 'relative_pos_embed_type': 'rope_2d', 'pool_type': 'attn', 'embed_cls_token': False, 'mim_loss': True, 'mim_ratio': 0.0625, 'pos_embed_type': 'learnable', 'image_size': 392, 'patch_size': 14, 'layers': 50, 'width': 1536, 'heads': 16, 'mlp_ratio': 5.833333334, 'ckpt_path': '/fsx_0/checkpoints/clip/MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360/epoch_360.pt'}
img_idx.shape=torch.Size([784, 1])


MetaCLIPVEv02 Visual Pretrained Model Missing keys: []
MetaCLIPVEv02 Visual Pretrained Model Unexpected keys: []
MetaCLIPTextTransformer Pretrained Model Missing keys: []
MetaCLIPTextTransformer Pretrained Model Unexpected keys: []


Successfully loaded clip model and preprocessor from /fsx_0/checkpoints/clip/MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360


In [5]:
templates = load_label_templates(
    "/fsx_0/user/tranx/github/openCLIPMeta/scripts/tranx/clip_imagenet_templates.txt"
)

label_names = load_label_names(
    "/fsx_0/user/tranx/github/openCLIPMeta/scripts/tranx/clip_modified_full_size_labels.txt"
)
print(f"Each sample will compare against {len(label_names)} labels")



label_embeds = get_label_embeddings(
    clip_model, clip_processor, label_names, templates
)

Each sample will compare against 1000 labels


In [6]:
batch_size = 128
dataloader = create_imagenet_dataloader(
    data_path="/fsx_0/dataset01/imagenet/val",
    f_transform=clip_processor,
    batch_size=batch_size,
    num_workers=24
)
# batch = next(iter(dataloader))
# batch

total_samples = len(dataloader.dataset)
num_batches = (total_samples // batch_size) + (total_samples % batch_size != 0)
num_batches

391

In [7]:
clip_processor

CLIPProcessor:
- image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 392,
    "width": 392
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 392
  }
}

- tokenizer: CLIPTokenizerFast(name_or_path='/fsx_0/checkpoints/clip/MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360', vocab_size=49408, model_max_length=32, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, specia

In [8]:
all_preds = []
all_labels = []

clip_model.vision_model = clip_model.vision_model.cuda()
clip_model.visual_projection = clip_model.visual_projection.cuda()

with torch.no_grad():
    for batch in tqdm(dataloader, total=num_batches):
    # for i, batch in enumerate(dataloader):
    # for batch in tqdm(dataloader):
        images = batch["images"].cuda()
        image_features = clip_model.get_image_features(images)
        
        image_embeds = image_features / image_features.norm(
                p=2, dim=-1, keepdim=True
            )
        
        logits_per_text = torch.matmul(label_embeds, image_embeds.t())
        # pyre-fixme[16]: `CLIPModel` has no attribute `logit_scale`.
        logits_per_image = logits_per_text.t() * clip_model.logit_scale.exp()
        # confidences of each sample over the classes
        probs = logits_per_image.softmax(dim=1)
        # get the top1 prediction
        preds = probs.argmax(dim=1)
        # add to all_preds => List[int]
        all_preds.extend(preds.tolist())
        all_labels.extend(batch["labels"])
    
accuracy = 0.0
if len(all_labels) == len(all_preds):
    accuracy_metrics = Accuracy(num_classes=1000, task='multiclass')
    accuracy = accuracy_metrics(
        torch.Tensor(all_preds).long(), torch.Tensor(all_labels).long()
    ).item()
print(f"Top-1 Accuracy: {accuracy:.4f}")

  0%|          | 0/391 [00:00<?, ?it/s]

[tranx] type(x): <class 'torch.Tensor'>


  0%|          | 0/391 [00:02<?, ?it/s]


AssertionError: packed_img_idx is required for RoPE

# Test forward

In [7]:
vision_config = {
    "output_dim": 1280,
    "global_layers": -1,
    "relative_pos_embed_type": "rope_2d",
    "pool_type": "attn",
    "embed_cls_token": False,

    "mim_loss": True,
    "mim_ratio": 0.0625,
    
    "pos_embed_type": "learnable",
    "image_size": 392,
    
    "patch_size": 14,

    "layers": 50,
    "width": 1536,
    "heads": 16,
    "mlp_ratio": 5.833333334
}

In [37]:
from torchvision.transforms import Compose

In [39]:
transform2 = Compose([clip_processor, PackWindowsAndPad(
    window_size=1, patch_size=14, insert_cls_token=False
)])

transform2

Compose(
    CLIPProcessor:
- image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 392,
    "width": 392
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 392
  }
}

- tokenizer: CLIPTokenizerFast(name_or_path='/fsx_0/checkpoints/clip/MetaCLIP-Gs_224_fair_vev02_ckpt_epoch_360', vocab_size=49408, model_max_length=32, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized

In [47]:
dataset = datasets.ImageFolder("/fsx_0/dataset01/imagenet/val", transform=transform2)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=4,
    sampler=None,
    collate_fn=CollatePackedWindows,
)

In [42]:
dataset = ImageFolderWithNames("/fsx_0/dataset01/imagenet/val")

In [45]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=0,
    sampler=None,
    collate_fn=CollatePackedWindows,
)

In [48]:
batch = next(iter(dataloader))

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/hpcaas/.mounts/fs-036153e63d56f4dc2/home/tranx/conda/envs/aligner_v7/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/opt/hpcaas/.mounts/fs-036153e63d56f4dc2/home/tranx/conda/envs/aligner_v7/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/hpcaas/.mounts/fs-036153e63d56f4dc2/home/tranx/conda/envs/aligner_v7/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/hpcaas/.mounts/fs-036153e63d56f4dc2/home/tranx/conda/envs/aligner_v7/lib/python3.10/site-packages/torchvision/datasets/folder.py", line 247, in __getitem__
    sample = self.transform(sample)
  File "/opt/hpcaas/.mounts/fs-036153e63d56f4dc2/home/tranx/conda/envs/aligner_v7/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "/home/tranx/.local/lib/python3.10/site-packages/transformers/models/clip/processing_clip.py", line 106, in __call__
    encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs)
  File "/home/tranx/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3016, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/home/tranx/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3076, in _call_one
    raise ValueError(
ValueError: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).


In [None]:

    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=0,
        sampler=sampler,
        collate_fn=ImageNetCollator(f_transform=f_transform),
    )

In [8]:
batch_size = 4
dataloader = create_imagenet_dataloader(
    data_path="/fsx_0/dataset01/imagenet/val",
    f_transform=clip_processor,
    # f_transform=transform2,
    batch_size=batch_size,
    num_workers=24
)
# batch = next(iter(dataloader))
# batch

total_samples = len(dataloader.dataset)
num_batches = (total_samples // batch_size) + (total_samples % batch_size != 0)
num_batches

12498

In [9]:
batch = next(iter(dataloader))
images = batch["images"].cuda(device=0)

In [10]:
vision_model = MetaCLIPVEv02(**vision_config) 
vision_model = vision_model.cuda(device=0)

img_idx.shape=torch.Size([784, 1])


In [11]:
vision_model.vision_select_layer

-1

In [14]:
vision_model.forward_vev01(images)

x.shape=torch.Size([4, 784, 1536])
x.shape=torch.Size([4, 784, 1536])
self.positional_embedding.shape=torch.Size([784, 1536])
packed_img_idx.shape=torch.Size([4, 784, 7])
x.shape=torch.Size([4, 784, 1536])
int_x.shape=torch.Size([4, 784, 7])
return_x_batch=tensor([[      1136,          0,          0,  ...,         28,         28,
                  0],
        [     22069,          0,          1,  ...,         28,         28,
                  1],
        [1694805312,          0,          2,  ...,         28,         28,
                  2],
        ...,
        [     32684,         27,         25,  ...,         28,         28,
                781],
        [ 325064608,         27,         26,  ...,         28,         28,
                782],
        [     32684,         27,         27,  ...,         28,         28,
                783]], device='cuda:0', dtype=torch.int32)


IndexError: too many indices for tensor of dimension 2

In [None]:
import math 

class PACKED:
    Z         = 0  # Z (time) coordinate of the token in the original sample
    Y         = 1  # Y (height) coordinate of the token in the original sample
    X         = 2  # X (width) coordinate of the token in the original sample
    TIME      = 3  # Total number of time units (frames) in the original sample
    HEIGHT    = 4  # Height of the original sample
    WIDTH     = 5  # Width of the original sample
    # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
    IDX       = 6  # Full index of the token in the original sample (x + y * w + z * w * h)
    BATCH_IDX = 7  # Which batch element this token belongs to.

    # Total size of the enum, remember to update this!
    NUM_METADATA = 8

    # Note: For padding tokens IDX = -1
    #       For cls tokens,    IDX = -2
    ID_CLS_TOKEN = -2
    ID_PAD_TOKEN = -1
    
class PackWindowsAndPad:
    """ Takes a torch image, chunks it into windows, and pads to the nearest window. """
    def __init__(self, window_size, patch_size, insert_cls_token):
        self.window_size = window_size  # can be (h, w) or #tokens, if #tokens use hilbert, if (h, w) use z-order
        self.patch_size  = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
        self.insert_cls_token = insert_cls_token

        if isinstance(window_size, (list, tuple)):
            self.window_shape = "quad"
        elif window_size == 1:
            self.window_shape = "none"
        else:
            self.window_shape = "hilbert"

    def __call__(self, img: torch.Tensor):
        c, ih, iw = img.shape

        # Quad padding requires the image to be a multiple of window size, while hilbert is just patch size
        patch_h, patch_w = self.patch_size
        
        h = int(math.ceil(ih / patch_h)) * patch_h
        w = int(math.ceil(iw / patch_w)) * patch_w

        pt = (h - ih) // 2
        pb = (h - ih) - pt
        pl = (w - iw) // 2
        pr = (w - iw) - pl

        # Check if we need to pad at all because idk if pytorch internally checks
        if not (pt == pb and pl == pr and pr == 0):
            img = F.pad(img, [pl, pt, pr, pb], fill=0)

        img = img.unfold(-2, patch_h, patch_h).unfold(-2, patch_w, patch_w)
        _, grid_h, grid_w, _, _ = img.shape

        img_idx = torch.arange(h*w // (patch_h * patch_w), dtype=torch.int32)
        img_idx = img_idx.reshape(h // patch_h, w // patch_w)

        # Then rearrange the patches into windows and pad to the nearest window (in the case of hilbert)
        if self.window_shape == "quad":
            new_grid_h = int(math.ceil(grid_h / self.window_size[0])) * self.window_size[0]
            new_grid_w = int(math.ceil(grid_w / self.window_size[1])) * self.window_size[1]

            pb = (new_grid_h - grid_h)
            pr = (new_grid_w - grid_w)

            img = img.reshape(c, grid_h, -1)
            img = F.pad(img, [0, 0, pr * patch_h * patch_w, pb], fill=0)

            img = img.reshape(c, new_grid_h, new_grid_w, patch_h, patch_w)
            img = img.unfold(1, self.window_size[0], self.window_size[0]).unfold(2, self.window_size[1], self.window_size[1])
            img = img.permute(0, 1, 2, 5, 6, 3, 4)  # oof
            img = img.reshape(c, -1, self.window_size[0] * self.window_size[1], *self.patch_size)

            img_idx = F.pad(img_idx[None, ...], [0, 0, pr, pb], fill=PACKED.ID_PAD_TOKEN)[0]
            img_idx = img_idx.unfold(-2, self.window_size[0], self.window_size[0]).unfold(-2, self.window_size[1], self.window_size[1])
            img_idx = img_idx.reshape(img.shape[1], img.shape[2])

            if self.insert_cls_token:
                # To be honest, idk where to put the cls token with quad packing so I'll just put it at the end, which is probably a padding token
                img_idx[-1, -1] = PACKED.ID_CLS_TOKEN
            
        elif self.window_shape == "hilbert":
            # Rearrange the patches according to a hilbert curve sequence then pad.
                from .gilbert import gilbert2d
                idx = torch.Tensor(list(gilbert2d(grid_w, grid_h))).long().to(device=img.device, dtype=torch.long)
                img = img[:, idx[:, 1], idx[:, 0], :, :]
                img = img.reshape(c, -1, self.patch_size[0] * self.patch_size[1])

                extra_tokens = 1 if self.insert_cls_token else 0

                # Add padding tokens up to the nearest window
                pad_tokens = int(math.ceil((grid_h * grid_w + extra_tokens) / self.window_size)) * self.window_size - grid_h * grid_w
                img = F.pad(img, [0, 0, 0, pad_tokens], fill=0)
                img = img.reshape(c, -1, self.window_size, *self.patch_size)

                img_idx = img_idx[idx[:, 1], idx[:, 0]]
                img_idx = F.pad(img_idx.view(1, -1, 1), [0, 0, 0, pad_tokens], fill=PACKED.ID_PAD_TOKEN)  # use -1 idx for posemb padding tokens
                img_idx = img_idx.reshape(-1, self.window_size)

                if self.insert_cls_token:
                    # Not true anymore: We guaranteed above that there is at least one padding token at the end for the cls token
                    # Add a cls token to every window, we're gonna avg the cls token logits at the end
                    img_idx[:, -1] = PACKED.ID_CLS_TOKEN

        elif self.window_shape == "none":
            # Don't do any reordering, just immediately reshape into windows
            # TODO: Add padding if we want to use this with window size > 1
            img = img.reshape(c, -1, self.window_size, *self.patch_size)
            img_idx = img_idx.reshape(-1, self.window_size)

            # This currently assumes window_size = 1
            if self.insert_cls_token:
                assert self.window_size == 1
                img = torch.cat([img, img[:, :1]], dim=1)
                img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
                img_idx[-1, -1] = PACKED.ID_CLS_TOKEN
        else:
            raise NotImplementedError(f"Unknown window shape type: ({self.window_shape}) in {repr(self)}.")

        # Save the padded image size (after patchification, e.g., for position embeddings)
        idx_h = h // self.patch_size[0]
        idx_w = w // self.patch_size[1]

        # Last metadata is added by collate
        packed_img_idx = torch.empty(img_idx.shape[0], img_idx.shape[1], PACKED.NUM_METADATA - 1, dtype=torch.int32)

        packed_img_idx[:, :, PACKED.Z].fill_(0)
        packed_img_idx[:, :, PACKED.Y] = img_idx // idx_w
        packed_img_idx[:, :, PACKED.X] = img_idx % idx_w
        packed_img_idx[:, :, PACKED.TIME].fill_(1)
        packed_img_idx[:, :, PACKED.HEIGHT].fill_(idx_h)
        packed_img_idx[:, :, PACKED.WIDTH].fill_(idx_w)
        packed_img_idx[:, :, PACKED.IDX] = img_idx

        # Return shape: [c, #windows, window_size, patch_height, patch_width]
        return img, packed_img_idx
    
    def __repr__(self):
        format_string = self.__class__.__name__ + f'(window_size={self.window_size}'
        format_string += f', patch_size={self.patch_size}'
        format_string += f', window_shape={self.window_shape})'
        return format_string


pp = PackWindowsAndPad(
    window_size=1, patch_size=14, insert_cls_token=False
)



batch_pp = [pp(x) for x in images]


In [36]:
class CollatePackedWindows:

    def __call__(self, batch):
        batch = [x for x in batch if x is not None]

        # Sort the batch by number of windows
        batch.sort(key=lambda x: x[0][0].shape[1])

        data = [item[0] for item in batch]
        target = [item[1] for item in batch]
        
        # To handle both plain integer targets and captions
        if isinstance(target, (list, tuple)) and isinstance(target[0], torch.Tensor):
            target = torch.stack(target, dim=0)
        else:
            target = torch.Tensor(target)

        num_windows = torch.Tensor([x[0].shape[1] for x in data]).long()

        packed_num_windows, packed_counts = torch.unique(num_windows, return_counts=True)
        packed_end_idx = (packed_counts * packed_num_windows).cumsum(dim=0)
        packing_boundaries = [packed_num_windows.tolist(), packed_end_idx.tolist()]
        
        packed_img = torch.cat([x[0] for x in data], dim=1)
        print(packed_img.shape)
        packed_image = packed_img.permute(1, 0, 2, 3, 4).contiguous()
        
        packed_img_idx = torch.cat([x[1] for x in data], dim=0)

        element_idx = torch.Tensor(sum([[i]*n for i, n in enumerate(num_windows)], [])).to(torch.int32)
        element_idx = element_idx.view(-1, 1).tile((1, packed_img_idx.shape[1]))

        packed_img_idx = torch.cat([packed_img_idx, element_idx[..., None]], dim=-1)

        return [(packed_img, packed_img_idx, num_windows.tolist(), packing_boundaries), target]
    
pp_collator = CollatePackedWindows()

batch_packed = pp_collator(batch_pp)

torch.Size([784, 128, 14, 14])


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 5

In [30]:
len(batch_pp[0])

2

In [14]:
vision_model.forward_metaclip(images)

[tranx] type(x): <class 'torch.Tensor'>


AssertionError: packed_img_idx is required for RoPE

In [11]:
vision_model(images)



RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [15]:
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
    """
    Selects the best resolution from a list of possible resolutions based on the original size.

    This is done by calculating the effective and wasted resolution for each possible resolution.

    The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.

    Args:
        original_size (tuple):
            The original size of the image in the format (height, width).
        possible_resolutions (list):
            A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].

    Returns:
        tuple: The best fit resolution in the format (height, width).
    """
    original_height, original_width = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for height, width in possible_resolutions:
        
        scale = min(width / original_width, height / original_height)
        
        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
        wasted_resolution = (width * height) - effective_resolution

        print(f"{height=}, {width=}, {scale=}, {downscaled_width=}, {downscaled_height=}, {wasted_resolution=}")
        print(f"{effective_resolution=}, scaled_resolution={width * height}, {wasted_resolution=}")
        
        if effective_resolution > max_effective_resolution or (
            effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
        ):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (height, width)

    return best_fit

possible_resolutions = []
chunk_size = 336
for x in range(1,4):
    for y in range(1,4):
        possible_resolutions.append((x*chunk_size, y*chunk_size))
   
# original_size = (432, 576)
original_size = (576, 576)
select_best_resolution(original_size, possible_resolutions)

height=336, width=336, scale=0.5833333333333334, downscaled_width=336, downscaled_height=336, wasted_resolution=0
effective_resolution=112896, scaled_resolution=112896, wasted_resolution=0
height=336, width=672, scale=0.5833333333333334, downscaled_width=336, downscaled_height=336, wasted_resolution=112896
effective_resolution=112896, scaled_resolution=225792, wasted_resolution=112896
height=336, width=1008, scale=0.5833333333333334, downscaled_width=336, downscaled_height=336, wasted_resolution=225792
effective_resolution=112896, scaled_resolution=338688, wasted_resolution=225792
height=672, width=336, scale=0.5833333333333334, downscaled_width=336, downscaled_height=336, wasted_resolution=112896
effective_resolution=112896, scaled_resolution=225792, wasted_resolution=112896
height=672, width=672, scale=1.1666666666666667, downscaled_width=672, downscaled_height=672, wasted_resolution=119808
effective_resolution=331776, scaled_resolution=451584, wasted_resolution=119808
height=672, wi

(672, 672)