In [None]:
!pip install nltk gensim numpy ipywidgets


# load data

In [None]:



dataset_path="./images"

import os
def load_prompts(path):
    image_extensions = [".jpg", ".png", ".jpeg"]
    for root, _, files in os.walk(path, topdown=False):
        for file in files:
            filename_without_extension, extension = os.path.splitext(file)
            if extension not in image_extensions:
                continue
            image_path = os.path.join(root, file)
            caption_path = os.path.join(root, filename_without_extension + ".txt")
            try:
                with open(caption_path, 'r') as caption_file:
                    caption = caption_file.read().replace('\n', ' ')
                    yield image_path, caption_path, caption
            except FileNotFoundError:
                continue

prompts = []
image_filenames = []

#with open('data.txt', 'r') as file:
#    data = file.read().replace('\n', '')

In [None]:
%%html
<style>
textarea, input {
    font-family: Arial, Helvetica, sans-serif;
}
:root {
    --jp-ui-font-size1: '2em';
}
</style>


# build vocab

In [None]:
from nltk.corpus import words
from nltk import Text
import itertools
#import nltk
#nltk.download('words')

words_nested = []
words = []
word_ids = []
nltk_text = None

def rebuild_words_and_words_nested():
    global words_nested
    global words
    words_nested = [p.replace(',', ' , ').split() for p in prompts]
    words = list(itertools.chain.from_iterable(words_nested))
    

def words_to_word_ids(sentence: list):
    return [nltk_text.index(w) for w in sentence]

def word_ids_to_words(word_ids: list[int]):
    return [nltk_text.tokens[wid] for wid in word_ids]

def reload_data(dataset_path):
    global prompts
    global image_filenames
    global nltk_text
    global word_ids
    prompts = []
    image_filenames = []
    for image_path, caption_path, caption in load_prompts(dataset_path):
        prompts.append(caption)
        image_filenames.append(image_path)
    rebuild_words_and_words_nested()
    nltk_text = Text(words)
    word_ids = [words_to_word_ids(s) for s in words_nested]
    
reload_data(dataset_path)
    



In [None]:

def register_caption_callback(image_id, callback):
    caption_callbacks.get(image_id, []).append(callback)

def set_caption(image_id, caption):
    prompts[image_id] = caption
    rebuild_words_and_words_nested()
    

# build ngrams

In [None]:
from nltk import ngrams, FreqDist

def get_sentence_ngrams(sentence: list, n: int=3) -> list[list[int]]:
    grams = [x for x in ngrams(sentence, n)]
    return grams

def get_ngram_freq_dist_sorted_by_count(ngrams):
    fdist = FreqDist()
    for ngram in ngrams:
        fdist[ngram] += 1
    return fdist

def get_all_ngrams_per_sentence(ngram_length: int) -> list[tuple]:
    all_ngrams_nested_words = [get_sentence_ngrams(s, ngram_length) for s in words_nested]
    return all_ngrams_nested_words

def get_all_ngrams_flat(ngram_length: int) -> list[tuple[int, list[str]]]:
    all_ngrams_nested_words = get_all_ngrams_per_sentence(ngram_length)
    all_ngrams_flat_words = list(itertools.chain.from_iterable(all_ngrams_nested_words))
    return all_ngrams_flat_words
    
def count_ngrams_containing_word(ngram_length: int, search_word: str) -> list[tuple[int, list[str]]]:
    all_ngrams_flat_words = get_all_ngrams_flat(ngram_length)
    fdist = get_ngram_freq_dist_sorted_by_count([ngram for ngram in all_ngrams_flat_words if search_word in ngram])
    for ngram, count in fdist.most_common(50):
        yield (count, ngram)
        
def find_prompts_containing_ngram(ngram_length: int, ngram_words: list[str]) -> list[int]:
    all_ngrams_nested_words = [get_sentence_ngrams(s, ngram_length) for s in words_nested]
    return [i for (i, ngrams) in enumerate(all_ngrams_nested_words) if ngram_words in ngrams]


def count_ngrams(ngram_length: int):
    all_ngrams_flat_words = get_all_ngrams_flat(ngram_length)
    fdist = get_ngram_freq_dist_sorted_by_count(all_ngrams_flat_words)
    return fdist

    

In [None]:
reload_data(dataset_path)
all_ngrams_fdist = count_ngrams(3)

least_to_most = [(k,v) for k,v in sorted(all_ngrams_fdist.items(), key=lambda x: x[1])]

print("most common:")
print("\n".join(reversed([str(x) for x in least_to_most[-10:]])))
print("least common:")
print("\n".join([str(x) for x in least_to_most[:10]]))

In [None]:
def find_sentences_simple(search_word: str) -> [int]:
    return [i for i, words in words_nested if search_word in words_nested]

