In [None]:
try:
    %load_ext lab_black
except ModuleNotFoundError:
    pass

In [None]:
import pickle
import json
import textwrap
import random
import functools
import webbrowser

import numpy as np
import scipy.cluster.hierarchy as shc
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
from ipywidgets import Button, HBox, VBox, Output, Layout, Image, Checkbox, HTML
from ipyevents import Event

from ipywidgets import HTML

# from scipy.cluster.hierarchy import cut_tree, to_tree, leaves_list

plt.style.use("dark_background")

In [None]:
source = "Lightbulb-top-1000-results"
# source = "SomebodyMakeThis-top-1000-results"

In [None]:
with open(f"../data/{source}-embeddings.pickle", "rb") as file:
    url_to_embedding = pickle.load(file)

with open(f"../data/{source}.json") as file:
    data = json.load(file)
data = data["data"]

In [None]:
url_to_score = dict()
url_to_title = dict()
for post in data:
    url = post["url"]
    url_to_score[url] = post["score"]
    url_to_title[url] = post["title"]

In [None]:
urls, embeddings = zip(*url_to_embedding.items())
urls = np.array(urls)

In [None]:
linkage = shc.linkage(embeddings, method="ward")

# plt.figure(figsize=(10, 7))
# dend = shc.dendrogram(linkage)

In [None]:
def split_into_n_children(tree, n):
    """Cut the biggest cluster in two, and repeat until there are n clusters.
    Can throw ValueError if the tree cannot be further divided.
    """
    children = [tree]

    while len(children) < n:
        index_of_biggest = np.argmax([child.count for child in children])
        to_split = children[index_of_biggest]
        if to_split.is_leaf():
            raise ValueError("tree cannot be further divided")
        splitten = [to_split.left, to_split.right]
        children[index_of_biggest : index_of_biggest + 1] = splitten
        # print([child.count for child in children])
    return children


# from IPython.display import Javascript


# def window_open(url):
#     display(Javascript('window.open("{url}");'.format(url=url)))


# def window_open(_, url):
#     webbrowser.open(url)
#     # alternative is to use Javascript https://stackoverflow.com/a/61900572/11756613
#     # and it works even when jupyter is remote
#     # but here, when called by an event, it's broken for some reason

In [None]:
class TreeCrawler:
    def __init__(self, num_of_columns, message_output, display_updater):
        self.num_of_columns = num_of_columns
        self.message_output = message_output
        self.display_updater = display_updater

    def reset(self, tree):
        self.tree = tree
        self.path = []
        self.update()

    def choose_column(self, _, i):
        self.path.append(self.tree)
        self.tree = self.children[i]
        self.message_output.clear_output()
        self.update()

    def go_back(self, _):
        if self.path == []:
            with self.message_output:
                self.message_output.clear_output()
                print("already on the highest cluster")
            return
        self.tree = self.path.pop()
        self.update()

    def update(self):
        if self.tree.count < self.num_of_columns:
            with self.message_output:
                self.message_output.clear_output()
                print("already on the lowest cluster")
            return

        self.children = split_into_n_children(self.tree, n=self.num_of_columns)
        self.display_updater(_)


class PostsWall:
    def __init__(
        self,
        tree,
        num_of_columns=3,
        posts_in_column=20,
        text_width=40,
        width=1000,
    ):
        self.num_of_columns = num_of_columns
        self.posts_in_column = posts_in_column
        self.text_width = text_width

        self.message_output = Output()
        self.tree_crawler = TreeCrawler(
            num_of_columns, self.message_output, self.update_displayed_posts
        )
        self.random_seed = random.randint(0, 1000000)

        column_width = width / self.num_of_columns
        layout = Layout(width=f"{column_width}px")
        self.columns = [Output(layout=layout) for _ in range(num_of_columns)]
        go_back_button = Button(description="Go back")
        go_back_button.on_click(self.tree_crawler.go_back)
        self.whole_output = VBox(
            [
                go_back_button,
                self.message_output,
                HBox(self.columns),
            ]
        )

        # bind middle click to choose_column
        for i, output in enumerate(self.columns):
            event = Event(source=output, watched_events=["auxclick"])
            func = functools.partial(self.tree_crawler.choose_column, i=i)
            event.on_dom_event(func)

        self.tree_crawler.reset(tree)

    def update_displayed_posts(self, _):
        for child, column in zip(self.tree_crawler.children, self.columns):
            ids = child.pre_order()

            # rank
            scores_and_urls = [(url_to_score[url], url) for url in child.pre_order()]
            scores_and_urls = sorted(scores_and_urls, reverse=True)
            best_urls = [url for score, url in scores_and_urls]

            #             # TODO these cutoff values could be parametrized
            #             top = min(60, len(ranked_ids) // 6)
            #             top = max(top, self.videos_in_column)
            #             ranked_ids = ranked_ids[:top]
            #             ranked_ids = list(ranked_ids)
            #             random.seed(self.random_seed)
            #             random.shuffle(ranked_ids)
            #             ranked_ids = ranked_ids[: self.videos_in_column]

            with column:
                column.clear_output(wait=True)
                print("total posts: ", len(ids))

                for url in best_urls[:10]:
                    title = url_to_title[url]
                    title = textwrap.wrap(title, width=self.text_width)
                    title = "<br>".join(title)
                    # title = "\n" + title

                    # out = Output()
                    # with out:
                    #     print(title)
                    # event = Event(source=out, watched_events=["click"])
                    # func = functools.partial(window_open, url=url)
                    # event.on_dom_event(func)
                    # display(out)

                    display(
                        HTML(
                            f"""<a href="{url}" style="color:#EEEEEE;" target="_blank"><b>{title}<br><br></b></a>"""
                        )
                    )

In [None]:
tree = shc.to_tree(linkage)


def substitute_url(leaf):
    leaf.id = urls[leaf.id]


_ = tree.pre_order(substitute_url)

In [None]:
posts_wall = PostsWall(tree)
posts_wall.whole_output