In [18]:
!pip install nltk gensim numpy




# declare some helper classes

In [80]:
from dataclasses import dataclass
from typing import Union


def sequence_is_in(pattern: Union[tuple,list], sequence: list):
    if type(pattern) is tuple and type(sequence) is list:
        pattern = list(pattern)
    if type(sequence) is not list:
        raise ValueError(f"sequence {sequence} must be a list")
    if type(pattern) is not list:
        raise ValueError(f"unhandled type {type(pattern).__name__} for pattern")
    for i in range(len(sequence) - len(pattern) + 1):
        if sequence[i:i+len(pattern)] == pattern:
            return True
    return False

@dataclass
class TermGroup: 
    terms: list[tuple[str]]
    synonyms: list[list[tuple[str]]]
        
    def get_synonyms(self, term: tuple[str]) -> list[tuple[str]]:
        """
        for the given term, return it and all of its synonyms (if any) as a list
        """
        try:
            synonym_idx = next(i for i,l in enumerate(self.synonyms) if sequence_is_in(term, l))
            return self.synonyms[synonym_idx]
        except StopIteration:
            return [term]
        
    
    def prompt_contains_terms(self, prompt: list[str]) -> list[bool]:
        """
        for the given prompt, return an array of bools  of each of our terms
        """
        result = [False] * len(self.terms)
        for term_index,term in enumerate(self.terms):
            for search_sequence in self.get_synonyms(term):
                #print(f"looking for {search_sequence} in {prompt}")
                if sequence_is_in(search_sequence, prompt):
                    #print('found')
                    result[term_index] = True
                    break
        return result
    
    def get_term_appearances(self, prompts: list[list[str]]) -> list[int]:
        per_term_ids = [[] for _ in range(len(self.terms))]
        for prompt_index,prompt in enumerate(prompts):
            contains_term = self.prompt_contains_terms(prompt)
            #print(contains_term)
            matching_term_indices = [i for i,term in enumerate(self.terms) if contains_term[i]]
            #print(f"existing: {result_tuple}, {len(result_tuple[1])} prompts")
            for term_index in matching_term_indices:
                per_term_ids[term_index].append(prompt_index)
        return per_term_ids


    
def count_groups_appearances(prompts: list[list[str]], groups: list[TermGroup]) -> list[tuple[TermGroup, list[int]]]:
    #result = [(group,[[] for _ in range(len(group.terms))]) for group in groups]
    return [(group, group.get_term_appearances(prompts)) for group in groups]
    #for prompt_index,prompt in enumerate(prompts):
    #    for group_index,group in enumerate(groups):
            
    #        count_group_appearances(group, prompts)
        #break
    return result



# load data

In [3]:

import os
def load_prompts_from_txt_files(root_folder):
    image_extensions = [".jpg", ".png", ".jpeg"]
    for root, _, files in os.walk(root_folder, topdown=False):
        for file in files:
            filename_without_extension, extension = os.path.splitext(file)
            if extension not in image_extensions:
                continue
            image_path = os.path.join(root, file)
            caption_path = os.path.join(root, filename_without_extension + ".txt")
            try:
                with open(caption_path, 'r') as caption_file:
                    caption = caption_file.read().replace('\n', ' ')
                    yield image_path, caption_path, caption
            except FileNotFoundError:
                continue
                

prompts = []
image_filenames = []

#with open('data.txt', 'r') as file:
#    data = file.read().replace('\n', '')

# load prompts

## from .txt files saved alongside image files


In [6]:
root_folder="./images"
prompts = [prompt for _, _, prompt in load_prompts_from_txt_files(root_folder)]
prompts_split_to_words = [p.replace(',', ' , ').split() for p in prompts]

## from a huggingface dataset


In [22]:
from datasets import load_dataset
repo_id = "m1guelpf/nouns"
dataset = load_dataset(repo_id)
prompts = dataset['train']['text']
prompts_split_to_words = [p.replace(',', ' , ').split() for p in prompts]


