# Imports

In [1]:
import json, requests, time, textwrap, csv
import torch
import numpy as np
import matplotlib.pyplot as plt

from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, EulerDiscreteScheduler
from transformers import BlipProcessor, BlipForConditionalGeneration
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nltk
nltk.download("popular") # required to download the stopwords lists
from nltk.corpus import stopwords
STOPWORDS = stopwords.words('english')

[nltk_data] Downloading collection 'popular'
[nltk_data]    | 
[nltk_data]    | Downloading package cmudict to /root/nltk_data...
[nltk_data]    |   Package cmudict is already up-to-date!
[nltk_data]    | Downloading package gazetteers to /root/nltk_data...
[nltk_data]    |   Package gazetteers is already up-to-date!
[nltk_data]    | Downloading package genesis to /root/nltk_data...
[nltk_data]    |   Package genesis is already up-to-date!
[nltk_data]    | Downloading package gutenberg to /root/nltk_data...
[nltk_data]    |   Package gutenberg is already up-to-date!
[nltk_data]    | Downloading package inaugural to /root/nltk_data...
[nltk_data]    |   Package inaugural is already up-to-date!
[nltk_data]    | Downloading package movie_reviews to
[nltk_data]    |     /root/nltk_data...
[nltk_data]    |   Package movie_reviews is already up-to-date!
[nltk_data]    | Downloading package names to /root/nltk_data...
[nltk_data]    |   Package names is already up-to-date!
[nltk_data]    | Do

# Models loading

In [3]:
# Stable Difussion XL
#STABLE_DIFF_PIPELINE = StableDiffusionXLPipeline.from_pretrained(
#    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
#)
#STABLE_DIFF_PIPELINE = STABLE_DIFF_PIPELINE.to("cuda")

# Stable Difussion v1
# STABLE_DIFF_PIPELINE = StableDiffusionPipeline.from_pretrained(
#     "CompVis/stable-diffusion-v1-1", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
# )
# STABLE_DIFF_PIPELINE = STABLE_DIFF_PIPELINE.to("cuda")

# Stable Difussion v2
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
STABLE_DIFF_PIPELINE = StableDiffusionPipeline.from_pretrained(
    model_id, scheduler=scheduler, torch_dtype=torch.float16
)
STABLE_DIFF_PIPELINE = STABLE_DIFF_PIPELINE.to("cuda")

# BLIP
BLIP_PROCCESOR = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", max_length=500)
BLIP_MODEL = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", max_length=500)

# MPNET embedding
MPNET = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

Loading pipeline components...: 100%|██████████| 6/6 [00:09<00:00,  1.64s/it]


# Captioning

In [4]:
# Caption generation using Blip
def generate_caption(img):
    inputs = BLIP_PROCCESOR(img, return_tensors="pt")

    out = BLIP_MODEL.generate(**inputs)
    return BLIP_PROCCESOR.decode(out[0], skip_special_tokens=True)

In [5]:
# Calculate captions for a batch of images
def get_batch_captions(images):
    captions = []
    for i, img in enumerate(tqdm(images, desc="Generating captions")):
        caption = generate_caption(img)
        captions.append(caption)
        
    return captions

# Sentence similarity

In [6]:
# Cosine similarity distance
def cosine_sim(a, b):
    return  np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))

In [7]:
# Using cosine distance between text embeddings of MPNet
def get_batch_similarities(prompt, captions):
    capt_em = MPNET.encode(captions)
    prompt_em = MPNET.encode(prompt)

    similarities = []

    for cap in tqdm(capt_em, desc="Calculating similarities"):
        similarities.append(cosine_sim(cap, prompt_em))
    
    return similarities

In [8]:
def get_similarity(prompt, caption):
    prompt_em = MPNET.encode(prompt)
    capt_em = MPNET.encode(caption)
    
    return cosine_sim(prompt_em, capt_em)

## Asses which words produce a negative impact in the similarity

In [19]:
def get_neg_pos_words(prompt, captions):
    neg_words = []
    pos_words = []
    for caption in captions:
        original_sim = get_similarity(prompt, caption)
        for word in caption.split():
            # Test each caption without every word combination (if it isnt a stop word)
            if word not in STOPWORDS:
                new_cap = caption.replace(word, '')
                similarity = get_similarity(prompt, new_cap)

                if original_sim-similarity < -(original_sim*0.3):
                    neg_words.append(word)
                elif original_sim-similarity > original_sim*0.3:
                    pos_words.append(word)
                
    return neg_words, pos_words

# Show results

