In [5]:
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from src.util.misc import log_opts, set_submodule_paths, set_cache_directories
set_submodule_paths(submodule_dir="submodules")
from ldm.util import instantiate_from_config
from train import get_data
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
from src.data.preprocessing.data_modules import DataModuleFromConfig

In [10]:
config_path = "configs/autoencoder/4_adaptive_conv_no_shift/manual_crop_img_zoomout.yaml"
checkpoint_path = "logs/2024-08-19T08-19-05_manual_crop_img_zoomout/checkpoints/last.ckpt"
config = OmegaConf.load(config_path)

def load_model(config, ckpt_path):
    model = instantiate_from_config(config.model)
    checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model

model = load_model(config, checkpoint_path)
model.eval()

making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 16, 16, 16) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
Loaded dataset stats:
{'car': {'t1': tensor([ 9.8967e-09, -1.8421e+01]), 't2': tensor([ 4.8958e-09, -1.8421e+01]), 't3': tensor([ 0.5154, -0.6869]), 'v3': tensor([0.0591, 1.1842]), 'l': tensor([ 2.6982, -2.4282]), 'h': tensor([ 1.7239, -2.8687]), 'w': tensor([ 1.1395, -4.1269]), 'yaw': tensor([0.0582, 1.1975]), 'fill_factor': tensor([ 0.4309, -3.2367])}, 'truck': {'t1': tensor([ 1.3829e-08, -1.8421e+01]), 't2': tensor([ 3.3183e-09, -1.8421e+01]), 't3': tensor([ 0.2548, 

AdaptivePoseAutoencoder(
  (encoder): FeatEncoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (1): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          

In [11]:
data = get_data(config)
iteration = iter(data.datasets['train'])

08/21 09:50:50 - mmengine - [4m[97mINFO[0m - ------------------------------
08/21 09:50:50 - mmengine - [4m[97mINFO[0m - The length of training dataset: 168780
08/21 09:50:50 - mmengine - [4m[97mINFO[0m - The number of instances per category in the dataset:
+----------------------+--------+
| category             | number |
+----------------------+--------+
| car                  | 413318 |
| truck                | 72815  |
| trailer              | 20701  |
| bus                  | 13163  |
| construction_vehicle | 11993  |
| bicycle              | 9478   |
| motorcycle           | 10109  |
| pedestrian           | 185847 |
| traffic_cone         | 82362  |
| barrier              | 125095 |
+----------------------+--------+
08/21 09:51:08 - mmengine - [4m[97mINFO[0m - ------------------------------
08/21 09:51:08 - mmengine - [4m[97mINFO[0m - The length of training dataset: 36114
08/21 09:51:08 - mmengine - [4m[97mINFO[0m - The number of instances per category in the 

In [21]:
batch = next(iteration)
model.chunk_size = 128
model.class_thresh = 0.0 # TODO: Set to 0.5 or other value that works on val set
model.fill_factor_thresh = 0.5 # TODO: Set to 0.5 or other value that works on val set
model.num_refinement_steps = 10
model.ref_lr=1.0e0

# Prepare Input
input_patches = batch[model.image_rgb_key].to(model.device).unsqueeze(0) # torch.Size([B, 3, 256, 256])
assert input_patches.dim() == 4 or (input_patches.dim() == 5 and input_patches.shape[0] == 1), f"Only supporting batch size 1. Input_patches shape: {input_patches.shape}"
if input_patches.dim() == 5:
    input_patches = input_patches.squeeze(0) # torch.Size([B, 3, 256, 256])
input_patches = model._rescale(input_patches) # torch.Size([B, 3, 256, 256])

# Chunked dec_pose[..., :POSE_6D_DIM + LHW_DIM + FILL_FACTOR_DIM]
all_objects = []
all_poses = []
all_patch_indices = []
all_scores = []
all_classes = []
chunk_size = model.chunk_size
with torch.no_grad():
    for i in range(0, len(input_patches), chunk_size):
        selected_patches = input_patches[i:i+chunk_size]
        global_patch_index = i + torch.arange(chunk_size)[:len(selected_patches)]
        selected_patch_indices, z_obj, dec_pose, score, class_idx = model._get_valid_patches(selected_patches, global_patch_index)
        all_objects.append(z_obj)
        all_poses.append(dec_pose)
        all_scores.append(score)
        all_classes.append(class_idx)
        
scores = torch.cat(all_scores)
            
# Inference refinement
all_patch_indices = torch.cat(all_patch_indices)
if not len(all_patch_indices):
    return torch.empty(0)
all_z_objects = torch.cat(all_objects)
all_z_poses = torch.cat(all_poses)
patches_w_objs = input_patches[all_patch_indices]
dec_pose_refined = model._refinement_step(patches_w_objs, all_z_objects, all_z_poses)

selected_patch_indices: tensor([], dtype=torch.int64)


RuntimeError: torch.cat(): expected a non-empty list of Tensors