In [None]:
import argparse
from PIL import Image, ImageDraw
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.VO import VOSampler
import os 
from transformers import CLIPProcessor, CLIPModel
from copy import deepcopy
import torch 
from ldm.util import instantiate_from_config
from trainer import batch_to_device
from inpaint_mask_func import draw_masks_from_boxes
import numpy as np
import clip 
from scipy.io import loadmat
from functools import partial
import torchvision.transforms.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from box_utils import save_img, Pharse2idx_2, process_box_phrase, format_box, draw_box_2
import torchvision.transforms as transforms
from pytorch_lightning import seed_everything
from PIL import Image, ImageDraw, ImageFont
from urllib.request import urlopen
device = "cuda"

def set_alpha_scale(model, alpha_scale):
    from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense
    for module in model.modules():
        if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense:
            module.scale = alpha_scale

def alpha_generator(length, type=None):

    if type == None:
        type = [1,0,0]

    assert len(type)==3 
    assert type[0] + type[1] + type[2] == 1
    
    stage0_length = int(type[0]*length)
    stage1_length = int(type[1]*length)
    stage2_length = length - stage0_length - stage1_length
    
    if stage1_length != 0: 
        decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1]
        decay_alphas = list(decay_alphas)
    else:
        decay_alphas = []
        
    
    alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length
    
    assert len(alphas) == length
    
    return alphas

def load_ckpt(ckpt_path):
    
    saved_ckpt = torch.load(ckpt_path)
    config = saved_ckpt["config_dict"]["_content"]

    model = instantiate_from_config(config['model']).to(device)
    autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval()
    text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval()
    diffusion = instantiate_from_config(config['diffusion']).to(device)

    # donot need to load official_ckpt for self.model here, since we will load from our ckpt
    model.load_state_dict( saved_ckpt['model'] )
    autoencoder.load_state_dict( saved_ckpt["autoencoder"]  )
    text_encoder.load_state_dict( saved_ckpt["text_encoder"]  )
    diffusion.load_state_dict( saved_ckpt["diffusion"]  )

    return model, autoencoder, text_encoder, diffusion, config

def project(x, projection_matrix):

    return x@torch.transpose(projection_matrix, 0, 1)

def get_clip_feature(model, processor, input, is_image=False):
    which_layer_text = 'before'
    which_layer_image = 'after_reproject'

    if is_image:
        if input == None:
            return None
        image = Image.open(input).convert("RGB")
        inputs = processor(images=[image],  return_tensors="pt", padding=True)
        inputs['pixel_values'] = inputs['pixel_values'].cuda() # we use our own preprocessing without center_crop 
        inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda()  # placeholder
        outputs = model(**inputs)
        feature = outputs.image_embeds 
        if which_layer_image == 'after_reproject':
            feature = project( feature, torch.load('projection_matrix').cuda().T ).squeeze(0)
            feature = ( feature / feature.norm() )  * 28.7 
            feature = feature.unsqueeze(0)
    else:
        if input == None:
            return None
        inputs = processor(text=input,  return_tensors="pt", padding=True)
        inputs['input_ids'] = inputs['input_ids'].cuda()
        inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder 
        inputs['attention_mask'] = inputs['attention_mask'].cuda()
        outputs = model(**inputs)
        if which_layer_text == 'before':
            feature = outputs.text_model_output.pooler_output
    return feature

def complete_mask(has_mask, max_objs):

    mask = torch.ones(1,max_objs)
    if has_mask == None:
        return mask 

    if type(has_mask) == int or type(has_mask) == float:
        return mask * has_mask
    else:
        for idx, value in enumerate(has_mask):
            mask[0,idx] = value
        return mask

@torch.no_grad()
def prepare_batch(meta, batch=1, max_objs=30):

    phrases, images = meta.get("phrases"), meta.get("images")
    images = [None]*len(phrases) if images==None else images 
    phrases = [None]*len(images) if phrases==None else phrases 

    version = "openai/clip-vit-large-patch14"
    model = CLIPModel.from_pretrained(version).cuda()
    processor = CLIPProcessor.from_pretrained(version)

    boxes = torch.zeros(max_objs, 4)
    masks = torch.zeros(max_objs)
    text_masks = torch.zeros(max_objs)
    image_masks = torch.zeros(max_objs)
    text_embeddings = torch.zeros(max_objs, 768)
    image_embeddings = torch.zeros(max_objs, 768)
    
    text_features = []
    image_features = []
    for phrase, image in zip(phrases,images):
        text_features.append(  get_clip_feature(model, processor, phrase, is_image=False) )
        image_features.append( get_clip_feature(model, processor, image,  is_image=True) )

    for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)):
        boxes[idx] = torch.tensor(box)
        masks[idx] = 1
        if text_feature is not None:
            text_embeddings[idx] = text_feature
            text_masks[idx] = 1 
        if image_feature is not None:
            image_embeddings[idx] = image_feature
            image_masks[idx] = 1 

    out = {
        "boxes" : boxes.unsqueeze(0).repeat(batch,1,1),
        "masks" : masks.unsqueeze(0).repeat(batch,1),
        "text_masks" : text_masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta.get("text_mask"), max_objs ),
        "image_masks" : image_masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta.get("image_mask"), max_objs ),
        "text_embeddings"  : text_embeddings.unsqueeze(0).repeat(batch,1,1),
        "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1)
    }

    return batch_to_device(out, device) 