In [20]:
def show_results(images, captions, similarities, prompt=None, path=None):
    fig, ax = plt.subplots(1, len(captions), figsize=(5*len(captions), 6))
    i=0

    for img, caption, similarity in zip(images, captions, similarities):
        ax[i].imshow(img)
        ax[i].set_title(textwrap.fill(caption, 40))
        ax[i].set_xlabel('Similarity ' + str(similarity))
        ax[i].set_xticks([])
        ax[i].set_yticks([])

        i+=1
        
    if len(images) == 3:
        ax[0].set_xlabel('Worst image' + '\nSimilarity ' + str(similarities[0]))
        ax[1].set_xlabel('Best image' + '\nSimilarity ' + str(similarities[1]))
        ax[2].set_xlabel('Base image' + '\nSimilarity ' + str(similarities[2]))
    
    if prompt!=None:
        fig.suptitle('Prompt:' + textwrap.fill(prompt, 100))
        
    if path!=None:
        plt.savefig(path)
        
    plt.show()

# Configuration

In [21]:
IT_GEN_IMGS = 5
n_gens = 50

In [22]:
# Set guidance scale value depending on average fitness.
# It will define a range between 5 and 9 depending on the fitness, with a max
# fitness of 0.6 and a min fitness of 0.2
def get_guidance(fit):
    if fit >= 0.2 and fit <= 0.6:
        new_fit = (fit-0.2)/(0.6-0.2)
        new_fit = 7+new_fit*(13-7)
        return new_fit
    elif fit < 0.2:
        return 7
    else:
        return 13

# Optimize generated images

In [23]:
def optimization(prompt, original_img, n_gens, csv_writer, idx, res_folder):
    # Initial ttention to the caption
    guid_scale = 7

    # Save all generated images
    total_imgs = []
    total_captions = []
    total_similarities = []
    
    negative_words = ''
    positive_words = ''
    negative_words_list = ''
    positive_words_list = ''
    
    # Save evolution of similarity within optimization process
    it_mean_sim = []
    it_max_sim = []
    it_min_sim = []
    
    # Save negative and positive word list for each iteration
    it_neg_words = []
    it_pos_words = []
    
    it_neg_words_num = []
    it_pos_words_num = []
    
    i=0
    
    while i<n_gens:
        images_batch = []
        captions_batch = []
        similarities_batch = []

        print("Guidance", guid_scale)
        # Generating original batch of images
        images_batch = STABLE_DIFF_PIPELINE(prompt=prompt, num_images_per_prompt=IT_GEN_IMGS,
                                            negative_prompt=negative_words_list,
                                            #prompt_2=positive_words_list,
                                            guidance_scale=guid_scale, output_type='pil').images
        captions_batch = get_batch_captions(images_batch)
        similarities_batch = get_batch_similarities(prompt, captions_batch)

        #show_results(images_batch, captions_batch, similarities_batch)

        # Get the similarity value of the iteration
        mean_sim = np.mean(similarities_batch)

        it_mean_sim.append(mean_sim)
        it_max_sim.append(np.max(similarities_batch))
        it_min_sim.append(np.min(similarities_batch))

        # Add the new batch of samples to the total
        total_imgs.extend(images_batch)
        total_captions.extend(captions_batch)
        total_similarities.extend(similarities_batch)

        print('-'*5, 'Mean similarity of iteration', i, '=', mean_sim)

        guid_scale = get_guidance(mean_sim)

        # Get words that produces negative similarity
        neg_words, pos_words = get_neg_pos_words(prompt, captions_batch)

        # Add new words to the string
        neg_words = ' '.join(neg_words)
        pos_words = ' '.join(pos_words)
        negative_words = negative_words + ' ' + neg_words
        positive_words = positive_words + ' ' + pos_words
        # Delete duplicates
        negative_words = ' '.join(list(set(negative_words.split(' '))))
        positive_words = ' '.join(list(set(positive_words.split(' '))))
        # Add commas between words
        negative_words_list = negative_words.replace(" ", ", ")[2:]
        positive_words_list = positive_words.replace(" ", ", ")[2:]
        print('Positive words list:', positive_words_list)
        print('Negative words list:', negative_words_list)
        
        # Save word list and number of words
        it_neg_words.append(negative_words)
        it_pos_words.append(positive_words)
        
        it_neg_words_num.append(len(negative_words.split(' ')) - 1)
        it_pos_words_num.append(len(positive_words.split(' ')) - 1)
        
        i+=1

    # Order samples by their similarity
    zip_list = zip(total_imgs, total_captions, total_similarities, it_neg_words, it_pos_words)
    sort_zip_list = sorted(zip_list, key=lambda x: x[2], reverse=True)
    sort_total_imgs, sort_total_captions, sort_total_similarities, sort_neg_words, sort_pos_words = zip(*sort_zip_list)

    # Show all results ordered
    #show_results(sort_total_imgs, sort_total_captions, sort_total_similarities)

    # Get best result
    best_img = sort_total_imgs[0]
    best_caption = sort_total_captions[0]
    best_similarity = sort_total_similarities[0]
    best_neg_words = sort_neg_words[0]
    best_pos_words = sort_pos_words[0]

    worst_img = sort_total_imgs[-1]
    worst_caption = sort_total_captions[-1]
    worst_similarity = sort_total_similarities[-1]
    worst_neg_words = sort_neg_words[-1]
    worst_pos_words = sort_pos_words[-1]

    # Get caption and similarity of original image
    original_caption = generate_caption(original_img)
    original_similarity = get_similarity(prompt, original_caption)

    show_results([worst_img, best_img, original_img],
                 [worst_caption, best_caption, original_caption],
                 [worst_similarity, best_similarity, original_similarity],
                 prompt = prompt,
                 path=res_folder + "{:03d}".format(idx) + '_images.png')

    best_img.save(res_folder + "{:03d}".format(idx) + '_best.png')
    worst_img.save(res_folder + "{:03d}".format(idx) + '_worst.png')
    original_img.save(res_folder + "{:03d}".format(idx) + '_base.png')

    # Show evolution of similarity
    fig, ax = plt.subplots(figsize=(5,5))
    plt.suptitle('Similarity evolution with the prompt')
    
    color = 'tab:blue'
    ax.plot(it_mean_sim, color=color)
    ax.fill_between(np.arange(len(it_mean_sim)), 
                     np.array(it_min_sim),
                     np.array(it_max_sim),
                     color=color, alpha=0.2)
    # Original similarity
    ax.axhline(y=original_similarity, color='k', linestyle='dashed', linewidth=0.7)

    ax.set_xlabel('Iteration')
    ax.set_ylabel('Similarity', color=color)
    
    ax2 = ax.twinx()
    color = 'tab:red'
    ax2.set_ylabel('Neg words number', color=color)
    ax2.plot(it_neg_words_num, color=color)
    ax2.tick_params(axis='y', labelcolor=color)
    fig.tight_layout()
    
    plt.savefig(res_folder + "{:03d}".format(idx) + '_similarities.png')
    plt.show()

    # Write in the .csv
    csv_writer.writerow({'id' : "{:03d}".format(idx),
                         'prompt': prompt,
                         'neg_words': negative_words_list,
                         'pos_words': positive_words_list,
                         'best_caption': best_caption, 'best_sim': best_similarity,
                         'best_neg_words': best_neg_words, 'best_pos_words': best_pos_words,
                         'worst_caption': worst_caption, 'worst_sim': worst_similarity,
                         'worst_neg_words': worst_neg_words, 'worst_pos_words': worst_pos_words,
                         'base_caption': original_caption, 'base_sim': original_similarity,
                         'sim_hist': str(it_mean_sim), 'min_sim_hist': str(it_min_sim), 'max_sim_hist': str(it_max_sim),
                         'neg_words_hist': str(it_neg_words), 'pos_words_hist': str(it_pos_words),
                         'neg_words_num_hist': str(it_neg_words_num), 'pos_words_num_hist': str(it_pos_words_num)})

