In [10]:
from PIL import Image
from pathlib import Path
import json
from augclip import AugCLIP
from tqdm import tqdm
import torch
from torch.nn import functional as F

In [16]:
work_dir = Path("/home/server08/yoonjeon_workspace/augclip_data")
edit_path = work_dir / "tedbench/edit_prompt.json"
dataset = "tedbench"
models = ["imagic"]
augclip = AugCLIP("cuda:0", "ViT-B/16")

score_path = work_dir / f"scores/{dataset}/augclip.json"

with open(edit_path, "r") as f:
    cfg = json.load(f)

In [17]:
cs = lambda x, y: torch.einsum("ij, kj -> ik", x, y)

In [83]:
case = cfg[1]

In [84]:
img_path = str(case["img_name"]).replace("ROOT_DIR", str(work_dir))
manip_path = str(case["output_name"]).replace("ROOT_DIR", str(work_dir))
src_image = Image.open(img_path).convert("RGB").resize((512, 512))

src_text, tgt_text = case["prompt"]
src_text_feat, tgt_text_feat = augclip.get_text_embeddings(src_text), augclip.get_text_embeddings(tgt_text)

src_desc, tgt_desc = case["source_desc"], case["target_desc"]
with open("augclip/desc/trial1/")
src_desc_feat, tgt_desc_feat = augclip.get_text_embeddings(src_desc), augclip.get_text_embeddings(tgt_desc)

delta_T = F.normalize(tgt_text_feat - src_text_feat, p=2, dim=-1)


In [85]:
model = "imagic"

In [86]:
manip_model_path = str(manip_path).replace("MODEL_NAME", model)
manip_image = Image.open(manip_model_path).convert("RGB").resize((512, 512))

src_img_feat, tgt_img_feat = augclip.get_image_embeddings([src_image]), augclip.get_image_embeddings([manip_image])

input_arguments = {
    "src_img_feat": src_img_feat,
    "tgt_img_feat": tgt_img_feat,
    "src_text_feat": src_text_feat,
    "tgt_text_feat": tgt_text_feat,
    "src_desc_feat": src_desc_feat,
    "tgt_desc_feat": tgt_desc_feat,
    "src_set_feat": src_desc_feat,
    "tgt_set_feat": tgt_desc_feat,
    "thres": 0 # probability of target + src attr happening
}
change_attr = False # if case=='attr' else False

In [94]:
src_sims = cs(src_desc_feat, src_img_feat)
in_source = (src_sims > 0.).flatten() # check if the source image has the given source property
src_desc_feat = src_desc_feat[in_source]

dist1 = cs(src_desc_feat, src_desc_feat)
mask = ~torch.eye(dist1.shape[0], dtype=bool)
dist1 = dist1[mask].reshape(dist1.shape[0], -1)

dist2 = cs(src_desc_feat, tgt_desc_feat)

w1 = dist1.var(dim=-1)
ordered_weight, idxs = w1.sort(descending=False)
ordered_desc = [src_desc[idx] for idx in idxs]
inv_w = 1 / ordered_weight
inv_w /= inv_w.max()
for d, w in zip(ordered_desc, inv_w):
    
    print(d, round(w, ndigits=2))

TypeError: type Tensor doesn't define __round__ method

In [88]:
dist1 = cs(tgt_desc_feat, src_desc_feat)
print(dist1.shape)

dist2 = cs(tgt_desc_feat, tgt_desc_feat)
mask = ~torch.eye(dist2.shape[0], dtype=bool)
dist2 = dist2[mask].reshape(dist2.shape[0], -1)
print(dist2.shape)

w2 = dist2.var(dim=-1)
idxs = w2.sort(descending=False).indices
ordered_desc = [tgt_desc[idx] for idx in idxs]
print(ordered_desc)

torch.Size([30, 30])
torch.Size([30, 29])
['Muscular', 'Adventurous', 'Graceful', 'Athletic', 'Playful', 'Energetic', 'Dynamic', 'Agile', 'Tail up', 'Free-spirited', 'Alert', 'Healthy', 'Paws off the ground', 'In motion', 'Lively', 'Vibrant', 'High energy', 'Determined', 'Happy', 'Action shot', 'Bouncy', 'Excited expression', 'Frolicking', 'Joyful', 'Acrobatic', 'Airborne', 'Focused', 'Mid-air', 'Leaping', 'Jumping']