def crop_and_resize(image):
    crop_size = min(image.size)
    image = TF.center_crop(image, crop_size)
    image = image.resize( (512, 512) )
    return image


def run(meta,models,info_files, p, starting_noise=None,iter_id=0, img_id=0, save=True,count=-1):
    model, autoencoder, text_encoder, diffusion, config = models

    grounding_tokenizer_input = instantiate_from_config(config['grounding_tokenizer_input'])
    model.grounding_tokenizer_input = grounding_tokenizer_input
    
    grounding_downsampler_input = None
    if "grounding_downsampler_input" in config:
        grounding_downsampler_input = instantiate_from_config(config['grounding_downsampler_input'])

    # - - - - - update config from args - - - - - # 
    config.update( vars(args) )
    config = OmegaConf.create(config)

    # - - - - - prepare batch - - - - - #

    batch = prepare_batch(meta, config.batch_size)
    context = text_encoder.encode(  [meta["prompt"]]*config.batch_size  )
    uc = text_encoder.encode( config.batch_size*[""] )
    with torch.no_grad():
        if args.negative_prompt is not None:
            uc = text_encoder.encode( config.batch_size*[args.negative_prompt] )

    # - - - - - sampler - - - - - # 
    alpha_generator_func = partial(alpha_generator, type=meta.get("alpha_type"))
    sampler = VOSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
    steps = 50
    inpainting_mask = z0 = None  
    inpainting_extra_input = None 

    grounding_input = grounding_tokenizer_input.prepare(batch)
    grounding_extra_input = None
    if grounding_downsampler_input != None:
        grounding_extra_input = grounding_downsampler_input.prepare(batch)

    input = dict(

                x = starting_noise, 
                timesteps = None, 
                context = context, 
                grounding_input = grounding_input,
                inpainting_extra_input = None,
                grounding_extra_input = grounding_extra_input,
                boxes=meta['ll'],
                object_position = meta['position'],
            )

    # - - - - - start sampling - - - - - #
    shape = (config.batch_size, model.in_channels, model.image_size, model.image_size)
    samples_fake,img_list,x0_list = sampler.sample(S=steps, shape=shape, input=input,  uc=uc, guidance_scale=config.guidance_scale, mask=inpainting_mask, x0=z0, loss_type=None)
    with torch.no_grad():
        samples_fake = autoencoder.decode(samples_fake)
    for i in range(steps):
        with torch.no_grad():
            img_fake = autoencoder.decode(img_list[i])
            x0_fake = autoencoder.decode(x0_list[i])
        img_fake = torch.clamp(img_fake[0], min=-1, max=1) * 0.5 + 0.5
        img_fake = img_fake.cpu().numpy().transpose(1, 2, 0) * 255
        img_fake = Image.fromarray(img_fake.astype(np.uint8))

        x0_fake = torch.clamp(x0_fake[0], min=-1, max=1) * 0.5 + 0.5
        x0_fake = x0_fake.cpu().numpy().transpose(1, 2, 0) * 255
        x0_fake = Image.fromarray(x0_fake.astype(np.uint8))

    # save images
    if save :
        #path = meta["save_folder_name"]
        output_folder1 = os.path.join( args.folder,  meta["save_folder_name"]+'_img')
        os.makedirs( output_folder1, exist_ok=True)
        output_folder2 = os.path.join( args.folder,  meta["save_folder_name"] + '_box')
        os.makedirs( output_folder2, exist_ok=True)
        start = len( os.listdir(output_folder2) )
        image_ids = list(range(start,start+config.batch_size))
        print(image_ids)
        font = ImageFont.truetype("Roboto-LightItalic.ttf", size=20)
        for image_id, sample in zip(image_ids, samples_fake):
            img_name = meta['prompt'].replace(' ', '_') + str(int(image_id))+'.png'
            sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
            sample = sample.cpu().numpy().transpose(1,2,0) * 255 
            sample = Image.fromarray(sample.astype(np.uint8))
            img2 = sample.copy()
            draw = ImageDraw.Draw(sample)
            boxes = meta['location_draw']
            text = meta["phrases"]
            
            info_files.update({img_name: (text, boxes)})
            for i, box in enumerate(boxes):
                t = text[i]

                draw.rectangle([(box[0], box[1]),(box[2], box[3])], outline=128, width=2)
                draw.text((box[0]+5, box[1]+5), t, fill=200,font=font )
            save_img(output_folder2, sample,meta['prompt'],iter_id,img_id,count)
            save_img(output_folder1,img2,meta['prompt'],iter_id ,img_id,count)
            
    return samples_fake