# Download DiffusionDB dataset

In [24]:
from datasets import load_dataset

# Load the dataset with the `large_random_1k` subset
dataset = load_dataset('poloclub/diffusiondb', 'large_random_1k')
data = dataset['train']

# Delete duplicates
prompts_list = list(set(data['prompt']))

# Delete prompts with commas, 'and' o less than 5 words and more than 25
prompts = []
for prompt in prompts_list:
    if len(prompt.split()) >= 5 and len(prompt.split()) <= 25:
        if ',' not in prompt and 'and' not in prompt:
            prompts.append(prompt)

In [25]:
prompts

['a cell shaded cartoon still of walter white ',
 '( steve jobs ) gold king tut mask ',
 'steve jobs crossing the alps painting by jacques louis david. ',
 'tom waits as the goblin king by brian froud ',
 'honey label by adolphe mucha ',
 'turn of the century sepia photo of a man waiting at the train station while using an ipad ',
 'astral projection by gustave dore ',
 'huge glitter bomb explosion above city ',
 'a capybara fighting an alien ',
 'turin in 2 0 7 0 ',
 'the most amazing smore you have ever seen ',
 'steve jobs breaks the tablets of the law by gustave dore. ',
 'funko pop of hillary clinton ',
 'book cover for the epic fantasy book titled echos of dragons ',
 'painting by bridget bate tichenor ',
 'funko pop of danny devito ',
 '3 d octane render of a glowing yellow orb with white clear wings flying ']

In [None]:
# Open results .csv
res_folder = 'results/'
with open(res_folder + 'csvfile.csv', 'a+', newline ='') as csv_file:
    header = ['id', 'prompt', 'neg_words', 'pos_words',
              'best_caption', 'best_sim', 'best_neg_words', 'best_pos_words',
              'base_caption', 'base_sim',
              'worst_caption', 'worst_sim', 'worst_neg_words', 'worst_pos_words',
              'sim_hist', 'min_sim_hist', 'max_sim_hist',
              'neg_words_hist', 'pos_words_hist',
              'neg_words_num_hist', 'pos_words_num_hist']
    csv_writer = csv.DictWriter(csv_file, fieldnames = header, delimiter=';')
    csv_writer.writeheader()

    for i, prompt in enumerate(prompts):
        print('Prompt:', prompt)
        # Get the original image
        image = data['image'][data['prompt'].index(prompt)]
        optimization(prompt, image, n_gens, csv_writer, i, res_folder)