In [None]:
!pip install nltk gensim numpy ipywidgets
import numpy


# load data

In [None]:

import os
def load_prompts(path):
    image_extensions = [".jpg", ".png", ".jpeg"]
    for root, _, files in os.walk(path, topdown=False):
        for file in files:
            filename_without_extension, extension = os.path.splitext(file)
            if extension not in image_extensions:
                continue
            image_path = os.path.join(root, file)
            caption_path = os.path.join(root, filename_without_extension + ".txt")
            try:
                with open(caption_path, 'r') as caption_file:
                    caption = caption_file.read().replace('\n', ' ')
                    yield image_path, caption_path, caption
            except FileNotFoundError:
                continue
                
def reload_data(dataset_path):
    prompts = []
    image_filenames = []
    for image_path, caption_path, caption in load_prompts(dataset_path):
        prompts.append(caption)
        image_filenames.append(image_path)
    return prompts, image_filenames
    

prompts = []
image_filenames = []

#with open('data.txt', 'r') as file:
#    data = file.read().replace('\n', '')

In [None]:
dataset_path="./images"
prompts, image_filenames = reload_data(dataset_path)

prompts_split_to_words = [p.replace(',', ' , ').split() for p in prompts]




In [None]:
from dataclasses import dataclass
from typing import Union


def sequence_is_in(pattern: Union[tuple,list], sequence: list):
    if type(pattern) is tuple and type(sequence) is list:
        pattern = list(pattern)
    if type(sequence) is not list:
        raise ValueError(f"sequence {sequence} must be a list")
    if type(pattern) is not list:
        raise ValueError(f"unhandled type {type(pattern).__name__} for pattern")
    for i in range(len(sequence) - len(pattern) + 1):
        if sequence[i:i+len(pattern)] == pattern:
            return True
    return False

@dataclass
class TermGroup: 
    terms: list[tuple[str]]
    synonyms: list[list[tuple[str]]]
        
    def get_synonyms(self, term: tuple[str]) -> list[tuple[str]]:
        """
        for the given term, return it and all of its synonyms (if any) as a list
        """
        try:
            synonym_idx = next(i for i,l in enumerate(self.synonyms) if sequence_is_in(term, l))
            return self.synonyms[synonym_idx]
        except StopIteration:
            return [term]
        
    
    def prompt_contains_terms(self, prompt: list[str]) -> list[bool]:
        """
        for the given prompt, return an array of bools  of each of our terms
        """
        result = [False] * len(self.terms)
        for term_index,term in enumerate(self.terms):
            for search_sequence in self.get_synonyms(term):
                #print(f"looking for {search_sequence} in {prompt}")
                if sequence_is_in(search_sequence, prompt):
                    #print('found')
                    result[term_index] = True
                    break
        return result
    
    def get_term_appearances(self, prompts: list[list[str]]) -> list[int]:
        per_term_ids = [[] for _ in range(len(self.terms))]
        for prompt_index,prompt in enumerate(prompts):
            contains_term = self.prompt_contains_terms(prompt)
            #print(contains_term)
            matching_term_indices = [i for i,term in enumerate(self.terms) if contains_term[i]]
            #print(f"existing: {result_tuple}, {len(result_tuple[1])} prompts")
            for term_index in matching_term_indices:
                per_term_ids[term_index].append(prompt_index)
        return per_term_ids


    
def count_groups_appearances(prompts: list[list[str]], groups: list[TermGroup]) -> list[tuple[TermGroup, list[int]]]:
    #result = [(group,[[] for _ in range(len(group.terms))]) for group in groups]
    return [(group, group.get_term_appearances(prompts)) for group in groups]
    #for prompt_index,prompt in enumerate(prompts):
    #    for group_index,group in enumerate(groups):
            
    #        count_group_appearances(group, prompts)
        #break
    return result



In [None]:
    
groups = [
    TermGroup(terms =[('flying', 'type'), ('psychic', 'type'), ('ground', 'type')], synonyms=[]),
    TermGroup(terms =[('red',), ('yellow',), ('blue',), ('green',)], synonyms=[]),
]

appearances = count_groups_appearances(prompts_split_to_words, groups)
print([(group,[len(ids) for ids in prompt_ids]) for group,prompt_ids in appearances])


## optimize

In [None]:
import math

def percentage_to_error(percentage: float, base=0.333) -> float:
    powed = math.pow(base, percentage)
    # remap 1..base to 1..0
    return max(0, min(1, (powed-base) / (1.0-base)))
    