In [None]:

parser = argparse.ArgumentParser()
parser.add_argument("--folder", type=str,  default="visual")
parser.add_argument('--ckpt', type=str, default='gligen_checkpoints/diffusion_pytorch_model.bin')
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--guidance_scale", type=float,  default=7.5)
parser.add_argument("--negative_prompt", type=str,  default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
parser.add_argument("--file_save", type=str, default='result')
args = parser.parse_args('')


In [3]:
ckpt = args.ckpt = 'gligen_checkpoints/diffusion_pytorch_model.bin'
file_save = args.file_save = 'result'
batch_size = args.batchsize =  1

meta_list = [ 

    # - - - - - - - - GLIGEN on text grounding for generation - - - - - - - - # 
    dict(
        ckpt = ckpt,
        prompt =None,
        phrases = None,
        locations = None,
        alpha_type = [0.3, 0.0, 0.7],
        save_folder_name=file_save,
        ll = None
    )
]

models = load_ckpt(meta_list[0]["ckpt"])

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.10.mlp.fc2.bias', 'vision_model.encoder.layers.20.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.19.self_attn.k_proj.bias', 'vision_model.encoder.layers.1.mlp.fc2.bias', 'vision_model.encoder.layers.0.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.layer_norm1.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.23.self_attn.out_proj.weight', 'vision_model.encoder.layers.6.mlp.fc1.weight', 'vision_model.encoder.layers.10.mlp.fc1.weight', 'vision_model.enco

In [4]:
caption = 'A car and a bike in front of a house.'
names_list = ['house','car','bike']
layout = [(66, 197, 452, 390), (326, 358, 402, 432), (111, 347, 216, 431)] 

info_files = {}

for meta in meta_list:

    pp = caption
    o_names = names_list
    o_boxes = layout
    meta["prompt"] = pp 
    text = pp 
    
    for k in range(1):

        starting_noise = torch.randn(batch_size, 4, 64, 64).to(device)
        starting_noise = starting_noise.to(device)

        p, ll  = format_box(o_names, o_boxes)
        l = np.array(o_boxes)
        name_box = process_box_phrase(o_names, o_boxes)
        position, box_att = Pharse2idx_2(pp, name_box)
        print('position', position, pp )
        meta["phrases"] = p
        meta['location_draw'] = l
        meta["locations"] = l/512
        meta['ll'] = box_att
        meta['position'] = position
        run(meta, models, info_files, args, starting_noise, k,1)

position [[10], [2], [5]] A car and a bike in front of a house.
step  0
optimize 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


step  1
optimize 1
step  2
optimize 2
step  3
optimize 3
step  4
optimize 4
step  5
optimize 5
step  6
optimize 6
step  7
optimize 7
step  8
optimize 8
step  9
optimize 9
step  10
optimize 10
step  11
optimize 11
step  12
optimize 12
step  13
optimize 13
step  14
optimize 14
step  15
optimize 15
step  16
optimize 16
step  17
optimize 17
step  18
optimize 18
step  19
optimize 19
step  20
optimize 20
step  21
optimize 21
step  22
optimize 22
step  23
optimize 23
step  24
optimize 24
step  25
optimize 25
step  26
optimize 26
step  27
optimize 27
step  28
optimize 28
step  29
optimize 29
step  30
optimize 30
step  31
optimize 31
step  32
optimize 32
step  33
optimize 33
step  34
optimize 34
step  35
optimize 35
step  36
optimize 36
step  37
optimize 37
step  38
optimize 38
step  39
optimize 39
step  40
optimize 40
step  41
optimize 41
step  42
optimize 42
step  43
optimize 43
step  44
optimize 44
step  45
optimize 45
step  46
optimize 46
step  47
optimize 47
step  48
optimize 48
step  49
o