In [89]:
print(tgt_text)

A photo of a jumping dog.


## Source Prompt Generation

In [2]:
import os
import sys

import PIL
import argparse
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration
from openai import OpenAI
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq

# FILL API KEY HERE
client = OpenAI()

from utils import *


def generate_descs(caption, source_text, target_text, prompt_option=0, temperature=0., frequency_penalty=0.):
    if prompt_option==0:
        messages=[
            {
                "role": "system", 
                "content": f"Give me a python list of 30 visual characteristics \
                    describes {caption}."
            },
            {
                "role": "user", 
                "content": "Return python list of that 30 text \
                    descriptions that depict a photo of a fluffy brown coated dog"
            },
            {"role": "assistant", "content": 
                        """
                        [
                            "Furry coat",
                            "Four legs",
                            "Tail wagging",
                            "Barking",
                            "Playful behavior",
                            "Snout",
                            "Collar",
                            "Leash",
                            "Walking on all fours",
                            "Wagging tail",
                            ]
                        """ 
            }
            ,
            {
                "role": "user", 
                "content": f"Give me a python list of 30 visual characteristics \
                    describes {caption}."}
        ]
    
    elif prompt_option==1:
        question = f"""
            Given the source text and target text, give me a python list of 30 visual characteristics that a target might have. 
            Note that you should extract the characteristics of only the target compared to the source.
            The visual features must be easy to portray: color, texture, shape, material, objects that are seen together, usage and etc.
            For example, 
            1. Source: "a photo of a dog", Target: "a photo of a cat", answer: ["Slanted almond-shaped eyes", "Soft and fluffy fur coat", "Long and elegant whiskers", "Pointed ears with tufts of fur", "Graceful and agile movements", ...]
            2. Source: "a photo of a horse", Target: "A photo of a breakdancing horse", answer: ["Windmill", "Twisting body", "Fur and ears flowing", ...]

            Source: {source_text}
            Target: {target_text}
            The answer is:
            """
        messages = [{"role": "user", "content": question}]
            
    elif prompt_option==2:
        messages = [
            {
                "role": "system", 
                "content": f"""
                            Provide image level characteristics 
                            such as color, texture, object category information, 
                            context of appearance, background 
                            of a given text that represents the image.
                            """
            },
            {
                "role": "user", 
                "content": f"""
                            Given a sentence, analyze the 
                            visual characteristics of the image where 'A dog is breakdancing',
                            that is not shown in the image of 'A dog is sitting'.
                            For example, focus on the how breakdancing is different from sitting only and do not mention dog's appearance.
                            Rearrange into a python list format.
                            """
            },
            {"role": "assistant", 
            "content": 
                        """
                        [
                            "motion: A dog shows dynamic movement",
                            "motion: A dog is spinning on its head", 
                            "environment: A dog is dancing on dance floor",
                            "environment: A dog is dancing in a hip hop scene",
                            "appearance: A dog shows stylish dancer outfit", 
                            "appearance: A dog has its legs up in the air",
                        ]
                        """ 
            },
            {
                "role": "user", 
                "content": f"""
                Given a sentence, analyze the 
                visual characteristics of the image where {target_text},
                that is not shown in the image of {source_text}.
                Rearrange into a python list format.
                """
            }
        ]
    completion = client.chat.completions.create(
        model='gpt-3.5-turbo', 
        messages=messages,
        temperature=temperature,
        frequency_penalty=frequency_penalty,
    )
    response = completion.choices[0].message.content
    return response

  from .autonotebook import tqdm as notebook_tqdm


CelebA


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


100%|██████████| 50/50 [04:13<00:00,  5.06s/it]


## Check permutation of Source / Target directions

perm(tgt_desc_feat - src_desc_feat)


1. tgt text 와 비교해서 - 이면 bad direction
2. tgt text 와 비교해서 + 이면 good direction
3. src img 와 비교해서 차이가 크면 bad direction
4. src img 와 비교해서 차이가 작으면 good direction 


In [None]:
dataset = "tedbench"
data_dir = "../augclip_data/"
freq_pen = 0.
prompt_option = 2
temp = 0.
device = "cuda:0"
trial = "trial5"
model = "kosmos"