def get_group_score(group: TermGroup, permutation_prompts: list[list[str]], verbose = False) -> float:
    """
    return a score based on how evenly shared the group is (lower is better)
    """
    badness = 0
    
    appearances = group.get_term_appearances(permutation_prompts)
    total_count = sum([len(l) for l in appearances])

    target_percentage_per_term = 1/len(group.terms)
    for term_index in range(len(appearances)):
        term_ids = appearances[term_index]
        percentage = len(term_ids) / (total_count+1)
        error = percentage_to_error(percentage / target_percentage_per_term, base=10)
        if verbose:
            print(f" - pct coverage badness: {error}, count: {round(100 * percentage)}%: {group.terms[term_index]} ({len(term_ids)} prompts)")
        badness += abs(error)

    return badness/len(group.terms)
    
def get_example_count_score(group: TermGroup, 
                            permutation_prompts: list[list[str]], 
                            target_example_count: int=30, 
                            verbose = False) -> float:
    """
    return a score that is the sum of whether the terms in this group have sufficient examples (lower is better)
    """
    badness = 0
    appearances = group.get_term_appearances(permutation_prompts)
    for term_index in range(len(group.terms)):
        example_count = len(appearances[term_index])
        this_badness = percentage_to_error(example_count/target_example_count, base=5)
        if verbose:
            print(f" - count badness: {this_badness}, count: {example_count}")
        badness += this_badness

    return badness

    
def get_score(groups: list[TermGroup], 
              permutation_prompts: list[list[str]], 
              all_prompts_count: int, 
              verbose: bool = False) -> float:
    
    badness = 0
    
    # each group gets a score based on how evenly spread it is
    coverage_badness = sum([get_group_score(group, permutation_prompts, verbose=verbose) for group in groups])
    badness += coverage_badness * 10
    
    # each term gets a score based on how many examples it includes
    example_count_badness = sum([get_example_count_score(group, permutation_prompts, verbose=verbose) for group in groups])
    badness += example_count_badness * 1
    
    # the total cut gets a scored base on what percentage of the total prompt count it includes
    taken_badness = percentage_to_error(len(permutation_prompts)/all_prompts_count)
    if verbose:
        print(f"permutation uses {len(permutation_prompts)} prompts out of {all_prompts_count} -> badness {taken_badness}")
    badness += taken_badness*1
    
    return badness
    


In [None]:
score = get_score(groups, 
                  permutation_prompts=prompts_split_to_words[:500], 
                  all_prompts_count=len(prompts_split_to_words),
                 verbose = True)
print("score: ", score)


In [None]:
import random
from tqdm.notebook import tqdm
random.seed(101)

def find_best_monte_carlo(groups: list[TermGroup], 
                          prompts_split_to_words: list[list[str]], 
                          num_iterations: int=100,
                          ):
    
    num_ids = len(prompts_split_to_words)
    def get_available_ids(selected_ids: list[int]) -> set[int]:
        return set([id for id in range(num_ids) if id not in current_selected_ids])

    def get_prompts_split_to_words_for_permutation(selected_ids: list[int]) -> list[list[str]]:
        return [prompts_split_to_words[i] for i in selected_ids]
    
    def get_score_for_permutation_ids(permutation: list[int]) -> float:
        return get_score(groups, get_prompts_split_to_words_for_permutation(current_selected_ids), num_ids)

    current_selected_ids = set(random.sample(range(num_ids), num_ids//2))
    current_score = get_score_for_permutation_ids(current_selected_ids)
    
    annealing_base = 0.5
    def get_current_annealing_factor(iteration):
        x = iteration / num_iterations
        return (math.pow(annealing_base, x)-annealing_base) / (1-annealing_base)

    
    for iteration in tqdm(range(num_iterations)):
        # remove and add a random number of elements
        current_available_ids = get_available_ids(current_selected_ids)
        permutation = current_selected_ids.copy()
        
        max_add_remove_count = max(1,int(0.25 * num_ids * get_current_annealing_factor(iteration)))
        
        if len(permutation) > 0:
            to_remove_count = random.randrange(min(max_add_remove_count, len(current_selected_ids)))
            for to_remove in random.sample(permutation, to_remove_count):
                permutation.remove(to_remove)
        if len(current_available_ids) > 0:
            to_add_count = random.randrange(min(max_add_remove_count, len(current_available_ids)))
            for to_add in random.sample(current_available_ids, to_add_count):
                permutation.add(to_add)
        #print(f"trying {permutation}")
        permutation_score = get_score(groups, get_prompts_split_to_words_for_permutation(permutation), num_ids)
        if permutation_score < current_score:
            #print(f"score {score} is better than current best {current_score}")
            current_score = permutation_score
            current_selected_ids = permutation
        if iteration % (num_iterations/10) == 0:
            print(f"current score: {current_score}")
        #print(f"score {score} is better than current best {current_score}")

            
    return current_selected_ids

best_ids = find_best_monte_carlo(groups, prompts_split_to_words, num_iterations=1000)
#print(f"best: {best_ids}")

best_prompts_split_to_words = [prompts_split_to_words[i] for i in best_ids]
score = get_score(groups, best_prompts_split_to_words, len(prompts_split_to_words), True)
    