In [None]:
import os
import torch
import argparse
import json
from pytorch_lightning import Trainer
from train_sample_utils import get_models, get_DDPM

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument(
    '-c', '--config', type=str, 
    default='config/train.json')
parser.add_argument(
    '-n', '--num_repeat', type=int, 
    default=1, help='the number of images for each condition')
parser.add_argument(
    '-cond', '--condition', action='store_true',
    help='whether for unconditional sampling or conditional sampling (for some dataset)'
)

''' parser configs '''
args_raw = parser.parse_args(['-c', 'configs/coco_ldm_soft_partial_prompt_animal_subset.json'])
with open(args_raw.config, 'r') as IN:
    args = json.load(IN)
args.update(vars(args_raw))
# args['gpu_ids'] = [0] # DEBUG
expt_name = args['expt_name']
expt_dir = args['expt_dir']
expt_path = os.path.join(expt_dir, expt_name)
os.makedirs(expt_path, exist_ok=True)

In [None]:
'''1. create denoising model'''
denoise_args = args['denoising_model']['model_args']
models = get_models(args)

diffusion_configs = args['diffusion']
# diffusion_args['beta_schedule_args']['n_timestep'] = 10 # DEBUG
ddpm_model = get_DDPM(
    diffusion_configs=diffusion_configs,
    log_args=args,
    **models
).cuda()

In [None]:
'''2. create a dataloader which generates'''
from test_sample_utils import get_test_dataset, get_test_callbacks
test_dataset, test_loader = get_test_dataset(args)

'''3. callbacks'''
callbacks = get_test_callbacks(args, expt_path)

In [None]:
ckpt_path = os.path.join(expt_path, 'latest.ckpt')
if os.path.exists(ckpt_path):
    print(f'INFO: Found checkpoint {ckpt_path}')
    ckpt = torch.load(ckpt_path)['state_dict']
    ''' DEBUG '''
    # ckpt_denoise_fn = {k.replace('denoise_fn.', ''): v for k, v in ckpt.items() if 'denoise_fn' in k}
    # ddpm_model.denoise_fn.load_state_dict(ckpt_denoise_fn)
    ddpm_model.load_state_dict(ckpt)
else:
    ckpt_path = None
    raise RuntimeError('Cannot do inference without pretrained checkpoint')

In [None]:
import math
import numpy as np
ANIMAL_ID_MAPPING = {
    16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse',
    20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
    24: 'zebra', 25: 'giraffe' 
}
class ColorMapping():
    def __init__(self, id_class_mapping, mesh_dim=3):
        self.id_class_mapping = id_class_mapping
        num_classes = len(id_class_mapping)
        num_grid_each_dim = math.ceil(num_classes**(1/mesh_dim))
        mesh_d = np.meshgrid(
            *[np.linspace(0,1,num_grid_each_dim)]*mesh_dim
        )
        mesh_d = [i.reshape(-1) for i in mesh_d]
        self.mesh = np.stack(mesh_d, axis=-1)

        self.id_to_mesh_idx = {}
        for idx, (class_id, class_name) in enumerate(id_class_mapping.items()):
            self.id_to_mesh_idx[class_id] = idx
    
    def __call__(self, class_id):
        class_name = self.id_class_mapping[class_id]
        mesh_index = self.id_to_mesh_idx[class_id]
        return self.mesh[mesh_index], class_name

In [None]:
import cv2
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def format_image(x):
    x = x.cpu()
    x = (x + 1) / 2
    x = x.clamp(0, 1)
    assert len(x.shape) in [3, 4]
    if len(x.shape) == 3:
        x = x.permute(1,2,0).detach().numpy()
    else:
        if x.shape[0] == 1:
            x = x[0].permute(1,2,0).detach().numpy()
        else:
            x = make_grid(x)
            x = x.permute(1,2,0).detach().numpy()
    return x
def show_image(x):
    plt.imshow(format_image(x))

def plot_bounding_box(image, bboxes, label_mapping=ANIMAL_ID_MAPPING):
    # bboxes: num_obj, 5
    color_mapper = ColorMapping(label_mapping)

    H, W = image.shape[:2]
    for bbox in bboxes:
        x, y, w, h = bbox[:4]
        x, y, w, h = list(map(int, [x*W, y*H, w*W, h*H]))
        label = int(bbox[-1]) + 1 if len(bbox) == 5 else None
        # in the network, we let label start from 0 by -1, now we add 1 back
        color, class_name = color_mapper(label)
        # plot the rectangle bounding box and label
        image = cv2.rectangle(image, (int(x), int(y)), (int(x+w), int(y+h)), color, 2)
        if label:
            (w, h), _ = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1)
            image = cv2.rectangle(image, (x, y+20), (x + w, y), color, -1)
            cv2.putText(image, class_name, (x, y+18), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (1,1,1), 1)
    return image

def save_image(image, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    if image.dtype == np.float32:
        image = (image * 255).astype(np.uint8)[..., ::-1]
    cv2.imwrite(save_path, image)

In [None]:
for idx, batch in enumerate(test_loader):
    images = batch[0].cuda()
    context = torch.tensor(batch[1]).cuda()
    random_image = torch.randn(1, 3, 64, 64).cuda()
    res = ddpm_model.fast_sampling(
        noise=random_image, 
        model_kwargs={'context': context}, 
    )
    y_0_hat = res[0]

    #start plot
    image = format_image(y_0_hat)
    bboxes = batch[1][0]
    image_annoed = plot_bounding_box(image.copy(), torch.tensor(bboxes).numpy(), ANIMAL_ID_MAPPING)

    save_image(
        image_annoed, 
        save_path=os.path.join(expt_path, 'sampling', f'{idx:04d}.png')
    )

In [1]:
from data.coco_detect import get_coco_id_mapping

In [2]:
x = get_coco_id_mapping()

In [3]:
x

{1: 'person',
 2: 'bicycle',
 3: 'car',
 4: 'motorcycle',
 5: 'airplane',
 6: 'bus',
 7: 'train',
 8: 'truck',
 9: 'boat',
 10: 'traffic light',
 11: 'fire hydrant',
 13: 'stop sign',
 14: 'parking meter',
 15: 'bench',
 16: 'bird',
 17: 'cat',
 18: 'dog',
 19: 'horse',
 20: 'sheep',
 21: 'cow',
 22: 'elephant',
 23: 'bear',
 24: 'zebra',
 25: 'giraffe',
 27: 'backpack',
 28: 'umbrella',
 31: 'handbag',
 32: 'tie',
 33: 'suitcase',
 34: 'frisbee',
 35: 'skis',
 36: 'snowboard',
 37: 'sports ball',
 38: 'kite',
 39: 'baseball bat',
 40: 'baseball glove',
 41: 'skateboard',
 42: 'surfboard',
 43: 'tennis racket',
 44: 'bottle',
 46: 'wine glass',
 47: 'cup',
 48: 'fork',
 49: 'knife',
 50: 'spoon',
 51: 'bowl',
 52: 'banana',
 53: 'apple',
 54: 'sandwich',
 55: 'orange',
 56: 'broccoli',
 57: 'carrot',
 58: 'hot dog',
 59: 'pizza',
 60: 'donut',
 61: 'cake',
 62: 'chair',
 63: 'couch',
 64: 'potted plant',
 65: 'bed',
 67: 'dining table',
 70: 'toilet',
 72: 'tv',
 73: 'laptop',
 74: 'mo

In [None]:
#### visualize prediction results