os.makedirs(f"{trial}/", exist_ok=True)

print(dataset)
edit_data= read_json(f"{data_dir}{dataset}/edit_prompt.json")
if dataset == "dreambooth":
    all_images = set([(item['img_name'], (item['prompt'][0], '')) for item in edit_data])
    edit_data = [{'img_name': path[0], 'prompt': path[1], 'output_name': path[0]} for path in all_images]

desc_path = f"{trial}/{dataset}_src.json"
fail_path = f"{trial}/{dataset}_src_failed.json"
descs = read_json_or_dict(desc_path)
failed = read_json_or_dict(fail_path)

if model=="blip":
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
elif model=="kosmos":
    processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
    model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")


prompt_dict = {
    "CelebA": (lambda src_prompt: f"<grounding>An image has facial features like {src_prompt}, "),
    "sameswap": (lambda src_prompt: f"<grounding>{src_prompt} in the image has visual features like "),
    "tedbench": (lambda src_obj: f"<grounding>An image of {src_obj} has visual features like "),
}

descs = {}
for item in tqdm(edit_data):
    output_name = item["output_name"].split('/')[-1]
    if (output_name in descs) or (output_name in failed):
        continue
    src_prompt, tgt_prompt = item["prompt"]
    img_path = item["img_name"].replace('ROOT_DIR/', data_dir)
    image = PIL.Image.open(img_path)
    if dataset=="tedbench":
        inp = item["object"][0]
    else:
        inp = item["prompt"][0]

    prompt = prompt_dict[dataset](inp)
    inputs = processor(text=prompt, images=image, return_tensors="pt")

    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds=None,
        image_embeds_position_mask=inputs["image_embeds_position_mask"],
        use_cache=True,
        max_new_tokens=128,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Specify `cleanup_and_extract=False` in order to see the raw model generation.
    # processed_text1 = processor.post_process_generation(generated_text, cleanup_and_extract=False)
    caption, entities = processor.post_process_generation(generated_text)
    messages = [
            {
                "role": "system", 
                "content": f"""
                            Extract all descriptions on the given entity in the sentence.
                            """
            },
            {
                "role": "user", 
                "content": f"""
                            Given a sentence 
                            'red haired girl in the image has visual features like curly hair, wearing a hat and blue pajamas and holding a cup of coffee, 
                            while a black haired woman is wearing black suit and writing down notes. They are both in a cafe, sitting around a round table'
                            Extract all descriptions on the entity 'red haired girl'.
                            Then given the descsriptions, augment the given descriptions into a sub-categories of visual characteristics.
                            """
            },
            {"role": "assistant", 
            "content": 
                        """
                        descs: [
                            "Red haired girl has curly hair",
                            "Red haired girl is wearing a hat", 
                            "Red haired girl is wearing blue pajamas", 
                            "Red haired girl is holding a cup of coffee",
                            "Red haired girl is in a cafe",
                            "Red haired girl is sitting around a round table"
                        ]
                        """ 
            },
            {
                "role": "user", 
                "content": f"""
                            Given a sentence 
                            '{caption}'
                            Extract all descriptions on the entity '{inp}'.

                            """
            }
        ]
    # print(f"caption: {caption} / source prompt: {src_prompt} / target prompt: {tgt_prompt} / entities: {entities}")
    # descs["caption"] = caption
    completion = client.chat.completions.create(
        model='gpt-3.5-turbo', 
        messages=messages,
        temperature=temp,
        frequency_penalty=freq_pen,
    )
    response = completion.choices[0].message.content
    x = str(response)
    s_idx, e_idx = x.find("["), x.find("]")
    desc_list = x[s_idx+1: e_idx].split(",")
    desc_list = [desc.strip()[1:-1] for desc in desc_list if len(desc.strip()[1:-1])]
    processed_desc_list = [desc.split(":")[-1].strip() for desc in desc_list]
    descs[output_name] = processed_desc_list
    # desc = generate_descs(caption, src_prompt, tgt_prompt, entities, temp, freq_pen)
    # descs[output_name] = eval(desc.strip('.'))
write_json("./desc/trial5/sameswap_src.json", descs)
        