Using custom data configuration m1guelpf--nouns-cc6819088b485316
Found cached dataset parquet (/Users/damian/.cache/huggingface/datasets/m1guelpf___parquet/m1guelpf--nouns-cc6819088b485316/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

# what did we load?

In [15]:
prompts[:5]

['a pixel art character with square black glasses, a hotdog-shaped head and a peachy-colored body on a warm background',
 'a pixel art character with square black sunglasses, a shower-shaped head and a blue-colored body on a cool background',
 'a pixel art character with square dark green glasses, a yeti-shaped head and a redpinkish-colored body on a cool background',
 'a pixel art character with square dark gray glasses, a void-shaped head and a teal-colored body on a warm background',
 'a pixel art character with square dark green glasses, a rainbow-shaped head and a orange-colored body on a cool background']

# n-gram analysis..

In [75]:
from nltk import ngrams, FreqDist
import itertools

def get_ngram_freqdist(prompts_split_to_words: list[list[str]], n:int=3, filter_by_lambda=None) -> FreqDist:
    fdist = FreqDist()
    for prompt in prompts_split_to_words:
        for ngram in ngrams(prompt, n):
            if filter_by_lambda is None or filter_by_lambda(ngram):
                fdist[ngram] += 1
    return fdist

def get_most_common(fdist: FreqDist, count:int=10):
    return list(itertools.islice(fdist, count))


## check out what common phrases are in here

In [71]:
freqdist = get_ngram_freqdist(prompts_split_to_words, n=3)
[(ngram, freqdist[ngram]) for ngram in get_most_common(freqdist, count=50)]

None


[(('a', 'pixel', 'art'), 49859),
 (('pixel', 'art', 'character'), 49859),
 (('art', 'character', 'with'), 49859),
 (('character', 'with', 'square'), 49859),
 (('head', 'and', 'a'), 49859),
 (('body', 'on', 'a'), 49859),
 (('glasses', ',', 'a'), 43219),
 (('on', 'a', 'warm'), 25100),
 (('a', 'warm', 'background'), 25100),
 (('on', 'a', 'cool'), 24759),
 (('a', 'cool', 'background'), 24759),
 (('green', 'glasses', ','), 8648),
 (('orange', 'glasses', ','), 8484),
 (('and', 'a', 'grayscale-colored'), 6690),
 (('a', 'grayscale-colored', 'body'), 6690),
 (('grayscale-colored', 'body', 'on'), 6690),
 (('with', 'square', 'black'), 6671),
 (('red', 'glasses', ','), 6665),
 (('with', 'square', 'light'), 6623),
 (('with', 'square', 'dark'), 6565),
 (('blue', 'glasses', ','), 6479),
 (('with', 'square', 'orange'), 6360),
 (('square', 'orange', 'glasses'), 6360),
 (('square', 'black', 'sunglasses'), 4467),
 (('black', 'sunglasses', ','), 4467),
 (('sunglasses', ',', 'a'), 4467),
 (('square', 'ligh

so it looks like `warm background` and `cool background` are a good training pair, as are `blue`/`pink`/`yellow`/`green` `glasses` and `body`.

let's try `background`:

In [76]:
freqdist = get_ngram_freqdist(prompts_split_to_words, n=3, filter_by_lambda=lambda ngram: 'background' in ngram)
[(ngram, freqdist[ngram]) for ngram in get_most_common(freqdist, count=50)]

[(('a', 'warm', 'background'), 25100), (('a', 'cool', 'background'), 24759)]

those are pretty evenly balanced. ok, how about `glasses`:

In [78]:
freqdist = get_ngram_freqdist(prompts_split_to_words, n=2, filter_by_lambda=lambda ngram: 'glasses' in ngram)
[(ngram, freqdist[ngram]) for ngram in get_most_common(freqdist, count=50)]

[(('glasses', ','), 43219),
 (('green', 'glasses'), 8648),
 (('orange', 'glasses'), 8484),
 (('red', 'glasses'), 6665),
 (('blue', 'glasses'), 6479),
 (('gray', 'glasses'), 4355),
 (('black', 'glasses'), 2204),
 (('square', 'glasses'), 2173),
 (('glasses', 'with'), 2173),
 (('brown', 'glasses'), 2156),
 (('yellow', 'glasses'), 2118),
 (('pink', 'glasses'), 2110)]

room for optimization here. what about `body`:

In [85]:
freqdist = get_ngram_freqdist(prompts_split_to_words, n=2, filter_by_lambda=lambda ngram: 'body' in ngram)
[(ngram, freqdist[ngram]) for ngram in get_most_common(freqdist, count=50)]

[(('body', 'on'), 49859),
 (('grayscale-colored', 'body'), 6690),
 (('peachy-colored', 'body'), 3339),
 (('orange-colored', 'body'), 3326),
 (('bege-colored', 'body'), 3317),
 (('teal-colored', 'body'), 3288),
 (('purple-colored', 'body'), 1741),
 (('redpinkish-colored', 'body'), 1716),
 (('hotbrown-colored', 'body'), 1705),
 (('computerblue-colored', 'body'), 1704),
 (('red-colored', 'body'), 1696),
 (('gold-colored', 'body'), 1680),
 (('yellow-colored', 'body'), 1678),
 (('magenta-colored', 'body'), 1678),
 (('slimegreen-colored', 'body'), 1668),
 (('darkbrown-colored', 'body'), 1660),
 (('cold-colored', 'body'), 1653),
 (('gunk-colored', 'body'), 1644),
 (('foggrey-colored', 'body'), 1637),
 (('green-colored', 'body'), 1635),
 (('blue-colored', 'body'), 1626),
 (('rust-colored', 'body'), 1620),
 (('bluegrey-colored', 'body'), 1619),
 (('darkpink-colored', 'body'), 1539)]

ok, some more room for optimisation here. let's define some TermGroups:

In [97]:
background_term_group = TermGroup(terms=[('warm', 'background'), ('cool', 'background')], synonyms=[])
glasses_color_term_group = TermGroup(terms=[('black', 'glasses'), ('green', 'glasses'), ('gray', 'glasses'), 
                                    ('blue', 'glasses'), ('orange', 'glasses'), ('red', 'glasses'), 
                                    ('brown', 'glasses'), ('pink', 'glasses'), 
                                    ('yellow', 'glasses')], synonyms=[])
body_color_term_group = TermGroup(terms=[('peachy-colored', 'body'), ('blue-colored', 'body'),
                                         ('redpinkish-colored', 'body'), ('teal-colored', 'body'), 
                                         ('orange-colored', 'body'), ('gold-colored', 'body'), 
                                         ('purple-colored', 'body'), ('darkbrown-colored', 'body'), 
                                         ('cold-colored', 'body'), ('green-colored', 'body'), 
                                         ('grayscale-colored', 'body'), ('gunk-colored', 'body'), 
                                         ('yellow-colored', 'body'), ('rust-colored', 'body'),
                                         ('red-colored', 'body'), ('darkpink-colored', 'body'), 
                                         ('computerblue-colored', 'body'), ('bege-colored', 'body'),
                                         ('slimegreen-colored', 'body'), ('bluegrey-colored', 'body'), 
                                         ('magenta-colored', 'body'), ('hotbrown-colored', 'body'), 
                                         ('foggrey-colored', 'body')], synonyms=[])

groups = [background_term_group, glasses_color_term_group, body_color_term_group]

appearances = count_groups_appearances(prompts_split_to_words, groups)
for group, prompt_ids in appearances:
    print(f"group {group} terms appear:")
    for i in range(len(prompt_ids)):
        print(f" - {group.terms[i]}: {len(prompt_ids[i])} times")
        
#print([(group,[len(ids) for ids in prompt_ids]) for group,prompt_ids in appearances])


group TermGroup(terms=[('warm', 'background'), ('cool', 'background')], synonyms=[]) terms appear:
 - ('warm', 'background'): 25100 times
 - ('cool', 'background'): 24759 times
group TermGroup(terms=[('black', 'glasses'), ('green', 'glasses'), ('gray', 'glasses'), ('blue', 'glasses'), ('orange', 'glasses'), ('red', 'glasses'), ('brown', 'glasses'), ('pink', 'glasses'), ('yellow', 'glasses')], synonyms=[]) terms appear:
 - ('black', 'glasses'): 2204 times
 - ('green', 'glasses'): 8648 times
 - ('gray', 'glasses'): 4355 times
 - ('blue', 'glasses'): 6479 times
 - ('orange', 'glasses'): 8484 times
 - ('red', 'glasses'): 6665 times
 - ('brown', 'glasses'): 2156 times
 - ('pink', 'glasses'): 2110 times
 - ('yellow', 'glasses'): 2118 times
group TermGroup(terms=[('peachy-colored', 'body'), ('blue-colored', 'body'), ('redpinkish-colored', 'body'), ('teal-colored', 'body'), ('orange-colored', 'body'), ('gold-colored', 'body'), ('purple-colored', 'body'), ('darkbrown-colored', 'body'), ('cold-c

## optimize

In [138]:
import math

def percentage_to_error(percentage: float, base=0.333) -> float:
    try:
        powed = math.pow(base, percentage)
    except OverflowError as e:
        print(f"OverflowError doing pow on {base} {percentage}")
        raise
    # remap 1..base to 1..0
    return (powed-base) / (1.0-base)
    

def get_group_score(group: TermGroup, permutation_prompts: list[list[str]], verbose = False) -> float:
    """
    return a score based on how evenly shared the group is (lower is better)
    """
    badness = 0
    
    appearances = group.get_term_appearances(permutation_prompts)
    total_count = sum([len(l) for l in appearances])

    target_percentage_per_term = 1/len(group.terms)
    for term_index in range(len(appearances)):
        term_ids = appearances[term_index]
        percentage = len(term_ids) / (total_count+1)
        error = abs(percentage_to_error(percentage / target_percentage_per_term, base=10))
        if verbose:
            print(f" - pct coverage badness: {error}, count: {round(100 * percentage)}%: {group.terms[term_index]} ({len(term_ids)} prompts)")
        badness += abs(error)

    return badness/len(group.terms)
    
def get_example_count_score(group: TermGroup, 
                            permutation_prompts: list[list[str]], 
                            target_example_count: int=30, 
                            verbose = False) -> float:
    """
    return a score that is the sum of whether the terms in this group have sufficient examples (lower is better)
    """
    badness = 0
    appearances = group.get_term_appearances(permutation_prompts)
    for term_index in range(len(group.terms)):
        example_count = len(appearances[term_index])
        #print(f"example count {example_count}, target {target_example_count}")
        this_badness = percentage_to_error(min(1,example_count/target_example_count), base=5)
        if verbose:
            print(f" - count badness: {this_badness}, count: {example_count}")
        badness += this_badness

    return badness

    
def get_score(groups: list[TermGroup], 
              permutation_prompts: list[list[str]], 
              all_prompts_count: int, 
              verbose: bool = False) -> float:
    
    badness = 0

    if verbose:
        print("coverage:")
    # each group gets a score based on how evenly spread it is
    coverage_badness = sum([get_group_score(group, permutation_prompts, verbose=verbose) for group in groups])
    badness += coverage_badness * 10

    if verbose:
        print("completeness:")
    # each term gets a score based on how many examples it includes
    example_count_badness = sum([get_example_count_score(group, permutation_prompts, verbose=verbose) for group in groups])
    badness += example_count_badness * 1
    
    
    # the total cut gets a scored base on what percentage of the total prompt count it includes
    taken_badness = percentage_to_error(len(permutation_prompts)/all_prompts_count)
    if verbose:
        print(f"permutation uses {len(permutation_prompts)} prompts out of {all_prompts_count} -> badness {taken_badness}")
    badness += taken_badness*1
    
    return badness
    


In [139]:
score = get_score(groups, 
                  permutation_prompts=random.sample(prompts_split_to_words,500), 
                  all_prompts_count=len(prompts_split_to_words),
                 verbose = True)
print("score: ", score)


coverage:
 - pct coverage badness: 0.09029441789566287, count: 52%: ('warm', 'background') (259 prompts)
 - pct coverage badness: 0.09291053398431713, count: 48%: ('cool', 'background') (241 prompts)
 - pct coverage badness: 0.7942055178746232, count: 5%: ('black', 'glasses') (22 prompts)
 - pct coverage badness: 3.2426133392567826, count: 18%: ('green', 'glasses') (77 prompts)
 - pct coverage badness: 0.035855020654458634, count: 11%: ('gray', 'glasses') (49 prompts)
 - pct coverage badness: 1.12351961713274, count: 14%: ('blue', 'glasses') (63 prompts)
 - pct coverage badness: 3.911504062622492, count: 18%: ('orange', 'glasses') (80 prompts)
 - pct coverage badness: 3.4550433032676966, count: 18%: ('red', 'glasses') (78 prompts)
 - pct coverage badness: 0.8613748367533154, count: 4%: ('brown', 'glasses') (17 prompts)
 - pct coverage badness: 0.8230062447328148, count: 5%: ('pink', 'glasses') (20 prompts)
 - pct coverage badness: 0.689348734913187, count: 6%: ('yellow', 'glasses') (28

In [144]:
import random
from tqdm.notebook import tqdm
random.seed(101)

def find_best_monte_carlo(groups: list[TermGroup], 
                          prompts_split_to_words: list[list[str]], 
                          num_iterations: int=100,
                          annealing_base: float=0.1
                          ):
    
    num_ids = len(prompts_split_to_words)
    def get_available_ids(selected_ids: list[int]) -> set[int]:
        return set([id for id in range(num_ids) if id not in current_selected_ids])

    def get_prompts_split_to_words_for_permutation(selected_ids: list[int]) -> list[list[str]]:
        return [prompts_split_to_words[i] for i in selected_ids]
    
    def get_score_for_permutation_ids(permutation: list[int]) -> float:
        return get_score(groups, get_prompts_split_to_words_for_permutation(current_selected_ids), num_ids)

    #current_selected_ids = set(random.sample(range(num_ids), num_ids//2))
    current_selected_ids = set([])
    current_score = get_score_for_permutation_ids(current_selected_ids)
    
    def get_current_annealing_factor(iteration):
        x = iteration / num_iterations
        return (math.pow(annealing_base, x)-annealing_base) / (1-annealing_base)

    
    for iteration in tqdm(range(num_iterations)):
        # remove and add a random number of elements
        current_available_ids = get_available_ids(current_selected_ids)
        permutation = current_selected_ids.copy()
        
        max_add_remove_count = max(1,int(0.25 * num_ids * get_current_annealing_factor(iteration)))
        
        if len(permutation) > 0:
            to_remove_count = random.randrange(min(max_add_remove_count, len(current_selected_ids)))
            for to_remove in random.sample(permutation, to_remove_count):
                permutation.remove(to_remove)
        if len(current_available_ids) > 0:
            to_add_count = random.randrange(min(max_add_remove_count, len(current_available_ids)))
            for to_add in random.sample(current_available_ids, to_add_count):
                permutation.add(to_add)
        #print(f"trying {permutation}")
        permutation_score = get_score(groups, get_prompts_split_to_words_for_permutation(permutation), num_ids)
        if permutation_score < current_score:
            #print(f"score {score} is better than current best {current_score}")
            current_score = permutation_score
            current_selected_ids = permutation
        if iteration % (num_iterations/10) == 0:
            print(f"current score: {current_score}")
        #print(f"score {score} is better than current best {current_score}")

            
    return current_selected_ids


In [145]:
best_ids = find_best_monte_carlo(groups, prompts_split_to_words, num_iterations=10, annealing_base=0.01)
#print(f"best: {best_ids}")

best_prompts_split_to_words = [prompts_split_to_words[i] for i in best_ids]
score = get_score(groups, best_prompts_split_to_words, len(prompts_split_to_words), True)

  0%|          | 0/10 [00:00<?, ?it/s]

current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
current score: 65.0
coverage:
 - pct coverage badness: 1.0, count: 0%: ('warm', 'background') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('cool', 'background') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('black', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('green', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('gray', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('blue', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('orange', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('red', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('brown', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('pink', 'glasses') (0 prompts)
 - pct coverage badness: 1.0, count: 0%: ('yellow'

In [140]:
prompts_split_to_words

[['a',
  'pixel',
  'art',
  'character',
  'with',
  'square',
  'black',
  'glasses',
  ',',
  'a',
  'hotdog-shaped',
  'head',
  'and',
  'a',
  'peachy-colored',
  'body',
  'on',
  'a',
  'warm',
  'background'],
 ['a',
  'pixel',
  'art',
  'character',
  'with',
  'square',
  'black',
  'sunglasses',
  ',',
  'a',
  'shower-shaped',
  'head',
  'and',
  'a',
  'blue-colored',
  'body',
  'on',
  'a',
  'cool',
  'background'],
 ['a',
  'pixel',
  'art',
  'character',
  'with',
  'square',
  'dark',
  'green',
  'glasses',
  ',',
  'a',
  'yeti-shaped',
  'head',
  'and',
  'a',
  'redpinkish-colored',
  'body',
  'on',
  'a',
  'cool',
  'background'],
 ['a',
  'pixel',
  'art',
  'character',
  'with',
  'square',
  'dark',
  'gray',
  'glasses',
  ',',
  'a',
  'void-shaped',
  'head',
  'and',
  'a',
  'teal-colored',
  'body',
  'on',
  'a',
  'warm',
  'background'],
 ['a',
  'pixel',
  'art',
  'character',
  'with',
  'square',
  'dark',
  'green',
  'glasses',
  ',',
 