def find_sentences_complex(with_words: list[str], without_words: list[str]=[]) -> set[int]:
    positive_matches = set([find_sentences_simple(w) for word in with_words])
    negative_matches = set([find_sentences_simple(w) for word in without_words])
    return positive_matches.minus(negative_matches)
                           

In [None]:

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import asyncio

class Timer:
    def __init__(self, timeout, callback):
        self._timeout = timeout
        self._callback = callback

    async def _job(self):
        await asyncio.sleep(self._timeout)
        self._callback()

    def start(self):
        self._task = asyncio.ensure_future(self._job())

    def cancel(self):
        self._task.cancel()

def debounce(wait):
    """ Decorator that will postpone a function's
        execution until after `wait` seconds
        have elapsed since the last time it was invoked. """
    def decorator(fn):
        timer = None
        def debounced(*args, **kwargs):
            nonlocal timer
            def call_it():
                fn(*args, **kwargs)
            if timer is not None:
                timer.cancel()
            timer = Timer(wait, call_it)
            timer.start()
        return debounced
    return decorator

def make_caption_edit_box(image_id: int):
    txtsl = widgets.Textarea(
        value=prompts[image_id],
        layout={'width': '800px', 'height': '100px'}
    )
    txtsl.style.font_size = '1.5em'
    txtsl.image_id = image_id
    txtsl.observe(textarea_contents_did_change, names='value')
    
    label_text = image_filenames[image_id] + ":"
    label = widgets.Label(label_text)
    label.style.font_family = 'courier'
    
    vbox = widgets.VBox([label, txtsl])
    return vbox



@debounce(0.5)
def textarea_contents_did_change(observation):
    #print('changed:', observation)
    new_caption = observation['new']
    image_id = observation['owner'].image_id
    set_caption(image_id, new_caption)
    print('stored caption for', image_id)

    
def count_ngrams_and_display_editboxes(ngram_length:int, word:str, max_count:int):
    vboxes = []
    titles = []
    for x in count_ngrams_containing_word(ngram_length, word):
        count, ngram = x
        if ngram[0] == ',':
            continue
        titles.append(f"{count} captions contain {ngram}")
        if count > max_count:
            vbox = widgets.VBox([])
            vboxes.append(vbox)
        else:
            image_ids = find_prompts_containing_ngram(2, ngram)
            edit_boxes = [make_caption_edit_box(image_id) for image_id in image_ids]
            vbox = widgets.VBox(edit_boxes)
            vboxes.append(vbox)
    
    accordian = widgets.Accordion(vboxes, titles=titles)
    display(accordian)


interact(count_ngrams_and_display_editboxes, 
         ngram_length=widgets.IntSlider(value=2, min=2, max=10, description='Length'), 
         word=widgets.Text(value='', description='Word', placeholder='type search word here'), 
         max_count=widgets.IntText(value=5, description='Max count')
        )
                

# cluster


In [None]:
from nltk import cluster
from nltk.cluster import euclidean_distance
from numpy import array
import numpy

# initialise the clusterer (will also assign the vectors to clusters) 
vectors = [array(f) for f in all_ngrams_flat]

num_clusters = 10
clusterer = cluster.KMeansClusterer(num_clusters, euclidean_distance, avoid_empty_clusters=True) 
clusters = clusterer.cluster(vectors, True)

In [None]:
import numpy
unique, counts = numpy.unique(clusters, return_counts=True)

print(dict(zip(unique, counts)))



#for cluster_index in range(0, num_clusters):
#    count = numpy.

#print(vectors[0],vectors[1])
#diff = vectors[1]-vectors[0]
#numpy.dot(diff, diff)

In [None]:

def get_cluster_contents(which_cluster):
    contents = {}
    counter = 0
    for ngram_index, cluster in enumerate(clusters):
        if cluster != which_cluster:
            continue
        ngram = all_ngrams_flat[ngram_index]
        #print(ngram)
        #print(nltk_text.words(list(ngram)))
        # d[i] = d.setdefault(i, 0) + 1
        #print(contents)
        if ngram not in contents.keys():
            contents.update({ngram: 1})
        else:
            contents.update({ngram: contents.get(ngram) + 1})
        #print(contents)
        #counter += 1
        #if counter > 30:
        #    break
    return contents


#for cluster_index in range(num_clusters):
#    this_cluster_contents = []
#    for ngram_index, ngram in enumerate(all_ngrams_flat):
#        print(f"ngram {ngram} is in cluster {clusters[ngram_index]}")
#        break

In [None]:

# sort by occurence count
def get_clustered_ngrams_sorted_by_count(cluster_index: int):
    cluster_contents = get_cluster_contents(cluster_index)
    #print(cluster_contents)
    x = cluster_contents.keys()
    #print(cluster_index, x)
    ngrams_sorted_by_occurence = sorted(cluster_contents.items(), key=lambda item: item[1], reverse=True)
    return [(word_ids_to_words(list(k)), v) for k, v in ngrams_sorted_by_occurence]



