# Checking constant-sized tensors

Compiling models in torch with `torch.export` is a bit of a hassle because conditionals are not well supported. As such, we must try to find constant tensors, be it by shape or value and try to hint the compiler about it.

## Setup

In [24]:
from detectron2.data.detection_utils import read_image
from demo.predictors import VisualizationDemo
from detectron2.config import LazyConfig, instantiate
from pathlib import Path
from detectron2.checkpoint import DetectionCheckpointer
import torch
from omegaconf import OmegaConf
from copy import copy, deepcopy
import detectron2.data.transforms as T
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from torch.export import export
import typing
from torch.export import Dim, export
from torchvision.transforms import functional as TTF
from typing import Callable, Tuple
import torch.nn as nn
import numpy as np
print(Path.cwd())
assert (Path.cwd() / ".project-root").exists(), "Please run this script from the root of the project"

/Users/dgcnz/development/amsterdam/edge


In [2]:
config_file = "projects/dino_dinov2/configs/COCO/dino_dinov2_b_12ep.py"
image_path = "artifacts/idea_raw.jpg"
CONFIDENCE_THRESHOLD = 0.5
opts = ["model.device=cpu", "train.device=cpu", "train.init_checkpoint=artifacts/model_final.pth"]
cfg = LazyConfig.load(config_file)
cfg = LazyConfig.apply_overrides(cfg, opts)

  def forward(
  def backward(ctx, grad_output):


In [3]:
model = instantiate(OmegaConf.to_object(cfg.model)).eval();
checkpointer = DetectionCheckpointer(model);
checkpointer.load(cfg.train.init_checkpoint);

[32m[08/19 18:27:02 timm backbone]: [0mbackbone out_indices: (11,)
[32m[08/19 18:27:02 timm backbone]: [0mbackbone out_channels: [768]
[32m[08/19 18:27:02 timm backbone]: [0mbackbone out_strides: [16]


  return torch.load(f, map_location=torch.device("cpu"))


In [4]:
img = read_image(image_path, format="BGR")
demo = VisualizationDemo(model=model)

## Testing

### T1: `spatial_shapes` over multiple image sizes

`spatial_shapes` has been one of the main culprits in compiler errors, and in testing it has been constant so far. However, testing has been with a single sample image, so it might be the case that multiple image sizes yield multiple spatial shapes.

So that's what we will do, we will reshape our sample image in multiple sizes and hook a logger to a layer taking `spatial_shapes`.

In [20]:
original_shape = img.shape # (1920, 1281, 3)
new_shapes = [(200, 2000),(800, 800), (800, 1200), (1200, 800), (1200, 1200), (2000, 2000)]
img_cwh = torch.from_numpy(img.copy()).permute(2, 0, 1)
images = [TTF.resize(img_cwh, new_shape).permute(1, 2, 0).numpy() for new_shape in new_shapes]

In [21]:
spatial_shapes_history = []

def track_spatial_shapes(layer: nn.Module, inputs: Tuple[torch.Tensor], kwargs: dict):
    global spatial_shapes_history
    spatial_shapes_history.append(kwargs["spatial_shapes"])


handle = demo.predictor.model.transformer.encoder.register_forward_pre_hook(track_spatial_shapes, with_kwargs=True)

for _img in images:
    demo.run_on_image(_img)

handle.remove()


In [22]:
print(*[(new_shapes[ix], x.flatten().tolist()) for ix, x in enumerate(spatial_shapes_history)], sep="\n")

((200, 2000), [18, 168, 9, 84, 4, 42, 2, 21])
((800, 800), [100, 100, 50, 50, 25, 25, 13, 13])
((800, 1200), [100, 150, 50, 75, 25, 37, 13, 19])
((1200, 800), [150, 100, 75, 50, 37, 25, 19, 13])
((1200, 1200), [100, 100, 50, 50, 25, 25, 13, 13])
((2000, 2000), [100, 100, 50, 50, 25, 25, 13, 13])


That's unfortunate, it seems that `spatial_shapes` is not constant and there's no trivial pattern to see.

### T2: `spatial_shapes` over a single image size

However, if we think about the production usage, we will use at most 2 different cameras (which presumably have the same resolution), so we might as well just compile a model specifically to that resolution. Just to sanity check, let's run different images of the same size and see if `spatial_shapes` change.

In [27]:
def random_uint8_image(shape: Tuple[int, int, int]) -> np.ndarray:
    return np.random.randint(0, 255, shape, dtype=np.uint8)

# random_uint8_image(img.shape)

In [31]:
N_IMAGES = 4
images = [random_uint8_image(img.shape) for _ in range(N_IMAGES)]
images.append(np.zeros(img.shape, dtype=np.uint8))

In [32]:
spatial_shapes_history = []

def track_spatial_shapes(layer: nn.Module, inputs: Tuple[torch.Tensor], kwargs: dict):
    global spatial_shapes_history
    spatial_shapes_history.append(kwargs["spatial_shapes"])


handle = demo.predictor.model.transformer.encoder.register_forward_pre_hook(track_spatial_shapes, with_kwargs=True)

for _img in images:
    demo.run_on_image(_img)

handle.remove()


In [33]:
print(*[x.flatten().tolist() for ix, x in enumerate(spatial_shapes_history)], sep="\n")

[150, 100, 75, 50, 37, 25, 19, 13]
[150, 100, 75, 50, 37, 25, 19, 13]
[150, 100, 75, 50, 37, 25, 19, 13]
[150, 100, 75, 50, 37, 25, 19, 13]
[150, 100, 75, 50, 37, 25, 19, 13]


Seems fine, we'll do that.